[MLIR][TORCH] Add support for aten._unsafe_index_put.hacked_twin op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2392/head snapshot-20230811.927
Vivek Khandelwal 2023-07-14 07:26:54 +00:00
parent f0a8f273f7
commit e61ef1ee54
9 changed files with 93 additions and 0 deletions

View File

@ -273,6 +273,9 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
} }
TORCHDYNAMO_CRASHING_SET = { TORCHDYNAMO_CRASHING_SET = {
@ -1431,4 +1434,5 @@ LTC_XFAIL_SET = {
"ScatterValueIntModule_basic", "ScatterValueIntModule_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"UniformStaticShapeModule_basic", "UniformStaticShapeModule_basic",
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
} }

View File

@ -4324,6 +4324,32 @@ def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [
}]; }];
} }
def Torch_Aten_UnsafeIndexPutHackedTwinOp : Torch_Op<"aten._unsafe_index_put.hacked_twin", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTensorType:$indices,
AnyTorchTensorType:$values,
Torch_BoolType:$accumulate
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_UnsafeIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void Aten_UnsafeIndexPutHackedTwinOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -8552,6 +8552,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_index_put.hacked_twin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.tuple<int, int>, %arg3: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<optional<tuple<int, int>>>, %arg2: !torch.tuple<int, int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<optional<tuple<int, int>>>, %arg2: !torch.tuple<int, int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -3254,6 +3254,25 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl`
// op.
class DecomposeAten_UnsafeIndexPutHackedTwinOp
: public OpRewritePattern<Aten_UnsafeIndexPutHackedTwinOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
op.getAccumulate(),
/*unsafe=*/cstFalse);
return success();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.pad` op into `aten.constantPadNd` op. // Decompose `aten.pad` op into `aten.constantPadNd` op.
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> { class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
@ -4889,6 +4908,7 @@ public:
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns); addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);

View File

@ -443,6 +443,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNativeDropoutOp>(); target.addIllegalOp<AtenNativeDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>(); target.addIllegalOp<AtenNewEmptyOp>();
target.addIllegalOp<AtenIndexPutHackedTwinOp>(); target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>(); target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>(); target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>(); target.addIllegalOp<AtenToDeviceOp>();

View File

@ -274,6 +274,7 @@ def _lower_mlir_module(verbose, output_type, module):
print("Torch Backend IR") print("Torch Backend IR")
print(module) print(module)
# module.dump()
if output_type == OutputType.TORCH: if output_type == OutputType.TORCH:
return module return module
@ -292,6 +293,7 @@ def _lower_mlir_module(verbose, output_type, module):
module, module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
# module.dump()
if verbose: if verbose:
print("\n====================") print("\n====================")
print("LINALG Backend IR") print("LINALG Backend IR")
@ -446,6 +448,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW: if output_type == OutputType.RAW:
return mb.module return mb.module
# mb.module.dump()
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
" extra-library=" + extra_library_file_name + "}" " extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report( run_pipeline_with_repro_report(

View File

@ -1652,6 +1652,11 @@ def atenindex_puthacked_twin〡dtype(self_rank_dtype: Tuple[int, int], ind
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype return self_dtype
@check_dtype_function(_index_put_invocations)
def aten_unsafe_index_puthacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_index_put_invocations) @check_dtype_function(_index_put_invocations)
def aten_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: def aten_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -353,6 +353,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
emit_with_mutating_variants( emit_with_mutating_variants(
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
# Non-elementwise tensor compute ops # Non-elementwise tensor compute ops
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")

View File

@ -843,6 +843,35 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
tu.randint(5, 8, 6, high=1000)) tu.randint(5, 8, 6, high=1000))
# ==============================================================================
# UnsafeIndexPutHackedTwin tests are using the aten._unsafe_index_put.hacked_twin operator.
class UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index],
value,
accumulate=False)
@register_test_case(
module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule())
def UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250))
# ============================================================================== # ==============================================================================
class ScatterSrcStaticModule(torch.nn.Module): class ScatterSrcStaticModule(torch.nn.Module):