diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 77ebc93c0..a586565f0 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -7,6 +7,9 @@ blacklist: - index_put # Error: TODO not sure if there are other valid types to handle here - index_put_ # Error: TODO not sure if there are other valid types to handle here +# Ops with list of tensors output +- unbind.int + # Additional ops which autogen is supported for but don't compile yet - _convolution - detach diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 30f054780..f65c6f5c6 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -259,6 +259,10 @@ TORCHDYNAMO_XFAIL_SET = { "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", # END tests failing due to: complex floating point ops + + # ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int + "UnbindIntListUnpack_Module_basic", + "UnbindIntGetItem_Module_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -722,6 +726,8 @@ STABLEHLO_PASS_SET = { "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", "AtenComplex64Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnbindIntGetItem_Module_basic", } # Write the TOSA set as a "passing" set as it is very early in development @@ -1001,6 +1007,8 @@ TOSA_PASS_SET = { "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", "DetachModule_basic", + "UnbindIntListUnpack_Module_basic", + "UnbindIntGetItem_Module_basic", "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "AtenComplex64Module_basic", @@ -1182,5 +1190,7 @@ LTC_XFAIL_SET = { "VarMeanDimBiasedModule_basic", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", - "AtenComplexViewModule_basic" + "AtenComplexViewModule_basic", + "UnbindIntListUnpack_Module_basic", + "UnbindIntGetItem_Module_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c14daf0be..7a828e754 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9521,6 +9521,29 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [ }]; } +def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unbind.int : (Tensor, int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnbindIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenUnbindIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index dbddcc312..d35a8f564 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -121,6 +121,66 @@ public: return success(); } }; + +class RecomposeUnbindListUnpack : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenUnbindOp + PrimListUnpackOp to select.int + auto unbind = dyn_cast(op.getOperand().getDefiningOp()); + if (!unbind) + return failure(); + if (isListPotentiallyMutated(unbind.getResult())) + return failure(); + Value dim = unbind.getDim(); + Value input = unbind.getSelf(); + SmallVector slices; + for (int i = 0; i < op.getNumResults(); i++) { + // rewrite to slice op + auto resultTy = op.getResult(i).getType(); + auto index = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(i)); + auto newSelect = rewriter.create(op->getLoc(), resultTy, + input, dim, index); + slices.push_back(newSelect); + } + rewriter.replaceOp(op, slices); + if (unbind.getResult().use_empty()) + rewriter.eraseOp(unbind); + return success(); + } +}; + +class RecomposeUnbindGetItem : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenUnbindIntOp + __getitem__t to select.int + auto unbind = dyn_cast(op.getList().getDefiningOp()); + if (!unbind) + return failure(); + if (isListPotentiallyMutated(unbind.getResult())) + return failure(); + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + + Location loc = op.getLoc(); + Value dim = unbind.getDim(); + Value input = unbind.getSelf(); + // rewrite to slice op + auto resultTy = op.getResult().getType(); + Value newSelect = rewriter.create(loc, resultTy, input, + dim, op.getIdx()); + rewriter.replaceOp(op, newSelect); + if (unbind.getResult().use_empty()) + rewriter.eraseOp(unbind); + return success(); + } +}; } // namespace namespace { @@ -134,6 +194,8 @@ public: // pattern.add calls go here patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b49136c73..c1d27b8ed 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -589,6 +589,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::any.bool : (bool[]) -> (bool)") emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") # Str ops. emit("aten::add.str : (str, str) -> (str)") diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 08cb00e19..7897a8ac4 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -542,3 +542,42 @@ class SliceCopyNegative_Module(torch.nn.Module): @register_test_case(module_factory=lambda: SliceCopyNegative_Module()) def SliceCopyNegative_Module_basic(module, tu: TestUtils): module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) + + +# ============================================================================== + +class UnbindIntListUnpack_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + def forward(self, x): + unbind_0, unbind_1 = torch.unbind(x, 0) + return torch.ops.aten.sub(unbind_0, unbind_1) + +@register_test_case(module_factory=lambda: UnbindIntListUnpack_Module()) +def UnbindIntListUnpack_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class UnbindIntGetItem_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + def forward(self, x): + unbind = torch.unbind(x, 0) + return torch.ops.aten.sub(unbind[0], unbind[1]) + +@register_test_case(module_factory=lambda: UnbindIntGetItem_Module()) +def UnbindIntGetItem_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4))