mirror of https://github.com/llvm/torch-mlir
onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848)
This PR contains three commits to update the validation checks in the ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators: > onnx: fix preconditions for lowering AveragePool ops > > The `pads` attribute of the AveragePool operator specifies the value to > pad at both the beginning as well as the end of the axis (see > https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so > the size of this attribute should be twice the rank of the input tensor. > However, our TorchOnnxToTorch bails out early since it incorrectly > compares the pads attribute with the rank (not twice the rank) of the > input tensor. > > This patch fixes the code to match the spec and adds a lit test. > onnx: allow optional constant value for Pad operator > > The `constant_value` input of the onnx.Pad operator is optional (see > https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the existing > logic for lowering the operator into the Torch dialect assumes that it > is mandatory. > > This patch makes the attribute optional and constructs a default value > (a list of zeros the size of the input tensor) if the attribute was not > specified. > onnx: fix checks for axes and steps inputs of Slice operator > > The ONNX Spec for the Slice operator allows the `starts` and `ends` > inputs to have fewer indices that the dimensions of the `data` tensor > (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code > expects these inputs to be as many as the `data` tensor's dimensions. > > More precisely, the spec requires that the `starts` and `ends` inputs > are only as long as the `axes` input, but since the `axes` input is > optional, the default type for the `axes` input has to match the type > for the `starts` and `ends` inputs. Moreover, the number of indices in > the `steps` input also has to match those in the `axes` inputs (instad > of matching the dimensions of the `data` input). > > This patch fixes the checks in the TorchOnnxToTorch conversion so that > they match the ONNX spec.pull/2881/head
parent
4df96616db
commit
21f070e95f
|
@ -308,12 +308,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "kernel list size does not match the number of axes");
|
binder.op, "kernel list size does not match the number of axes");
|
||||||
}
|
}
|
||||||
if (binder.s64IntegerArrayAttr(padding, "pads", {0})) {
|
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
|
||||||
|
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (padding.size() != 1 && padding.size() != rank - 2) {
|
if (padding.size() != 2 * (rank - 2)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "padding list size does not match the number of axes");
|
binder.op,
|
||||||
|
"padding list size does not match twice the number of axes");
|
||||||
}
|
}
|
||||||
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
|
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value data, pads, constantValue, axes;
|
Value data, pads, axes;
|
||||||
std::string mode;
|
std::string mode;
|
||||||
|
|
||||||
// TODO: The `axes` parameter is not supported yet.
|
// TODO: The `axes` parameter is not supported yet.
|
||||||
|
@ -871,12 +871,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
binder.tensorOperandAtIndex(pads, 1) ||
|
binder.tensorOperandAtIndex(pads, 1) ||
|
||||||
binder.tensorOperandAtIndex(constantValue, 2) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType) ||
|
||||||
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = binder.getLoc();
|
Location loc = binder.getLoc();
|
||||||
|
|
||||||
|
Value constantValue;
|
||||||
|
if (binder.getNumOperands() >= 3) {
|
||||||
|
if (binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||||
|
llvm::errs() << "failed to bind to index 2\n";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
||||||
|
|
||||||
|
auto maybeZeroAttr = [&]() -> std::optional<Attribute> {
|
||||||
|
if (dataTensorType.getDtype().isa<IntegerType>()) {
|
||||||
|
return rewriter.getI64IntegerAttr(0);
|
||||||
|
}
|
||||||
|
if (dataTensorType.getDtype().isa<FloatType>()) {
|
||||||
|
return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f);
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}();
|
||||||
|
|
||||||
|
if (!maybeZeroAttr) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "expected integer or float data tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shapedType = dataTensorType.toBuiltinTensor();
|
||||||
|
auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr);
|
||||||
|
constantValue = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||||
|
loc, dataTensorType, splat);
|
||||||
|
}
|
||||||
|
|
||||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||||
// tensor.
|
// tensor.
|
||||||
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
||||||
|
|
|
@ -1531,18 +1531,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The default axes value is the range from 0 to the number of
|
// The default axes value is the range from 0 to the size of first
|
||||||
// dimensions
|
// dimension of `starts` and `ends`.
|
||||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
auto defaultAxesType = Torch::ValueTensorType::get(
|
|
||||||
context, ArrayRef<int64_t>{operandTy.getRank()},
|
|
||||||
rewriter.getIntegerType(64, /*signed*/ 1));
|
|
||||||
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
|
Value arangeLength = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getType<Torch::IntType>(),
|
loc, rewriter.getType<Torch::IntType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
|
||||||
operandTy.getRank()));
|
|
||||||
axes = rewriter.create<Torch::AtenArangeOp>(
|
axes = rewriter.create<Torch::AtenArangeOp>(
|
||||||
loc, defaultAxesType, arangeLength, none, none, none, none);
|
loc, startsTorchTy, arangeLength, none, none, none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binding `steps` from its arguments or through a default value
|
// Binding `steps` from its arguments or through a default value
|
||||||
|
@ -1553,22 +1549,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The default `steps` value is a 1d tensor filled with ones with a
|
// The default `steps` value is a 1d tensor filled with ones with a
|
||||||
// size of the dimension of the operand
|
// size equal to the size of `starts` and `ends`.
|
||||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
auto defaultStepsType = Torch::ValueTensorType::get(
|
|
||||||
context, ArrayRef<int64_t>{operandTy.getRank()},
|
|
||||||
rewriter.getIntegerType(64, /*signed*/ 1));
|
|
||||||
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
|
Value sizeStepInput = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getType<Torch::IntType>(),
|
loc, rewriter.getType<Torch::IntType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize));
|
||||||
operandTy.getRank()));
|
|
||||||
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
|
Value sizeStepsInput = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
loc,
|
loc,
|
||||||
Torch::ListType::get(
|
Torch::ListType::get(
|
||||||
Torch::IntType::get(binder.op->getContext())),
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
sizeStepInput);
|
sizeStepInput);
|
||||||
steps = rewriter.create<Torch::AtenOnesOp>(
|
steps = rewriter.create<Torch::AtenOnesOp>(
|
||||||
loc, defaultStepsType, sizeStepsInput, none, none, none, none);
|
loc, startsTorchTy, sizeStepsInput, none, none, none, none);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
|
if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 &&
|
||||||
|
|
|
@ -699,13 +699,25 @@ func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !to
|
||||||
|
|
||||||
// CHECK-LABEL: @test_averagepool_3d_default
|
// CHECK-LABEL: @test_averagepool_3d_default
|
||||||
func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false_2, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32>
|
// CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false{{.*}}, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||||
%0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
|
%0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||||
return %0 : !torch.vtensor<[1,3,31,31,31],f32>
|
return %0 : !torch.vtensor<[1,3,31,31,31],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_averagepool_with_padding
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,20,64,48],f32>
|
||||||
|
// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,20,64,48],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,20,32,24],f32>
|
||||||
|
|
||||||
|
func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
|
||||||
|
%0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,20,32,24],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_conv_with_strides_no_padding
|
// CHECK-LABEL: @test_conv_with_strides_no_padding
|
||||||
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
|
|
@ -447,6 +447,24 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pad_optional_constant
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[SEVEN:.*]] = torch.constant.int 7
|
||||||
|
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||||
|
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
|
|
||||||
|
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_pow
|
// CHECK-LABEL: func.func @test_pow
|
||||||
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
|
|
@ -1205,6 +1205,31 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_slice_default_axes_and_steps
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[20,10,5],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1],si64>,
|
||||||
|
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64>
|
||||||
|
|
||||||
|
// CHECK: %[[ZERO0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[ZERO1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32>
|
||||||
|
|
||||||
|
func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
|
%0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[20,10,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_slice_default_steps
|
// CHECK-LABEL: func.func @test_slice_default_steps
|
||||||
func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
//CHECK: %[[NONE:.*]] = torch.constant.none
|
//CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
|
Loading…
Reference in New Issue