mirror of https://github.com/llvm/torch-mlir
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops: ```python def my_repeat_interleave(input, repeats, dim=None): if dim is None: # Flatten the input and then repeat return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten() else: # Calculate the shape after repeat expanded_shape = list(input.shape) expanded_shape[dim] *= repeats # Repeat the tensor along the specified dimension repeat_shape = [1] * (input.dim() + 1) repeat_shape[dim + 1] = repeats input = input.unsqueeze(-1) # Tile and then reshape tiled = torch.tile(input, repeat_shape) # Rearrange and reshape repeated = tiled.reshape(*expanded_shape) return repeated ``` I passed the tests of stablehlo and linalg. When testing onnx, strange things happened. In torch-mlir's CI **torch_nightly** and my own environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**. In torch-mlir's CI **torch_stable**, it **failed**. The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result shape should be [120]. ```python class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module): def __init__(self): super().__init__() @export @annotate_args([ None, ([3, 4, 5], torch.float32, True), ]) def forward(self, x): return x.repeat_interleave(2) @register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule()) def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) ``` The error log is as follows: ``` Unexpected outcome summary: (onnx) ****** Failed tests - 1 tests FAIL - "RepeatInterleaveSelfIntNoDimModule_basic" @ trace item #0 - call to "forward" @ output of call to "forward" ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120])) ``` @rsuderman Would you please help me check what's wrong with my PR? Thanks a lot.pull/3183/head
parent
491f4820f5
commit
d4313eed4a
|
@ -10418,6 +10418,32 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRepeatInterleaveSelfIntOp : Torch_Op<"aten.repeat_interleave.self_int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$repeats,
|
||||
AnyTorchOptionalIntType:$dim,
|
||||
AnyTorchOptionalIntType:$output_size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRepeatInterleaveSelfIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenRepeatInterleaveSelfIntOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenTileOp : Torch_Op<"aten.tile", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -393,6 +393,59 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||
PrimsCollapseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
||||
auto rank = selfType.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the rank of tensor must be greater than 0");
|
||||
|
||||
int64_t start, end;
|
||||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant start is currently supported");
|
||||
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant end is currently supported");
|
||||
|
||||
start = toPositiveDim(start, rank);
|
||||
end = toPositiveDim(end, rank);
|
||||
SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(rank);
|
||||
for (int r = 0; r < start; ++r)
|
||||
dims.push_back(r);
|
||||
int64_t collapsedDimSize = 1;
|
||||
for (int r = start; r <= end; ++r) {
|
||||
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the size of the dimension being collapsed is can't be unknown");
|
||||
collapsedDimSize *= selfType.getShape()[r];
|
||||
}
|
||||
dims.push_back(collapsedDimSize);
|
||||
for (int r = end + 1; r < rank; ++r)
|
||||
dims.push_back(r);
|
||||
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
||||
stablehloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -405,6 +458,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -7331,6 +7331,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %6 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
|
||||
" %2 = func.call @__torch__.torch.jit._shape_functions.flatten(%arg0, %int0, %int-1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.aten.__getitem__.t %2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.mul.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %4 : (!torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %5 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" %3 = torch.aten.slice.t %arg0, %none, %2, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg0, %2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.mul.int %4, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %7 = torch.aten.add.t %3, %6 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
|
||||
" %8 = torch.aten.add.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.aten.slice.t %arg0, %8, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
|
||||
" %10 = torch.aten.add.t %7, %9 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.tile\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
|
@ -10429,6 +10455,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.repeat_interleave.self_int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.tile\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -2800,6 +2800,100 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// decompose aten.repeat_interleave.self_int into following ops:
|
||||
// aten.flatten.using_ints, aten.unsqueeze, aten.tile, aten.reshape
|
||||
namespace {
|
||||
|
||||
class DecomposeAtenRepeatInterleaveSelfIntOp
|
||||
: public OpRewritePattern<AtenRepeatInterleaveSelfIntOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRepeatInterleaveSelfIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
Value self = op.getSelf();
|
||||
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||
if (!selfTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: no implementation for rankless tensor");
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: no implementation for rankless tensor");
|
||||
|
||||
int64_t inputRank = selfTy.getSizes().size();
|
||||
int64_t repeats;
|
||||
if (!matchPattern(op.getRepeats(), m_TorchConstantInt(&repeats)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: repeats not constant int");
|
||||
|
||||
bool dimIsNone = false;
|
||||
int64_t dim;
|
||||
Value dimValue = op.getDim();
|
||||
if (dimValue.getType().isa<Torch::NoneType>()) {
|
||||
dimIsNone = true;
|
||||
dim = inputRank - 1;
|
||||
} else {
|
||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: dim not constant int");
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
}
|
||||
|
||||
dimValue =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
|
||||
Value dimValuePlusOne = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dim + 1));
|
||||
|
||||
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne);
|
||||
if (failed(unsqueezedInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor op");
|
||||
self = *unsqueezedInfo;
|
||||
|
||||
Value constMinusOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
SmallVector<Value> expandShapeValueList(inputRank + 1, constMinusOne);
|
||||
expandShapeValueList[dim + 1] = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(repeats));
|
||||
Value expandShapeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), expandShapeValueList);
|
||||
Value constFalse =
|
||||
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
|
||||
|
||||
SmallVector<int64_t> expandShape(inputRank + 1);
|
||||
for (int64_t i = 0; i <= dim; i++) {
|
||||
expandShape[i] = selfTy.getSizes()[i];
|
||||
}
|
||||
expandShape[dim + 1] = repeats;
|
||||
for (int64_t i = dim + 1; i < inputRank; i++) {
|
||||
expandShape[i + 1] = selfTy.getSizes()[i];
|
||||
}
|
||||
|
||||
BaseTensorType expandTy = rewriter.getType<ValueTensorType>(
|
||||
expandShape, selfTy.getOptionalDtype());
|
||||
|
||||
Value expandSelf = rewriter.create<AtenExpandOp>(
|
||||
loc, expandTy, self, expandShapeList, constFalse);
|
||||
|
||||
Value result;
|
||||
if (dimIsNone) {
|
||||
Value constZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
result = rewriter.create<AtenFlattenUsingIntsOp>(
|
||||
loc, resType, expandSelf, constZero, constMinusOne);
|
||||
} else {
|
||||
result = rewriter.create<PrimsCollapseOp>(loc, resType, expandSelf,
|
||||
dimValue, dimValuePlusOne);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.flatten.using_ints into aten.view op.
|
||||
namespace {
|
||||
class DecomposeAtenFlattenUsingIntsOp
|
||||
|
@ -7465,6 +7559,8 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatInterleaveSelfIntOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);
|
||||
|
|
|
@ -377,6 +377,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenStackOp>();
|
||||
target.addIllegalOp<AtenRollOp>();
|
||||
target.addIllegalOp<AtenRepeatOp>();
|
||||
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
|
|
|
@ -588,6 +588,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ChunkListUnpackUneven_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"CloneModule_basic",
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
|
@ -853,6 +855,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ReduceSumFloatModule_basic",
|
||||
"ReduceSumSignedIntModule_basic",
|
||||
"ReduceSumUnsignedIntModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"RollModule_basic",
|
||||
|
@ -1390,6 +1394,7 @@ TOSA_PASS_SET = {
|
|||
"ReduceSumSignedIntModule_basic",
|
||||
"ReduceSumUnsignedIntModule_basic",
|
||||
"RepeatModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
|
@ -1512,6 +1517,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"TensorIntModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"TorchPrimLoopForLikeTensorArgModule_basic",
|
||||
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
||||
"ViewSizeDimFollowedByExpandedOnesModule_basic",
|
||||
|
@ -2352,6 +2358,12 @@ if torch_version_for_comparison() >= version.parse("2.4.0.dev"):
|
|||
"ReduceL1NormWithDTypeModule_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse('2.3.0.dev'):
|
||||
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
|
||||
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
}
|
||||
|
||||
|
||||
ONNX_CRASHING_SET = {
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
|
|
|
@ -726,6 +726,15 @@ def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]:
|
|||
out.append(self[i] * repeats[i + leading_rank])
|
||||
return out
|
||||
|
||||
def aten〇repeat_interleave〇self_int〡shape(self: List[int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> List[int]:
|
||||
if dim is None:
|
||||
flatten_size = upstream_shape_functions.flatten(self, 0, -1)[0]
|
||||
return [flatten_size * repeats]
|
||||
else:
|
||||
out = self[:dim] + [self[dim] * repeats] + self[dim + 1:]
|
||||
return out
|
||||
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(3, 2, 8), [2, 2]), # dims_length < self_length
|
||||
Invocation(TensorOfShape(3, 2, 8), [2, 2, 2]) # dims_length >= self_length
|
||||
|
@ -2625,6 +2634,11 @@ def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int])
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=1))
|
||||
def aten〇repeat_interleave〇self_int〡dtype(self_rank_dtype: Tuple[int, int], repeats: int, dim: Optional[int] = None, output_size: Optional[int] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[1]))
|
||||
def aten〇tile〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -648,6 +648,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)")
|
||||
emit("aten::tile : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -1842,6 +1842,47 @@ def RepeatModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class RepeatInterleaveSelfIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.repeat_interleave(2, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntModule())
|
||||
def RepeatInterleaveSelfIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.repeat_interleave(2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
|
||||
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TileSmallDimsSizeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue