mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for SpaceToDepth op (#3393)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3412/head
parent
285b087a5d
commit
6382dbbcc0
|
@ -96,6 +96,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
|
||||||
|
|
||||||
std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);
|
std::optional<int64_t> onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx);
|
||||||
|
|
||||||
|
LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value input, int64_t dimA,
|
||||||
|
int64_t dimB, Value &transposed);
|
||||||
|
|
||||||
|
LogicalResult createTorchPermuteOp(OpBinder binder,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value input,
|
||||||
|
SmallVector<int64_t> permuteDims,
|
||||||
|
Value &permuted);
|
||||||
|
|
||||||
} // namespace mlir::torch::onnx_c
|
} // namespace mlir::torch::onnx_c
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||||
|
|
|
@ -147,6 +147,10 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
|
||||||
// Torch flags, user options, etc).
|
// Torch flags, user options, etc).
|
||||||
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);
|
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);
|
||||||
|
|
||||||
|
LogicalResult getPermutedType(BaseTensorType inType,
|
||||||
|
SmallVector<int64_t> permuteDims,
|
||||||
|
Type &permutedType);
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -18,23 +18,6 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::onnx_c;
|
using namespace mlir::torch::onnx_c;
|
||||||
|
|
||||||
static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, Value input,
|
|
||||||
int64_t dimA, int64_t dimB,
|
|
||||||
Value &transposed) {
|
|
||||||
Type transposedType;
|
|
||||||
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
|
||||||
dimA, dimB, transposedType)))
|
|
||||||
return failure();
|
|
||||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(dimA));
|
|
||||||
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(dimB));
|
|
||||||
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
|
||||||
loc, transposedType, input, cstDimA, cstDimB);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
LogicalResult windowFunctionImpl(OpBinder binder,
|
LogicalResult windowFunctionImpl(OpBinder binder,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
|
|
|
@ -2952,4 +2952,102 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
/*Torch_BoolType:$antialias*/ cstFalse);
|
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"SpaceToDepth", 1,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
int64_t blockSize;
|
||||||
|
std::string mode;
|
||||||
|
if (binder.tensorOperand(input) ||
|
||||||
|
binder.s64IntegerAttr(blockSize, "blocksize") ||
|
||||||
|
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
|
||||||
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
|
||||||
|
if (!inputTy || !inputTy.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Expected input type having sizes");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> inputSizes{inputTy.getSizes()};
|
||||||
|
if (inputSizes.size() != 4) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Expected input rank to be 4");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value b = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), input,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||||
|
Value c = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), input,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(1)));
|
||||||
|
Value h = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), input,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(2)));
|
||||||
|
Value w = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), input,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(3)));
|
||||||
|
Value cstBlockSize = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize));
|
||||||
|
Value cstBlockSizeSquare = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize));
|
||||||
|
Value hDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
|
||||||
|
binder.getLoc(), h, cstBlockSize);
|
||||||
|
Value wDivBlockSize = rewriter.create<Torch::AtenDivIntOp>(
|
||||||
|
binder.getLoc(), w, cstBlockSize);
|
||||||
|
hDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
|
||||||
|
hDivBlockSize);
|
||||||
|
wDivBlockSize = rewriter.create<Torch::AtenIntFloatOp>(binder.getLoc(),
|
||||||
|
wDivBlockSize);
|
||||||
|
|
||||||
|
// The implementation is as follows:
|
||||||
|
// tmp = np.reshape(
|
||||||
|
// x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize]
|
||||||
|
// )
|
||||||
|
// tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
|
||||||
|
// y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w //
|
||||||
|
// blocksize])
|
||||||
|
Value reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||||
|
llvm::SmallVector<Value>{b, c, hDivBlockSize, cstBlockSize,
|
||||||
|
wDivBlockSize, cstBlockSize});
|
||||||
|
int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize
|
||||||
|
? Torch::kUnknownSize
|
||||||
|
: inputSizes[2] / blockSize;
|
||||||
|
int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize
|
||||||
|
? Torch::kUnknownSize
|
||||||
|
: inputSizes[3] / blockSize;
|
||||||
|
SmallVector<int64_t, 6> reshapeSizesInt{inputSizes[0], inputSizes[1],
|
||||||
|
hDivBlockSizeInt, blockSize,
|
||||||
|
wDivBlockSizeInt, blockSize};
|
||||||
|
Value reshapedInput = rewriter.create<Torch::AtenReshapeOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
inputTy.getWithSizesAndDtype(reshapeSizesInt,
|
||||||
|
inputTy.getOptionalDtype()),
|
||||||
|
input, reshapeSizesList);
|
||||||
|
|
||||||
|
SmallVector<int64_t, 6> permuteDimsInt{0, 3, 5, 1, 2, 4};
|
||||||
|
Value permutedInput;
|
||||||
|
if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(),
|
||||||
|
reshapedInput, permuteDimsInt,
|
||||||
|
permutedInput)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Failed to create Torch Permute op");
|
||||||
|
|
||||||
|
Value cMulBlockSizeSquare = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
|
binder.getLoc(), c, cstBlockSizeSquare);
|
||||||
|
reshapeSizesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||||
|
llvm::SmallVector<Value>{b, cMulBlockSizeSquare, hDivBlockSize,
|
||||||
|
wDivBlockSize});
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenReshapeOp>(
|
||||||
|
binder.op, resultType, permutedInput, reshapeSizesList);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,3 +97,33 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
|
||||||
|
|
||||||
return dtypeIntTorch;
|
return dtypeIntTorch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::torch::onnx_c::createTorchTransposeOp(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, Value input,
|
||||||
|
int64_t dimA, int64_t dimB, Value &transposed) {
|
||||||
|
Type transposedType;
|
||||||
|
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||||
|
dimA, dimB, transposedType)))
|
||||||
|
return failure();
|
||||||
|
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(dimA));
|
||||||
|
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(dimB));
|
||||||
|
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
|
loc, transposedType, input, cstDimA, cstDimB);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult mlir::torch::onnx_c::createTorchPermuteOp(
|
||||||
|
OpBinder binder, ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value input, SmallVector<int64_t> permuteDims, Value &permuted) {
|
||||||
|
Type permutedType;
|
||||||
|
if (failed(
|
||||||
|
Torch::getPermutedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||||
|
permuteDims, permutedType)))
|
||||||
|
return failure();
|
||||||
|
Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims);
|
||||||
|
permuted = rewriter.create<Torch::AtenPermuteOp>(loc, permutedType, input,
|
||||||
|
permuteDimsList);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
|
@ -570,6 +570,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult Torch::getPermutedType(BaseTensorType inType,
|
||||||
|
SmallVector<int64_t> permuteDims,
|
||||||
|
Type &permutedType) {
|
||||||
|
if (!inType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> shape(inType.getSizes());
|
||||||
|
if (shape.size() != permuteDims.size())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> permutedShape;
|
||||||
|
for (unsigned i = 0; i < shape.size(); i++)
|
||||||
|
permutedShape.push_back(shape[permuteDims[i]]);
|
||||||
|
permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape),
|
||||||
|
inType.getOptionalDtype());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
|
Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
|
||||||
if (inputType.isF16())
|
if (inputType.isF16())
|
||||||
return rewriter.getF32Type();
|
return rewriter.getF32Type();
|
||||||
|
|
|
@ -2189,3 +2189,111 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
|
||||||
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
|
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_spacetodepth_example
|
||||||
|
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,1,4,6],f32>, !torch.list<int> -> !torch.vtensor<[1,1,2,2,3,2],f32>
|
||||||
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[1,1,2,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[1,2,2,1,2,3],f32>
|
||||||
|
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,1,2,3],f32>, !torch.list<int> -> !torch.vtensor<[1,4,2,3],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,2,3],f32
|
||||||
|
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,4,2,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_spacetodepth
|
||||||
|
func.func @test_spacetodepth(%arg0: !torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[2,2,6,6],f32>, !torch.list<int> -> !torch.vtensor<[2,2,3,2,3,2],f32>
|
||||||
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[2,2,3,2,3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,2,2,2,3,3],f32>
|
||||||
|
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[2,2,2,2,3,3],f32>, !torch.list<int> -> !torch.vtensor<[2,8,3,3],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,8,3,3],f32
|
||||||
|
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,8,3,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_spacetodepth
|
||||||
|
func.func @test_spacetodepth_dynamic_dims(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float
|
||||||
|
// CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,2,?,2],f32>
|
||||||
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[C5:.*]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list<int> -> !torch.vtensor<[?,2,2,?,?,?],f32>
|
||||||
|
// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[?,2,2,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32
|
||||||
|
%0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue