diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index a586565f0..63434211e 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -8,6 +8,7 @@ blacklist: - index_put_ # Error: TODO not sure if there are other valid types to handle here # Ops with list of tensors output +- split.Tensor - unbind.int # Additional ops which autogen is supported for but don't compile yet diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index f65c6f5c6..0155ac529 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -726,6 +726,7 @@ STABLEHLO_PASS_SET = { "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", "AtenComplex64Module_basic", + "SplitTensorGetItem_Module_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", } @@ -1012,6 +1013,7 @@ TOSA_PASS_SET = { "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "AtenComplex64Module_basic", + "SplitTensorGetItem_Module_basic", } LTC_XFAIL_SET = { @@ -1191,6 +1193,7 @@ LTC_XFAIL_SET = { "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", + "SplitTensorGetItem_Module_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 591ee4c64..d37bc9c27 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9566,6 +9566,30 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [ }]; } +def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::split.Tensor : (Tensor, int, int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$split_size, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 4d664a58a..116181f84 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -181,6 +181,48 @@ public: return success(); } }; + +class RecomposeSplitTensorGetItemOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp + auto splitTensorOp = + dyn_cast(op.getList().getDefiningOp()); + if (!splitTensorOp) + return failure(); + if (isListPotentiallyMutated(splitTensorOp.getResult())) + return failure(); + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + + int64_t splitSize; + if (!matchPattern(splitTensorOp.getSplitSize(), + m_TorchConstantInt(&splitSize))) + return rewriter.notifyMatchFailure( + op, + "Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int"); + + Location loc = op.getLoc(); + Value step = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * splitSize)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize)); + Value sliceTensorOp = rewriter.create( + loc, op.getResult().getType(), splitTensorOp.getSelf(), + splitTensorOp.getDim(), start, end, step); + rewriter.replaceOp(op, sliceTensorOp); + if (splitTensorOp.getResult().use_empty()) + rewriter.eraseOp(splitTensorOp); + return success(); + } +}; } // namespace namespace { @@ -194,6 +236,7 @@ public: // pattern.add calls go here patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 81ec69689..095d3aa01 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -590,6 +590,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::any.bool : (bool[]) -> (bool)") emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") # Str ops. diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 7897a8ac4..7b9d7395b 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -581,3 +581,24 @@ class UnbindIntGetItem_Module(torch.nn.Module): @register_test_case(module_factory=lambda: UnbindIntGetItem_Module()) def UnbindIntGetItem_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + + +# ============================================================================== + +class SplitTensorGetItem_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + def forward(self, x): + splits = torch.ops.aten.split(x, 1, 0) + return torch.ops.aten.sub(splits[0], splits[1]) + +@register_test_case(module_factory=lambda: SplitTensorGetItem_Module()) +def SplitTensorGetItem_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) +