mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for `auto_pad` in `onnx.Conv` (#3670)
Add logic for `auto_pad` attribute in the conversion of `onnx.Conv` torch dialect. Add lit tests covering different configurations of `auto_pad`.pull/3699/head
parent
b5d95ff399
commit
b35675a78e
|
@ -1292,14 +1292,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
std::string autoPad;
|
|
||||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
|
||||||
return failure();
|
|
||||||
if (autoPad != "NOTSET") {
|
|
||||||
// TODO: Add support for `auto_pad` != "NOTSET"
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
|
||||||
}
|
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value input, weight;
|
Value input, weight;
|
||||||
int64_t group;
|
int64_t group;
|
||||||
|
@ -1349,20 +1341,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
defaultStrides.push_back(1);
|
defaultStrides.push_back(1);
|
||||||
defaultDilations.push_back(1);
|
defaultDilations.push_back(1);
|
||||||
}
|
}
|
||||||
// Padding for the beginning and ending along each spatial axis, it can
|
|
||||||
// take any value greater than or equal to 0. The value represent the
|
|
||||||
// number of pixels added to the beginning and end part of the
|
|
||||||
// corresponding axis. pads format should be as follow [x1_begin,
|
|
||||||
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
|
|
||||||
// at the beginning of axis i and xi_end, the number of pixels added at
|
|
||||||
// the end of axis i.
|
|
||||||
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "padding list size does not match the number of axes");
|
|
||||||
}
|
|
||||||
if (binder.s64IntegerArrayAttr(dilations, "dilations",
|
if (binder.s64IntegerArrayAttr(dilations, "dilations",
|
||||||
defaultDilations)) {
|
defaultDilations)) {
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1379,6 +1357,46 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "strides list size does not match the number of axes");
|
binder.op, "strides list size does not match the number of axes");
|
||||||
}
|
}
|
||||||
|
std::string autoPad;
|
||||||
|
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||||
|
return failure();
|
||||||
|
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
|
||||||
|
// Padding for the beginning and ending along each spatial axis, it can
|
||||||
|
// take any value greater than or equal to 0. The value represent the
|
||||||
|
// number of pixels added to the beginning and end part of the
|
||||||
|
// corresponding axis. pads format should be as follow [x1_begin,
|
||||||
|
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
|
||||||
|
// at the beginning of axis i and xi_end, the number of pixels added at
|
||||||
|
// the end of axis i.
|
||||||
|
if (autoPad == "NOTSET") {
|
||||||
|
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
} else if (autoPad == "VALID") {
|
||||||
|
padding = defaultPadding;
|
||||||
|
} else {
|
||||||
|
const bool isSameLower = autoPad == "SAME_LOWER";
|
||||||
|
const unsigned spatialRank = rank - 2;
|
||||||
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
|
padding.resize_for_overwrite(2 * spatialRank);
|
||||||
|
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
|
||||||
|
const int64_t dilatedKernelSize =
|
||||||
|
dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1;
|
||||||
|
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
|
||||||
|
strides[dimIdx] -
|
||||||
|
1) *
|
||||||
|
strides[dimIdx] +
|
||||||
|
dilatedKernelSize - inputShape[dimIdx + 2];
|
||||||
|
totalPad = totalPad >= 0 ? totalPad : 0;
|
||||||
|
padding[dimIdx] =
|
||||||
|
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
|
||||||
|
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "padding list size does not match the number of axes");
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
||||||
cstOutputPadding;
|
cstOutputPadding;
|
||||||
|
@ -1452,8 +1470,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||||
binder.getLoc(), rewriter.getStringAttr("constant"));
|
binder.getLoc(), rewriter.getStringAttr("constant"));
|
||||||
Value constantValue;
|
Value constantValue;
|
||||||
auto inputTensorType =
|
|
||||||
cast<Torch::ValueTensorType>(input.getType());
|
|
||||||
if (isa<IntegerType>(inputTensorType.getDtype()))
|
if (isa<IntegerType>(inputTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
|
|
@ -1062,6 +1062,93 @@ func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_conv_with_autopad
|
||||||
|
func.func @test_conv_with_autopad(%arg0: !torch.vtensor<[1,1,12,7],f32>, %arg1: !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[C1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,3,3],f32>
|
||||||
|
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 3 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 3 : si64]} : (!torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,3,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_conv_with_autopad_asymmetric
|
||||||
|
func.func @test_conv_with_autopad_asymmetric(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[str:.*]] = torch.constant.str "constant"
|
||||||
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
|
||||||
|
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
// CHECK: return %[[Conv]]
|
||||||
|
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_conv_with_autopad_asymmetric_lower
|
||||||
|
func.func @test_conv_with_autopad_asymmetric_lower(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int2]], %[[int1]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[str:.*]] = torch.constant.str "constant"
|
||||||
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
|
||||||
|
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[C4_0:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
// CHECK: return %[[Conv]]
|
||||||
|
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_conv_with_bias_strides_padding
|
// CHECK-LABEL: @test_conv_with_bias_strides_padding
|
||||||
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
|
Loading…
Reference in New Issue