[ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3412/head
Vivek Khandelwal 2024-06-03 20:29:39 +05:30 committed by GitHub
parent 285b087a5d
commit 6382dbbcc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 269 additions and 18 deletions

View File

@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);
LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input, int64_t dimA,
int64_t dimB, Value &transposed);
LogicalResult createTorchPermuteOp(OpBinder binder,
ConversionPatternRewriter &rewriter,
Location loc, Value input,
SmallVector<int64_t> permuteDims,
Value &permuted);
} // namespace mlir::torch::onnx_c } // namespace mlir::torch::onnx_c
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

View File

@ -147,6 +147,10 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
// Torch flags, user options, etc). // Torch flags, user options, etc).
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);
LogicalResult getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -18,23 +18,6 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::onnx_c; using namespace mlir::torch::onnx_c;
static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
Value &transposed) {
Type transposedType;
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
return success();
}
namespace { namespace {
LogicalResult windowFunctionImpl(OpBinder binder, LogicalResult windowFunctionImpl(OpBinder binder,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,

View File

@ -2952,4 +2952,102 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*Torch_BoolType:$antialias*/ cstFalse); /*Torch_BoolType:$antialias*/ cstFalse);
return success(); return success();
}); });
patterns.onOp(
"SpaceToDepth", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t blockSize;
std::string mode;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(blockSize, "blocksize") ||
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
}
SmallVector<int64_t> inputSizes{inputTy.getSizes()};
if (inputSizes.size() != 4) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input rank to be 4");
}
Value b = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
Value c = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1)));
Value h = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
Value w = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(3)));
Value cstBlockSize = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize));
Value cstBlockSizeSquare = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize));
Value hDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
binder.getLoc(), h, cstBlockSize);
Value wDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
binder.getLoc(), w, cstBlockSize);
hDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
hDivBlockSize);
wDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
wDivBlockSize);
// The implementation is as follows:
// tmp = np.reshape(
// x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]
// )
// tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
// y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w //
// blocksize])
Value reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, c, hDivBlockSize, cstBlockSize,
wDivBlockSize, cstBlockSize});
int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize
? Torch::kUnknownSize
: inputSizes[2] / blockSize;
int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize
? Torch::kUnknownSize
: inputSizes[3] / blockSize;
SmallVector<int64_t, 6> reshapeSizesInt{inputSizes[0], inputSizes[1],
hDivBlockSizeInt, blockSize,
wDivBlockSizeInt, blockSize};
Value reshapedInput = rewriter.create<Torch::AtenReshapeOp>(
binder.getLoc(),
inputTy.getWithSizesAndDtype(reshapeSizesInt,
inputTy.getOptionalDtype()),
input, reshapeSizesList);
SmallVector<int64_t, 6> permuteDimsInt{0, 3, 5, 1, 2, 4};
Value permutedInput;
if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(),
reshapedInput, permuteDimsInt,
permutedInput)))
return rewriter.notifyMatchFailure(
binder.op, "Failed to create Torch Permute op");
Value cMulBlockSizeSquare = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), c, cstBlockSizeSquare);
reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(input.getContext())),
llvm::SmallVector<Value>{b, cMulBlockSizeSquare, hDivBlockSize,
wDivBlockSize});
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
binder.op, resultType, permutedInput, reshapeSizesList);
return success();
});
} }

View File

@ -97,3 +97,33 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
return dtypeIntTorch; return dtypeIntTorch;
} }
LogicalResult mlir::torch::onnx_c::createTorchTransposeOp(
ConversionPatternRewriter &rewriter, Location loc, Value input,
int64_t dimA, int64_t dimB, Value &transposed) {
Type transposedType;
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
return success();
}
LogicalResult mlir::torch::onnx_c::createTorchPermuteOp(
OpBinder binder, ConversionPatternRewriter &rewriter, Location loc,
Value input, SmallVector<int64_t> permuteDims, Value &permuted) {
Type permutedType;
if (failed(
Torch::getPermutedType(cast<Torch::BaseTensorType>(input.getType()),
permuteDims, permutedType)))
return failure();
Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims);
permuted = rewriter.create<Torch::AtenPermuteOp>(loc, permutedType, input,
permuteDimsList);
return success();
}

View File

@ -570,6 +570,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
return success(); return success();
} }
LogicalResult Torch::getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType) {
if (!inType.hasSizes())
return failure();
SmallVector<int64_t> shape(inType.getSizes());
if (shape.size() != permuteDims.size())
return failure();
SmallVector<int64_t> permutedShape;
for (unsigned i = 0; i < shape.size(); i++)
permutedShape.push_back(shape[permuteDims[i]]);
permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape),
inType.getOptionalDtype());
return success();
}
Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
if (inputType.isF16()) if (inputType.isF16())
return rewriter.getF32Type(); return rewriter.getF32Type();

View File

@ -2189,3 +2189,111 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32>
} }
// -----
// CHECK-LABEL: @test_spacetodepth_example
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,1,4,6],f32>, !torch.list<int> -> !torch.vtensor<[1,1,2,2,3,2],f32>
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[1,1,2,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[1,2,2,1,2,3],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,4,2,3],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,2,3],f32
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32>
return %0 : !torch.vtensor<[1,4,2,3],f32>
}
// -----
// CHECK-LABEL: @test_spacetodepth
func.func @test_spacetodepth(%arg0: !torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[2,2,6,6],f32>, !torch.list<int> -> !torch.vtensor<[2,2,3,2,3,2],f32>
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[2,2,3,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,2,2,2,3,3],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[2,2,2,2,3,3],f32>, !torch.list<int> -> !torch.vtensor<[2,8,3,3],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,8,3,3],f32
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32>
return %0 : !torch.vtensor<[2,8,3,3],f32>
}
// -----
// CHECK-LABEL: @test_spacetodepth
func.func @test_spacetodepth_dynamic_dims(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[C3:.*]] = torch.constant.int 3
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,2,?,2],f32>
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
// CHECK: %[[C5:.*]] = torch.constant.int 5
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list<int> -> !torch.vtensor<[?,2,2,?,?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[?,2,2,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}