mirror of https://github.com/llvm/torch-mlir
Added support for Maxpool (Autopad) (#3774)
Added autopad. and passed 3 tests test_maxpool_2d_precomputed_same_upper test_maxpool_2d_same_lower' test_maxpool_2d_same_upper Address : https://github.com/nod-ai/SHARK-ModelDev/issues/843 2 attributes yet to complete : storage_order, indices outputpull/3814/head
parent
2f9a68cc1e
commit
d6feb2179c
|
@ -1087,9 +1087,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"auto_pad bind failure");
|
||||
if (autoPad != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
|
||||
Torch::ValueTensorType resultTypeOut;
|
||||
Value operand;
|
||||
|
@ -1136,6 +1133,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"dilations bind failure");
|
||||
|
||||
// set default padding
|
||||
if (padding.empty())
|
||||
padding.resize(spatial, 0);
|
||||
if (strides.empty())
|
||||
|
@ -1143,6 +1141,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (dilations.empty())
|
||||
dilations.resize(spatial, 1);
|
||||
|
||||
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||
|
||||
// Padding for the beginning and ending along each spatial axis, it can
|
||||
// take any value greater than or equal to 0. The value represent the
|
||||
// number of pixels added to the beginning and end part of the
|
||||
// corresponding axis. pads format should be as follow [x1_begin,
|
||||
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
|
||||
// at the beginning of axis i and xi_end, the number of pixels added at
|
||||
// the end of axis i.
|
||||
if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
const bool isSameLower = autoPad == "SAME_LOWER";
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
padding.resize_for_overwrite(2 * spatial);
|
||||
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
|
||||
const int64_t dilatedKernelSize =
|
||||
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
|
||||
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
|
||||
strides[dimIdx] -
|
||||
1) *
|
||||
strides[dimIdx] +
|
||||
dilatedKernelSize - inputShape[dimIdx + 2];
|
||||
totalPad = totalPad >= 0 ? totalPad : 0;
|
||||
padding[dimIdx] =
|
||||
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
|
||||
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
|
||||
}
|
||||
}
|
||||
|
||||
// If the padding is symmetric we can push the padding operation to the
|
||||
// torch operator.
|
||||
if (padding.size() == static_cast<size_t>(2 * spatial)) {
|
||||
|
|
|
@ -730,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch
|
|||
return %0 : !torch.vtensor<[1,64,56,56],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_same_lower
|
||||
func.func @test_maxpool_2d_same_lower(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} {
|
||||
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
|
||||
// CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int0_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1_5:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_6:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1_7:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_8:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32>
|
||||
return %0 : !torch.vtensor<[1,3,32,32],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_same_upper
|
||||
func.func @test_maxpool_2d_same_upper(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308
|
||||
// CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,33,33],f32>
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int0_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1_5:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_6:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1_7:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_8:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,3,32,32],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32>
|
||||
return %0 : !torch.vtensor<[1,3,32,32],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper
|
||||
func.func @test_maxpool_2d_precomputed_same_upper(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64}{
|
||||
// CHECK: %[[int3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[int3_0:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int2_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[int1_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[int1_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,1,3,3],f32>
|
||||
%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32>
|
||||
return %0 : !torch.vtensor<[1,1,3,3],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue