From d4313eed4a61d33f31dd4c42088c0c024ed0aec6 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Thu, 18 Apr 2024 06:27:51 +0800 Subject: [PATCH] [Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075) Decomposition RepeatInterleaveSelfInt with following ops: ```python def my_repeat_interleave(input, repeats, dim=None): if dim is None: # Flatten the input and then repeat return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten() else: # Calculate the shape after repeat expanded_shape = list(input.shape) expanded_shape[dim] *= repeats # Repeat the tensor along the specified dimension repeat_shape = [1] * (input.dim() + 1) repeat_shape[dim + 1] = repeats input = input.unsqueeze(-1) # Tile and then reshape tiled = torch.tile(input, repeat_shape) # Rearrange and reshape repeated = tiled.reshape(*expanded_shape) return repeated ``` I passed the tests of stablehlo and linalg. When testing onnx, strange things happened. In torch-mlir's CI **torch_nightly** and my own environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**. In torch-mlir's CI **torch_stable**, it **failed**. The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result shape should be [120]. ```python class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([3, 4, 5], torch.float32, True), ]) def forward(self, x): return x.repeat_interleave(2) @register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule()) def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) ``` The error log is as follows: ``` Unexpected outcome summary: (onnx) ****** Failed tests - 1 tests FAIL - "RepeatInterleaveSelfIntNoDimModule_basic" @ trace item #0 - call to "forward" @ output of call to "forward" ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) ``` @rsuderman Would you please help me check what's wrong with my PR? Thanks a lot. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++ lib/Conversion/TorchToStablehlo/ViewLike.cpp | 54 +++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 30 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 96 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 12 +++ .../build_tools/abstract_interp_lib_gen.py | 14 +++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 41 ++++++++ 9 files changed, 275 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index e10bf6464..8d00e0f96 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10418,6 +10418,32 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ }]; } +def Torch_AtenRepeatInterleaveSelfIntOp : Torch_Op<"aten.repeat_interleave.self_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$repeats, + AnyTorchOptionalIntType:$dim, + AnyTorchOptionalIntType:$output_size + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRepeatInterleaveSelfIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRepeatInterleaveSelfIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenTileOp : Torch_Op<"aten.tile", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index f6c879907..fdd482a0d 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -393,6 +393,59 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsCollapseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.getA().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); + } + + auto rank = selfType.getRank(); + if (rank == 0) + return rewriter.notifyMatchFailure( + op, "the rank of tensor must be greater than 0"); + + int64_t start, end; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "only constant start is currently supported"); + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure( + op, "only constant end is currently supported"); + + start = toPositiveDim(start, rank); + end = toPositiveDim(end, rank); + SmallVector dims; + dims.reserve(rank); + for (int r = 0; r < start; ++r) + dims.push_back(r); + int64_t collapsedDimSize = 1; + for (int r = start; r <= end; ++r) { + if (selfType.getShape()[r] == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + op, "the size of the dimension being collapsed is can't be unknown"); + collapsedDimSize *= selfType.getShape()[r]; + } + dims.push_back(collapsedDimSize); + for (int r = end + 1; r < rank; ++r) + dims.push_back(r); + + auto newDimSizesInfo = hlo::getDimSizesOfTensor( + rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); + if (failed(newDimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + auto newDimSizes = *newDimSizesInfo; + auto stablehloShape = + rewriter.create(op.getLoc(), newDimSizes); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), + stablehloShape); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -405,6 +458,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSqueezeOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b1e3a7322..5ad9b8216 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7331,6 +7331,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = func.call @__torch__.torch.jit._shape_functions.flatten(%arg0, %int0, %int-1) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__getitem__.t %2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.mul.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = torch.aten.slice.t %arg0, %none, %2, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %4 = torch.aten.__getitem__.t %arg0, %2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list\n" +" %7 = torch.aten.add.t %3, %6 : !torch.list, !torch.list -> !torch.list\n" +" %8 = torch.aten.add.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.slice.t %arg0, %8, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %10 = torch.aten.add.t %7, %9 : !torch.list, !torch.list -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %int1 = torch.constant.int 1\n" " %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" @@ -10429,6 +10455,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e436bdf92..97bd85063 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2800,6 +2800,100 @@ public: }; } // namespace +// decompose aten.repeat_interleave.self_int into following ops: +// aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape +namespace { + +class DecomposeAtenRepeatInterleaveSelfIntOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRepeatInterleaveSelfIntOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value self = op.getSelf(); + auto selfTy = cast(self.getType()); + if (!selfTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + auto resType = op.getType().cast(); + if (!resType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + + int64_t inputRank = selfTy.getSizes().size(); + int64_t repeats; + if (!matchPattern(op.getRepeats(), m_TorchConstantInt(&repeats))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: repeats not constant int"); + + bool dimIsNone = false; + int64_t dim; + Value dimValue = op.getDim(); + if (dimValue.getType().isa()) { + dimIsNone = true; + dim = inputRank - 1; + } else { + if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dim not constant int"); + dim = toPositiveDim(dim, inputRank); + } + + dimValue = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + Value dimValuePlusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim + 1)); + + auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne); + if (failed(unsqueezedInfo)) + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor op"); + self = *unsqueezedInfo; + + Value constMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + SmallVector expandShapeValueList(inputRank + 1, constMinusOne); + expandShapeValueList[dim + 1] = rewriter.create( + loc, rewriter.getI64IntegerAttr(repeats)); + Value expandShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), expandShapeValueList); + Value constFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + SmallVector expandShape(inputRank + 1); + for (int64_t i = 0; i <= dim; i++) { + expandShape[i] = selfTy.getSizes()[i]; + } + expandShape[dim + 1] = repeats; + for (int64_t i = dim + 1; i < inputRank; i++) { + expandShape[i + 1] = selfTy.getSizes()[i]; + } + + BaseTensorType expandTy = rewriter.getType( + expandShape, selfTy.getOptionalDtype()); + + Value expandSelf = rewriter.create( + loc, expandTy, self, expandShapeList, constFalse); + + Value result; + if (dimIsNone) { + Value constZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + result = rewriter.create( + loc, resType, expandSelf, constZero, constMinusOne); + } else { + result = rewriter.create(loc, resType, expandSelf, + dimValue, dimValuePlusOne); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // Decompose aten.flatten.using_ints into aten.view op. namespace { class DecomposeAtenFlattenUsingIntsOp @@ -7465,6 +7559,8 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b10456cd6..701300fef 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -377,6 +377,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 53d468b17..f6949abcf 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -588,6 +588,8 @@ STABLEHLO_PASS_SET = { "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", "CloneModule_basic", + "CollapseAllDimensionsModule_basic", + "CollapseStaticModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", @@ -853,6 +855,8 @@ STABLEHLO_PASS_SET = { "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", "RollModule_basic", @@ -1390,6 +1394,7 @@ TOSA_PASS_SET = { "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", "RepeatModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", "ResNet18StaticModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", @@ -1512,6 +1517,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { "TensorIntModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "RepeatInterleaveSelfIntModule_basic", "TorchPrimLoopForLikeTensorArgModule_basic", "ViewSizeDimFollowedByCollapsedOnesModule_basic", "ViewSizeDimFollowedByExpandedOnesModule_basic", @@ -2352,6 +2358,12 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"): "ReduceL1NormWithDTypeModule_basic", } +if torch_version_for_comparison() < version.parse('2.3.0.dev'): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) + "RepeatInterleaveSelfIntNoDimModule_basic", + } + ONNX_CRASHING_SET = { "FakeQuantizePerTensorAffineModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1097778e5..5d63d7a7d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -726,6 +726,15 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: out.append(self[i] * repeats[i + leading_rank]) return out +def aten〇repeat_interleave〇self_int〡shape(self: List[int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> List[int]: + if dim is None: + flatten_size = upstream_shape_functions.flatten(self, 0, -1)[0] + return [flatten_size * repeats] + else: + out = self[:dim] + [self[dim] * repeats] + self[dim + 1:] + return out + + @check_shape_function([ Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length @@ -2625,6 +2634,11 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=1)) +def aten〇repeat_interleave〇self_int〡dtype(self_rank_dtype: Tuple[int, int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1])) def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 15baeef00..78861202c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -648,6 +648,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::repeat : (Tensor, int[]) -> (Tensor)") + emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)") emit("aten::tile : (Tensor, int[]) -> (Tensor)") emit("aten::reshape : (Tensor, int[]) -> (Tensor)") emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index bc55973d0..fa99522a8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1842,6 +1842,47 @@ def RepeatModule_basic(module, tu: TestUtils): # ============================================================================== +class RepeatInterleaveSelfIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float32, True), + ]) + def forward(self, x): + return x.repeat_interleave(2, 1) + + +@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntModule()) +def RepeatInterleaveSelfIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + + +class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float32, True), + ]) + def forward(self, x): + return x.repeat_interleave(2) + + +@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule()) +def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class TileSmallDimsSizeModule(torch.nn.Module): def __init__(self):