mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch support for CenterCropPad (#3496)
parent
6fece25ff3
commit
3915db0a86
|
@ -13,6 +13,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -729,6 +730,128 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, maxExpression, minExpression, constantOne);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"CenterCropPad", 18,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, shape;
|
||||
if (binder.tensorOperands(input, shape) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||
SmallVector<int64_t> inputShape(inputTy.getSizes());
|
||||
SmallVector<int64_t> resultShape(resultType.getSizes());
|
||||
int64_t rank = inputShape.size();
|
||||
|
||||
SmallVector<int64_t> axes, defaultAxes(rank);
|
||||
std::iota(defaultAxes.begin(), defaultAxes.end(), 0);
|
||||
if (binder.s64IntegerArrayAttr(axes, "axes", defaultAxes)) {
|
||||
return failure();
|
||||
}
|
||||
int64_t axesSize = axes.size();
|
||||
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(2));
|
||||
auto scalarTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>{1}, rewriter.getIntegerType(64, /*signed*/ 1));
|
||||
|
||||
int64_t lastChangeDim = 0;
|
||||
llvm::SmallVector<int64_t> interShape(inputShape);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
if (inputShape[i] != resultShape[i]) {
|
||||
interShape[i] = -1;
|
||||
lastChangeDim = i;
|
||||
}
|
||||
if (interShape[i] == ShapedType::kDynamic)
|
||||
interShape[i] = Torch::kUnknownSize;
|
||||
}
|
||||
auto interType = rewriter.getType<Torch::ValueTensorType>(
|
||||
interShape, resultType.getOptionalDtype());
|
||||
|
||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||
binder.getLoc(), rewriter.getStringAttr("floor"));
|
||||
for (int i = 0; i < axesSize; i++) {
|
||||
if (axes[i] < 0)
|
||||
axes[i] += rank;
|
||||
if (inputShape[axes[i]] == resultShape[axes[i]])
|
||||
continue;
|
||||
|
||||
auto opType = axes[i] == lastChangeDim ? resultType : interType;
|
||||
Value axis = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]));
|
||||
Value k = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||
Value kTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
binder.getLoc(), scalarTensorType, k);
|
||||
Value sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||
binder.getLoc(), scalarTensorType, shape, cstZero, kTensor);
|
||||
Value outputDimSize = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), sel);
|
||||
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), input,
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])));
|
||||
|
||||
if (inputShape[axes[i]] > resultShape[axes[i]]) {
|
||||
Value sub = rewriter.create<Torch::AtenSubIntOp>(
|
||||
binder.getLoc(), inputDimSize, outputDimSize);
|
||||
Value subTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
binder.getLoc(), scalarTensorType, sub);
|
||||
Value div = rewriter.create<Torch::AtenDivScalarModeOp>(
|
||||
binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal);
|
||||
Value start = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), div);
|
||||
Value end = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), start, outputDimSize);
|
||||
input = rewriter.create<Torch::AtenSliceTensorOp>(
|
||||
binder.getLoc(), opType, input, axis, start, end, cstOne);
|
||||
} else {
|
||||
Value sub = rewriter.create<Torch::AtenSubIntOp>(
|
||||
binder.getLoc(), outputDimSize, inputDimSize);
|
||||
Value subTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
|
||||
binder.getLoc(), scalarTensorType, sub);
|
||||
Value div = rewriter.create<Torch::AtenDivScalarModeOp>(
|
||||
binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal);
|
||||
Value start = rewriter.create<Torch::AtenItemOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(), div);
|
||||
Value end = rewriter.create<Torch::AtenAddIntOp>(
|
||||
binder.getLoc(), start, inputDimSize);
|
||||
|
||||
SmallVector<Value> zerosShapeValues;
|
||||
for (int j = 0; j < rank; j++) {
|
||||
if (j == axes[i]) {
|
||||
zerosShapeValues.push_back(outputDimSize);
|
||||
} else {
|
||||
Value dimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), input,
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(j)));
|
||||
zerosShapeValues.push_back(dimSize);
|
||||
}
|
||||
}
|
||||
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
zerosShapeValues);
|
||||
Value zeros = rewriter.create<Torch::AtenZerosOp>(
|
||||
binder.getLoc(), opType, zerosShapeList, none, none, none,
|
||||
none);
|
||||
input = rewriter.create<Torch::AtenSliceScatterOp>(
|
||||
binder.getLoc(), opType, zeros, input, axis, start, end,
|
||||
cstOne);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, input);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// https://onnx.ai/onnx/operators/onnx__Clip.html
|
||||
|
|
|
@ -2447,6 +2447,8 @@ func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !tor
|
|||
return %0 : !torch.vtensor<[1,1,6,6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_col2im_strides
|
||||
func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
|
@ -2483,6 +2485,141 @@ func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_center_crop_pad_crop_and_pad
|
||||
func.func @test_center_crop_pad_crop_and_pad(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,10,3],f32>
|
||||
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,10,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,10,3],f32>
|
||||
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32>
|
||||
return %0 : !torch.vtensor<[10,10,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_center_crop_pad_crop_axes_chw
|
||||
func.func @test_center_crop_pad_crop_axes_chw(%arg0: !torch.vtensor<[3,20,8],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C1_1]] : !torch.vtensor<[3,20,8],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[3,20,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32>
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[C1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_3]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[SIZE_2]], %[[ITEM_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,10,9],f32>
|
||||
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C2_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[3,10,9],f32>, !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,9],f32>
|
||||
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [1 : si64, 2 : si64]} : (!torch.vtensor<[3,20,8],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32>
|
||||
return %0 : !torch.vtensor<[3,10,9],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_center_crop_pad_crop_negative_axes_hwc
|
||||
func.func @test_center_crop_pad_crop_negative_axes_hwc(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32>
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.str -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||
// CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,9,3],f32>
|
||||
// CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,9,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,9,3],f32>
|
||||
%0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [-3 : si64, -2 : si64]} : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32>
|
||||
return %0 : !torch.vtensor<[10,9,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_dft_fft
|
||||
func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
|
||||
|
@ -2506,6 +2643,8 @@ func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vten
|
|||
return %0 : !torch.vtensor<[10,10,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_dft_inverse_real
|
||||
func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>,
|
||||
|
|
Loading…
Reference in New Issue