diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 63434211e..87c3b1088 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -10,6 +10,7 @@ blacklist: # Ops with list of tensors output - split.Tensor - unbind.int +- chunk # Additional ops which autogen is supported for but don't compile yet - _convolution diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 2742efba2..38224e7c9 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -731,6 +731,8 @@ STABLEHLO_PASS_SET = { "SplitTensorGetItem_Module_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", + "ChunkListUnpack_Module_basic", + "ChunkListUnpackUneven_Module_basic", "RandIntDtypeModule_basic", "RandIntLowDtypeModule_basic", "RandIntLowModule_basic", @@ -1022,6 +1024,8 @@ TOSA_PASS_SET = { "TensorsConcatNegativeDimStaticModule_basic", "AtenComplex64Module_basic", "SplitTensorGetItem_Module_basic", + "ChunkListUnpack_Module_basic", + "ChunkListUnpackUneven_Module_basic", } LTC_XFAIL_SET = { @@ -1204,4 +1208,8 @@ LTC_XFAIL_SET = { "SplitTensorGetItem_Module_basic", "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", + "ChunkListUnpack_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index d37bc9c27..3cb4fb642 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9613,6 +9613,30 @@ def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ }]; } +def Torch_AtenChunkOp : Torch_Op<"aten.chunk", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::chunk : (Tensor, int, int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$chunks, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenChunkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenChunkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 116181f84..6d202783c 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -130,14 +130,15 @@ public: // recompose AtenUnbindOp + PrimListUnpackOp to select.int auto unbind = dyn_cast(op.getOperand().getDefiningOp()); if (!unbind) - return failure(); + return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbind.getResult())) - return failure(); + return rewriter.notifyMatchFailure( + op, "AtenUnbindIntOp result is potentially mutated"); Value dim = unbind.getDim(); Value input = unbind.getSelf(); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { - // rewrite to slice op + // rewrite to select.int op auto resultTy = op.getResult(i).getType(); auto index = rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(i)); @@ -160,9 +161,10 @@ public: // recompose AtenUnbindIntOp + __getitem__t to select.int auto unbind = dyn_cast(op.getList().getDefiningOp()); if (!unbind) - return failure(); + return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbind.getResult())) - return failure(); + return rewriter.notifyMatchFailure( + op, "AtenUnbindIntOp result is potentially mutated"); int64_t index; if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( @@ -192,9 +194,10 @@ public: auto splitTensorOp = dyn_cast(op.getList().getDefiningOp()); if (!splitTensorOp) - return failure(); + return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) - return failure(); + return rewriter.notifyMatchFailure( + op, "SplitTensorOp result is potentially mutated"); int64_t index; if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) return rewriter.notifyMatchFailure( @@ -223,6 +226,59 @@ public: return success(); } }; + +class RecomposeChunkListUnpack : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps + auto chunk = dyn_cast(op.getOperand().getDefiningOp()); + if (!chunk) + return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp"); + if (isListPotentiallyMutated(chunk.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenChunkOp result is potentially mutated"); + Value dim = chunk.getDim(); + Value input = chunk.getSelf(); + Value chunks = chunk.getChunks(); + Location loc = chunk.getLoc(); + Value totalSize = rewriter.create(loc, input, dim); + + // chunkSize = floordiv(totalSize + chunks - 1, chunks) + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dividend = rewriter.create(loc, totalSize, chunks); + dividend = rewriter.create(loc, dividend, cstOne); + Value chunkSize = rewriter.create(loc, dividend, chunks); + + SmallVector slices; + for (size_t i = 0; i < op.getNumResults(); i++) { + // rewrite to slice op with + // start = chunkSize * i, + // end = lastIndex ? totalSize : chunkSize * (i+1) + auto resultTy = op.getResult(i).getType(); + auto index = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(i)); + auto start = rewriter.create(loc, index, chunkSize); + Value end; + if (i == op.getNumResults() - 1) { + end = totalSize; + } else { + auto nextIdx = rewriter.create(loc, index, cstOne); + end = rewriter.create(loc, nextIdx, chunkSize); + } + Value sliceTensorOp = rewriter.create( + loc, resultTy, input, dim, start, end, cstOne); + slices.push_back(sliceTensorOp); + } + rewriter.replaceOp(op, slices); + // erase chunkOp if no user left + if (chunk.getResult().use_empty()) + rewriter.eraseOp(chunk); + return success(); + } +}; } // namespace namespace { @@ -239,6 +295,7 @@ public: patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 095d3aa01..368f34777 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -592,6 +592,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") + emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") # Str ops. emit("aten::add.str : (str, str) -> (str)") diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 7b9d7395b..a279be5dd 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -602,3 +602,82 @@ class SplitTensorGetItem_Module(torch.nn.Module): def SplitTensorGetItem_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) +# ============================================================================== + +class ChunkListUnpack_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 12, 2], torch.float32, True), + ]) + def forward(self, x): + chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) + add = torch.ops.aten.add(chunk_0, chunk_1) + sum = torch.ops.aten.add(add, chunk_2) + return sum + +@register_test_case(module_factory=lambda: ChunkListUnpack_Module()) +def ChunkListUnpack_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 12, 2)) + +# ============================================================================== + +class ChunkListUnpackUneven_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 13, 2], torch.float32, True), + ]) + def forward(self, x): + chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) + return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + +@register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module()) +def ChunkListUnpackUneven_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 13, 2)) + +# ============================================================================== + +class ChunkListUnpackDynamic_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) + add = torch.ops.aten.add(chunk_0, chunk_1) + sum = torch.ops.aten.add(add, chunk_2) + return sum + +@register_test_case(module_factory=lambda: ChunkListUnpackDynamic_Module()) +def ChunkListUnpackDynamic_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 12, 2)) + +# ============================================================================== + +class ChunkListUnpackUnevenDynamic_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) + return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + +@register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module()) +def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 13, 2))