From 6382dbbcc00c35cc3deb8a75f0b181ae4701552a Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 3 Jun 2024 20:29:39 +0530 Subject: [PATCH] [ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393) Signed-Off By: Vivek Khandelwal --- .../Conversion/TorchOnnxToTorch/Utils.h | 10 ++ .../torch-mlir/Dialect/Torch/Utils/Utils.h | 4 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 17 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 98 ++++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 30 +++++ lib/Dialect/Torch/Utils/Utils.cpp | 18 +++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 110 +++++++++++++++++- 7 files changed, 269 insertions(+), 18 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 97d004e36..d8d2534f9 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { std::optional onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); +LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, + Location loc, Value input, int64_t dimA, + int64_t dimB, Value &transposed); + +LogicalResult createTorchPermuteOp(OpBinder binder, + ConversionPatternRewriter &rewriter, + Location loc, Value input, + SmallVector permuteDims, + Value &permuted); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 24db6f14f..62e6680f4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -147,6 +147,10 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, // Torch flags, user options, etc). Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); +LogicalResult getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index cb5affbbb..eb6bfbe76 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -18,23 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, - Location loc, Value input, - int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), - dimA, dimB, transposedType))) - return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); - return success(); -} - namespace { LogicalResult windowFunctionImpl(OpBinder binder, ConversionPatternRewriter &rewriter, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f0795d332..69e9ce6d9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2952,4 +2952,102 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "SpaceToDepth", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t blockSize; + std::string mode; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(blockSize, "blocksize") || + binder.customOpNameStringAttr(mode, "mode", "DCR") || + binder.tensorResultType(resultType)) + return failure(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + SmallVector inputSizes{inputTy.getSizes()}; + if (inputSizes.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input rank to be 4"); + } + + Value b = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + Value c = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1))); + Value h = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2))); + Value w = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value hDivBlockSize = rewriter.create( + binder.getLoc(), h, cstBlockSize); + Value wDivBlockSize = rewriter.create( + binder.getLoc(), w, cstBlockSize); + hDivBlockSize = rewriter.create(binder.getLoc(), + hDivBlockSize); + wDivBlockSize = rewriter.create(binder.getLoc(), + wDivBlockSize); + + // The implementation is as follows: + // tmp = np.reshape( + // x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize] + // ) + // tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) + // y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // + // blocksize]) + Value reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, c, hDivBlockSize, cstBlockSize, + wDivBlockSize, cstBlockSize}); + int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[2] / blockSize; + int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[3] / blockSize; + SmallVector reshapeSizesInt{inputSizes[0], inputSizes[1], + hDivBlockSizeInt, blockSize, + wDivBlockSizeInt, blockSize}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); + + SmallVector permuteDimsInt{0, 3, 5, 1, 2, 4}; + Value permutedInput; + if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(), + reshapedInput, permuteDimsInt, + permutedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create Torch Permute op"); + + Value cMulBlockSizeSquare = rewriter.create( + binder.getLoc(), c, cstBlockSizeSquare); + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cMulBlockSizeSquare, hDivBlockSize, + wDivBlockSize}); + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedInput, reshapeSizesList); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index dec134906..e7baf2e24 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -97,3 +97,33 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { return dtypeIntTorch; } + +LogicalResult mlir::torch::onnx_c::createTorchTransposeOp( + ConversionPatternRewriter &rewriter, Location loc, Value input, + int64_t dimA, int64_t dimB, Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + +LogicalResult mlir::torch::onnx_c::createTorchPermuteOp( + OpBinder binder, ConversionPatternRewriter &rewriter, Location loc, + Value input, SmallVector permuteDims, Value &permuted) { + Type permutedType; + if (failed( + Torch::getPermutedType(cast(input.getType()), + permuteDims, permutedType))) + return failure(); + Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims); + permuted = rewriter.create(loc, permutedType, input, + permuteDimsList); + return success(); +} diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 197f09c66..1c7e6f284 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -570,6 +570,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, return success(); } +LogicalResult Torch::getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType) { + if (!inType.hasSizes()) + return failure(); + + SmallVector shape(inType.getSizes()); + if (shape.size() != permuteDims.size()) + return failure(); + + SmallVector permutedShape; + for (unsigned i = 0; i < shape.size(); i++) + permutedShape.push_back(shape[permuteDims[i]]); + permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape), + inType.getOptionalDtype()); + return success(); +} + Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { if (inputType.isF16()) return rewriter.getF32Type(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index f520d961a..ed3dc10c9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -220,7 +220,7 @@ func.func @test_scatter(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor< // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RESULT:.*]] = torch.aten.scatter.src %arg0, %[[INT0]], %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3],f32> - %0 = torch.operator "onnx.Scatter"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.Scatter"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> return %0 : !torch.vtensor<[3,3],f32> } @@ -2189,3 +2189,111 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: @test_spacetodepth_example +func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,1,4,6],f32>, !torch.list -> !torch.vtensor<[1,1,2,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[1,1,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,2,1,2,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,1,2,3],f32>, !torch.list -> !torch.vtensor<[1,4,2,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,2,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> + return %0 : !torch.vtensor<[1,4,2,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth(%arg0: !torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[2,2,6,6],f32>, !torch.list -> !torch.vtensor<[2,2,3,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[2,2,3,2,3,2],f32>, !torch.list -> !torch.vtensor<[2,2,2,2,3,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[2,2,2,2,3,3],f32>, !torch.list -> !torch.vtensor<[2,8,3,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,8,3,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> + return %0 : !torch.vtensor<[2,8,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth_dynamic_dims(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,2,?,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list -> !torch.vtensor<[?,2,2,?,?,?],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[?,2,2,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +}