diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index bf2773488..5ec1fa768 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -273,6 +273,9 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "ScatterValueIntModule_basic", + # AssertionError: Unregistered operation: torch.aten._unsafe_index_put + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + } TORCHDYNAMO_CRASHING_SET = { @@ -1431,4 +1434,5 @@ LTC_XFAIL_SET = { "ScatterValueIntModule_basic", "IndexTensorNegativeIndexModule_basic", "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8bff52697..29c22d48f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4324,6 +4324,32 @@ def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [ }]; } +def Torch_Aten_UnsafeIndexPutHackedTwinOp : Torch_Op<"aten._unsafe_index_put.hacked_twin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices, + AnyTorchTensorType:$values, + Torch_BoolType:$accumulate + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_UnsafeIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void Aten_UnsafeIndexPutHackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenLinearOp : Torch_Op<"aten.linear", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 65ddcdcdb..f41d1862d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8552,6 +8552,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1616d4b17..d9ba95bd9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3254,6 +3254,25 @@ public: }; } // namespace +namespace { +// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` +// op. +class DecomposeAten_UnsafeIndexPutHackedTwinOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), + /*unsafe=*/cstFalse); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { @@ -4889,6 +4908,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index cc4230d65..caa14e9e4 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -443,6 +443,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 8de6cc1a1..0f52d14a1 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -274,6 +274,7 @@ def _lower_mlir_module(verbose, output_type, module): print("Torch Backend IR") print(module) + # module.dump() if output_type == OutputType.TORCH: return module @@ -292,6 +293,7 @@ def _lower_mlir_module(verbose, output_type, module): module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + # module.dump() if verbose: print("\n====================") print("LINALG Backend IR") @@ -446,6 +448,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: if output_type == OutputType.RAW: return mb.module + # mb.module.dump() option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ " extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 986d1115d..5fb21fe42 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1652,6 +1652,11 @@ def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], ind self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_index_put_invocations) +def aten〇_unsafe_index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_index_put_invocations) def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b51f8eb58..ac7bdcfdb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -353,6 +353,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") emit_with_mutating_variants( "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") + emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)") # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/scatter.py index 404089a40..8a84961a5 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -843,6 +843,35 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), tu.randint(5, 8, 6, high=1000)) + +# ============================================================================== +# UnsafeIndexPutHackedTwin tests are using the aten._unsafe_index_put.hacked_twin operator. + + +class UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._unsafe_index_put(input, [index], + value, + accumulate=False) + + +@register_test_case( + module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule()) +def UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250)) + + # ============================================================================== class ScatterSrcStaticModule(torch.nn.Module):