[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)

Decomposition RepeatInterleaveSelfInt with following ops:
```python

def my_repeat_interleave(input, repeats, dim=None):
    if dim is None:
        # Flatten the input and then repeat
        return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
    else:
        # Calculate the shape after repeat
        expanded_shape = list(input.shape)
        expanded_shape[dim] *= repeats
        # Repeat the tensor along the specified dimension
        repeat_shape = [1] * (input.dim() + 1)
        repeat_shape[dim + 1] = repeats
        input = input.unsqueeze(-1)

        # Tile and then reshape
        tiled = torch.tile(input, repeat_shape)
        # Rearrange and reshape
        repeated = tiled.reshape(*expanded_shape)
    return repeated

```

I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI  **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):

    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([3, 4, 5], torch.float32, True),
    ])
    def forward(self, x):
        return x.repeat_interleave(2)


@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
  Unexpected outcome summary: (onnx)
  
  ****** Failed tests - 1 tests
      FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
          @ trace item #0 - call to "forward"
          @ output of call to "forward"
          ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```

@rsuderman 
Would you please help me check what's wrong with my PR? Thanks a lot.
pull/3183/head
Xinyu Yang 2024-04-18 06:27:51 +08:00 committed by GitHub
parent 491f4820f5
commit d4313eed4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 275 additions and 0 deletions

View File

@ -10418,6 +10418,32 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
}];
}
def Torch_AtenRepeatInterleaveSelfIntOp : Torch_Op<"aten.repeat_interleave.self_int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$repeats,
AnyTorchOptionalIntType:$dim,
AnyTorchOptionalIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRepeatInterleaveSelfIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRepeatInterleaveSelfIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenTileOp : Torch_Op<"aten.tile", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -393,6 +393,59 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
PrimsCollapseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}
auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure(
op, "the rank of tensor must be greater than 0");
int64_t start, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "only constant start is currently supported");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");
start = toPositiveDim(start, rank);
end = toPositiveDim(end, rank);
SmallVector<int64_t, 4> dims;
dims.reserve(rank);
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
stablehloShape);
return success();
}
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -405,6 +458,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_VIEW_OP_PATTERN(AtenOp) \

View File

@ -7331,6 +7331,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
" %2 = func.call @__torch__.torch.jit._shape_functions.flatten(%arg0, %int0, %int-1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.__getitem__.t %2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.mul.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %none, %2, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n"
" %6 = torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>\n"
" %7 = torch.aten.add.t %3, %6 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %8 = torch.aten.add.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.aten.slice.t %arg0, %8, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %10 = torch.aten.add.t %7, %9 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" torch.prim.If.yield %10 : !torch.list<int>\n"
" }\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
@ -10429,6 +10455,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2800,6 +2800,100 @@ public:
};
} // namespace
// decompose aten.repeat_interleave.self_int into following ops:
// aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape
namespace {
class DecomposeAtenRepeatInterleaveSelfIntOp
: public OpRewritePattern<AtenRepeatInterleaveSelfIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRepeatInterleaveSelfIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
Value self = op.getSelf();
auto selfTy = cast<BaseTensorType>(self.getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasSizes())
return rewriter.notifyMatchFailure(
op, "Unimplemented: no implementation for rankless tensor");
int64_t inputRank = selfTy.getSizes().size();
int64_t repeats;
if (!matchPattern(op.getRepeats(), m_TorchConstantInt(&repeats)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: repeats not constant int");
bool dimIsNone = false;
int64_t dim;
Value dimValue = op.getDim();
if (dimValue.getType().isa<Torch::NoneType>()) {
dimIsNone = true;
dim = inputRank - 1;
} else {
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: dim not constant int");
dim = toPositiveDim(dim, inputRank);
}
dimValue =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
Value dimValuePlusOne = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim + 1));
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
if (failed(unsqueezedInfo))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
self = *unsqueezedInfo;
Value constMinusOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
expandShapeValueList[dim + 1] = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(repeats));
Value expandShapeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), expandShapeValueList);
Value constFalse =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
SmallVector<int64_t> expandShape(inputRank + 1);
for (int64_t i = 0; i <= dim; i++) {
expandShape[i] = selfTy.getSizes()[i];
}
expandShape[dim + 1] = repeats;
for (int64_t i = dim + 1; i < inputRank; i++) {
expandShape[i + 1] = selfTy.getSizes()[i];
}
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
expandShape, selfTy.getOptionalDtype());
Value expandSelf = rewriter.create<AtenExpandOp>(
loc, expandTy, self, expandShapeList, constFalse);
Value result;
if (dimIsNone) {
Value constZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
result = rewriter.create<AtenFlattenUsingIntsOp>(
loc, resType, expandSelf, constZero, constMinusOne);
} else {
result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
dimValue, dimValuePlusOne);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// Decompose aten.flatten.using_ints into aten.view op.
namespace {
class DecomposeAtenFlattenUsingIntsOp
@ -7465,6 +7559,8 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);

View File

@ -377,6 +377,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenWhereScalarOp>();

View File

@ -588,6 +588,8 @@ STABLEHLO_PASS_SET = {
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"CloneModule_basic",
"CollapseAllDimensionsModule_basic",
"CollapseStaticModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
@ -853,6 +855,8 @@ STABLEHLO_PASS_SET = {
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"RollModule_basic",
@ -1390,6 +1394,7 @@ TOSA_PASS_SET = {
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"RepeatInterleaveSelfIntNoDimModule_basic",
"ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
@ -1512,6 +1517,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"RepeatInterleaveSelfIntModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
"ViewSizeDimFollowedByExpandedOnesModule_basic",
@ -2352,6 +2358,12 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
"ReduceL1NormWithDTypeModule_basic",
}
if torch_version_for_comparison() < version.parse('2.3.0.dev'):
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
"RepeatInterleaveSelfIntNoDimModule_basic",
}
ONNX_CRASHING_SET = {
"FakeQuantizePerTensorAffineModule_basic",

View File

@ -726,6 +726,15 @@ def atenrepeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
out.append(self[i] * repeats[i + leading_rank])
return out
def atenrepeat_interleaveself_int〡shape(self: List[int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> List[int]:
if dim is None:
flatten_size = upstream_shape_functions.flatten(self, 0, -1)[0]
return [flatten_size * repeats]
else:
out = self[:dim] + [self[dim] * repeats] + self[dim + 1:]
return out
@check_shape_function([
Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length
Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length
@ -2625,6 +2634,11 @@ def atenrepeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=1))
def atenrepeat_interleaveself_int〡dtype(self_rank_dtype: Tuple[int, int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1]))
def atentile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -648,6 +648,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")

View File

@ -1842,6 +1842,47 @@ def RepeatModule_basic(module, tu: TestUtils):
# ==============================================================================
class RepeatInterleaveSelfIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2, 1)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntModule())
def RepeatInterleaveSelfIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class TileSmallDimsSizeModule(torch.nn.Module):
def __init__(self):