[ONNX] Add OnnxToTorch support for CenterCropPad (#3496)

pull/3513/head
jinchen 2024-06-28 12:47:29 -07:00 committed by GitHub
parent 6fece25ff3
commit 3915db0a86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 0 deletions

View File

@ -13,6 +13,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <numeric>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -729,6 +730,128 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, maxExpression, minExpression, constantOne); binder.op, resultType, maxExpression, minExpression, constantOne);
return success(); return success();
}); });
patterns.onOp(
"CenterCropPad", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, shape;
if (binder.tensorOperands(input, shape) ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
SmallVector<int64_t> inputShape(inputTy.getSizes());
SmallVector<int64_t> resultShape(resultType.getSizes());
int64_t rank = inputShape.size();
SmallVector<int64_t> axes, defaultAxes(rank);
std::iota(defaultAxes.begin(), defaultAxes.end(), 0);
if (binder.s64IntegerArrayAttr(axes, "axes", defaultAxes)) {
return failure();
}
int64_t axesSize = axes.size();
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
auto scalarTensorType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>{1}, rewriter.getIntegerType(64, /*signed*/ 1));
int64_t lastChangeDim = 0;
llvm::SmallVector<int64_t> interShape(inputShape);
for (int i = 0; i < rank; i++) {
if (inputShape[i] != resultShape[i]) {
interShape[i] = -1;
lastChangeDim = i;
}
if (interShape[i] == ShapedType::kDynamic)
interShape[i] = Torch::kUnknownSize;
}
auto interType = rewriter.getType<Torch::ValueTensorType>(
interShape, resultType.getOptionalDtype());
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
binder.getLoc(), rewriter.getStringAttr("floor"));
for (int i = 0; i < axesSize; i++) {
if (axes[i] < 0)
axes[i] += rank;
if (inputShape[axes[i]] == resultShape[axes[i]])
continue;
auto opType = axes[i] == lastChangeDim ? resultType : interType;
Value axis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]));
Value k = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), scalarTensorType, k);
Value sel = rewriter.create<Torch::AtenIndexSelectOp>(
binder.getLoc(), scalarTensorType, shape, cstZero, kTensor);
Value outputDimSize = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])));
if (inputShape[axes[i]] > resultShape[axes[i]]) {
Value sub = rewriter.create<Torch::AtenSubIntOp>(
binder.getLoc(), inputDimSize, outputDimSize);
Value subTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), scalarTensorType, sub);
Value div = rewriter.create<Torch::AtenDivScalarModeOp>(
binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal);
Value start = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), div);
Value end = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), start, outputDimSize);
input = rewriter.create<Torch::AtenSliceTensorOp>(
binder.getLoc(), opType, input, axis, start, end, cstOne);
} else {
Value sub = rewriter.create<Torch::AtenSubIntOp>(
binder.getLoc(), outputDimSize, inputDimSize);
Value subTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), scalarTensorType, sub);
Value div = rewriter.create<Torch::AtenDivScalarModeOp>(
binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal);
Value start = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), div);
Value end = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), start, inputDimSize);
SmallVector<Value> zerosShapeValues;
for (int j = 0; j < rank; j++) {
if (j == axes[i]) {
zerosShapeValues.push_back(outputDimSize);
} else {
Value dimSize = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), input,
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(j)));
zerosShapeValues.push_back(dimSize);
}
}
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
binder.getLoc(), opType, zerosShapeList, none, none, none,
none);
input = rewriter.create<Torch::AtenSliceScatterOp>(
binder.getLoc(), opType, zeros, input, axis, start, end,
cstOne);
}
}
rewriter.replaceOp(binder.op, input);
return success();
});
patterns.onOp( patterns.onOp(
"Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// https://onnx.ai/onnx/operators/onnx__Clip.html // https://onnx.ai/onnx/operators/onnx__Clip.html

View File

@ -2447,6 +2447,8 @@ func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !tor
return %0 : !torch.vtensor<[1,1,6,6],f32> return %0 : !torch.vtensor<[1,1,6,6],f32>
} }
// -----
// CHECK-LABEL: func.func @test_col2im_strides // CHECK-LABEL: func.func @test_col2im_strides
func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
@ -2483,6 +2485,141 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch
// ----- // -----
// CHECK-LABEL: @test_center_crop_pad_crop_and_pad
func.func @test_center_crop_pad_crop_and_pad(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,10,3],f32>
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,10,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,10,3],f32>
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32>
return %0 : !torch.vtensor<[10,10,3],f32>
}
// -----
// CHECK-LABEL: @test_center_crop_pad_crop_axes_chw
func.func @test_center_crop_pad_crop_axes_chw(%arg0: !torch.vtensor<[3,20,8],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C1_1]] : !torch.vtensor<[3,20,8],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[3,20,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[C1_3:.*]] = torch.constant.int 1
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_3]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[SIZE_2]], %[[ITEM_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,10,9],f32>
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C2_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[3,10,9],f32>, !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,9],f32>
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [1 : si64, 2 : si64]} : (!torch.vtensor<[3,20,8],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32>
return %0 : !torch.vtensor<[3,10,9],f32>
}
// -----
// CHECK-LABEL: @test_center_crop_pad_crop_negative_axes_hwc
func.func @test_center_crop_pad_crop_negative_axes_hwc(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,9,3],f32>
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,9,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,9,3],f32>
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [-3 : si64, -2 : si64]} : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32>
return %0 : !torch.vtensor<[10,9,3],f32>
}
// -----
// CHECK-LABEL: func.func @test_dft_fft // CHECK-LABEL: func.func @test_dft_fft
func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
@ -2506,6 +2643,8 @@ func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vten
return %0 : !torch.vtensor<[10,10,2],f32> return %0 : !torch.vtensor<[10,10,2],f32>
} }
// -----
// CHECK-LABEL: func.func @test_dft_inverse_real // CHECK-LABEL: func.func @test_dft_inverse_real
func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,