Adds onnx ConvTranspose support for autopadding. (#3797)

Adds onnx ConvTranspose support for autopadding
(https://github.com/nod-ai/SHARK-ModelDev/issues/839).

- Adds support for attribute auto_pad="SAME_UPPER" or "SAME_LOWER" which
will automatically calculate padding of input based on output shape.
- Adds support, during auto-padding, for output_shape=[H,W] which
overrides the default output shape of input_shape[i]*stride[i] (for
spatial dimensions only).
- Adds lit test for auto-padding.
- Tests are added by https://github.com/nod-ai/SHARK-TestSuite/pull/370


NOTE: ConvTranspose still doesn't support asymmetric padding, therefore
multiple original onnx tests still won't pass.
pull/3806/head
David Tanner 2024-10-18 13:31:33 -04:00 committed by GitHub
parent 9c7067649b
commit 02327af998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 131 additions and 22 deletions

View File

@ -1690,20 +1690,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::string autoPad; std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure(); return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}
SmallVector<int64_t> outputShape;
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
return failure();
if (outputShape.size()) {
// TODO: Add support for non-None output_shape value.
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: output_shape should be absent");
}
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value input, weight; Value input, weight;
int64_t group; int64_t group;
@ -1737,6 +1723,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
} }
} }
} }
} else {
for (unsigned i = 0; i < weightShape.size() - 2; i++) {
kernelShape.push_back(weightShape[i + 2]);
}
} }
// Determine the rank of input tensor. // Determine the rank of input tensor.
@ -1746,7 +1736,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"Unimplemented: unranked tensor"); "Unimplemented: unranked tensor");
unsigned rank = *maybeRank; unsigned rank = *maybeRank;
SmallVector<int64_t> padding, strides, dilations, outputPadding; SmallVector<int64_t> padding, strides, dilations, outputPadding,
outputShape;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations,
defaultOutputPadding; defaultOutputPadding;
for (unsigned i = 0; i < rank - 2; i++) { for (unsigned i = 0; i < rank - 2; i++) {
@ -1762,13 +1753,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at // at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i. // the end of axis i.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
return failure();
}
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
}
if (binder.s64IntegerArrayAttr(dilations, "dilations", if (binder.s64IntegerArrayAttr(dilations, "dilations",
defaultDilations)) { defaultDilations)) {
return failure(); return failure();
@ -1794,7 +1778,60 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, binder.op,
"output_padding list size does not match the number of axes"); "output_padding list size does not match the number of axes");
} }
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
binder.op, "Expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
if (autoPad == "VALID") {
// Zero padding.
padding = defaultPadding;
} else if (autoPad == "NOTSET") {
// Explicit padding; read pads with defaults.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding))
return failure();
} else { // autopad == SAME_UPPER or SAME_LOWER
// Auto-padding; output_shape defaults to input_shape * strides.
SmallVector<int64_t> defaultOutputShape;
for (unsigned i = 0; i < rank - 2; i++) {
defaultOutputShape.push_back(inputShape[2 + i] * strides[i]);
}
if (binder.s64IntegerArrayAttr(outputShape, "output_shape",
defaultOutputShape))
return failure();
SmallVector<int64_t> paddingEnd;
for (unsigned i = 0; i < rank - 2; i++) {
int64_t totalPadding =
strides[i] * (inputShape[2 + i] - 1) + outputPadding[i] +
((kernelShape[i] - 1) * dilations[i] + 1) - outputShape[i];
if (totalPadding % 2) {
// TODO: Add support for different padding values for the
// beginning and ending along each spatial axis.
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: the combination of stride, "
"input_shape, kernel_shape, dilation, output_padding and "
"output_shape caused auto-padding to produce asymmetric "
"padding which isn't currently supported.");
}
int64_t half = totalPadding / 2;
int64_t remainder = totalPadding - half;
if (autoPad == "SAME_UPPER") {
padding.push_back(half);
paddingEnd.push_back(remainder);
} else {
padding.push_back(remainder);
paddingEnd.push_back(half);
}
}
padding.insert(padding.end(), paddingEnd.begin(), paddingEnd.end());
}
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
}
SmallVector<Value> cstPadding, cstStrides, cstDilations, SmallVector<Value> cstPadding, cstStrides, cstDilations,
cstOutputPadding; cstOutputPadding;
if (padding.size() != 2 * (rank - 2)) { if (padding.size() != 2 * (rank - 2)) {

View File

@ -1329,6 +1329,78 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc
// ----- // -----
// CHECK-LABEL: @test_convtranspose_autopad_same_upper
func.func @test_convtranspose_autopad_same_upper(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_3:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,6,6],f32>
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_UPPER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32>
return %4 : !torch.vtensor<[1,2,6,6],f32>
}
// -----
// CHECK-LABEL: @test_convtranspose_autopad_same_lower
func.func @test_convtranspose_autopad_same_lower(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_3:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,6,6],f32>
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_LOWER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32>
return %4 : !torch.vtensor<[1,2,6,6],f32>
}
// -----
// CHECK-LABEL: @test_convtranspose_autopad_valid
func.func @test_convtranspose_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_2:.*]] = torch.constant.int 2
// CHECK: %[[C0_3:.*]] = torch.constant.int 0
// CHECK: %[[C0_4:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_3]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,8,8],f32>
%4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32>
return %4 : !torch.vtensor<[1,2,8,8],f32>
}
// -----
// CHECK-LABEL: @test_batchnorm_epsilon // CHECK-LABEL: @test_batchnorm_epsilon
func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false