diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 1ef6d83bc..d57f693cc 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -5,6 +5,7 @@ blacklist: # Ops with list of tensors output - split.Tensor +- split_with_sizes - unbind.int - chunk diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index b9836a7ee..67c0515d2 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -872,6 +872,7 @@ STABLEHLO_PASS_SET = { "SplitTensorListUnpackModule_basic", "SplitTensorNegativeDimModule_basic", "SplitTensorLastSmallerModule_basic", + "SplitWithSizesListUnpackModule_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", "ChunkListUnpack_Module_basic", @@ -1216,6 +1217,7 @@ TOSA_PASS_SET = { "SplitTensorListUnpackModule_basic", "SplitTensorNegativeDimModule_basic", "SplitTensorLastSmallerModule_basic", + "SplitWithSizesListUnpackModule_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", "TupleModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index bb3250317..09147dc8f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11017,6 +11017,30 @@ def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [ }]; } +def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$split_sizes, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitWithSizesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitWithSizesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c2bedd5f6..872856f67 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1444,7 +1444,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { return nullptr; // If any operand is a constant true, return true. for (auto operand : inputConstruct.getOperands()) { - bool b; + bool b = false; if (matchPattern(operand, m_TorchConstantBool(&b)) && b) { return getI1IntegerAttr(getContext(), true); } diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 30ed51079..01a962f0b 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -363,6 +363,94 @@ public: } }; +class RecomposeSplitWithSizesListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps + auto splitOp = + dyn_cast(op.getOperand().getDefiningOp()); + if (!splitOp) { + return rewriter.notifyMatchFailure(op, + "Input is not AtenSplitWithSizesOp"); + } + if (isListPotentiallyMutated(splitOp.getResult())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp result is potentially mutated"); + } + if (isListPotentiallyMutated(splitOp.getSplitSizes())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp's split_sizes is potentially mutated"); + } + auto splitSizesConstruct = + splitOp.getSplitSizes().getDefiningOp(); + if (!splitSizesConstruct) { + return rewriter.notifyMatchFailure( + op, "split_sizes is not from PrimListConstructOp"); + } + + int64_t sumSplitSize = 0; + SmallVector splitSizes; + for (auto operand : splitSizesConstruct.getOperands()) { + int64_t value = -1; + // TODO: support when split_sizes are not constant int + if (!matchPattern(operand, m_TorchConstantInt(&value))) { + return rewriter.notifyMatchFailure( + op, "one of split_sizes is not constant int"); + } + if (value < 0) { + return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0"); + } + sumSplitSize += value; + splitSizes.push_back(value); + } + if (splitSizes.size() != op.getNumResults()) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be same as splitOp result size"); + } + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // add runtime.assert to check rank constraint + Value totalSize = rewriter.create(loc, input, dim); + Value cstSumSplitSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value eqOrNot = + rewriter.create(loc, totalSize, cstSumSplitSize); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("split dim must be sum of split_sizes")); + + // calculate slice op's lower bound and up bound + SmallVector boundaryOfSliceOp(splitSizes.size() + 1, 0); + for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) { + boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; + } + SmallVector slices; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + for (size_t i = 0; i < op.getNumResults(); i++) { + auto resultTy = op.getResult(i).getType(); + auto start = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i])); + auto end = rewriter.create( + loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1]))); + Value sliceTensorOp = rewriter.create( + loc, resultTy, input, dim, start, end, /*step=*/cstOne); + slices.push_back(sliceTensorOp); + } + rewriter.replaceOp(op, slices); + // erase splitOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + class RecomposeChunkListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -436,6 +524,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); patterns.add(context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e418a9337..0866f0b9a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -651,6 +651,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") + emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index e18775a85..e5d31fe9c 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -800,6 +800,26 @@ def SplitTensorNegativeDimModule_basic(module, tu: TestUtils): # ============================================================================== +class SplitWithSizesListUnpackModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 12], torch.float32, True) + ]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split_with_sizes(x, [3, 4, 5], -1) + return (s0, s1, s2) + +@register_test_case(module_factory=lambda: SplitWithSizesListUnpackModule()) +def SplitWithSizesListUnpackModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 12)) + +# ============================================================================== + class ChunkListUnpack_Module(torch.nn.Module): def __init__(self): super().__init__()