MaxPool1d lowering to linalg (#3295)

Co-authored-by: root <root@i32b01216.sqa.eu95>
pull/3328/head
NeverRaR 2024-05-11 00:35:26 +08:00 committed by GitHub
parent 261074f594
commit 1d4859699b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 187 additions and 1 deletions

View File

@ -185,6 +185,12 @@ namespace {
template <typename T> struct DimensionTraits {};
template <> struct DimensionTraits<AtenMaxPool1dOp> {
static constexpr int64_t Dim = 1;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};
template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
@ -328,7 +334,24 @@ public:
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
if constexpr (Dim == 2) {
if constexpr (Dim == 1) {
SmallVector<Value, 4> outTensorShape;
Value maxPool1d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddedInput, maxPool1d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool1d);
return success();
} else if constexpr (Dim == 2) {
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
@ -1090,8 +1113,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool1dOp>();
target.addIllegalOp<AtenMaxPool2dOp>();
target.addIllegalOp<AtenMaxPool3dOp>();
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);

View File

@ -10481,6 +10481,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -1078,6 +1078,8 @@ STABLEHLO_PASS_SET = {
"Matmul_matvec",
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool1dStaticModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
@ -1905,6 +1907,9 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
@ -2361,6 +2366,11 @@ ONNX_XFAIL_SET = {
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",

View File

@ -2612,6 +2612,11 @@ def atenmasked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], kernel_size=[2]))
def atenmax_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def atenmax_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -157,6 +157,126 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================
class MaxPool1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[6], stride=[2], padding=[3], dilation=2
)
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)
@register_test_case(module_factory=lambda: MaxPool1dModule())
def MaxPool1dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=-1))
class MaxPool1dEmptyStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([1, 1, 20], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.max_pool1d(x, kernel_size=2, stride=[])
@register_test_case(module_factory=lambda: MaxPool1dEmptyStrideStaticModule())
def MaxPool1dEmptyStrideStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=-1))
class MaxPool1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[3], stride=[2], padding=[1], dilation=[1]
)
@export
@annotate_args(
[
None,
([1, 64, 112], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)
@register_test_case(module_factory=lambda: MaxPool1dStaticModule())
def MaxPool1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112))
class MaxPool1dStaticCeilModeTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[3], stride=[2], padding=[1], dilation=[1], ceil_mode=True
)
@export
@annotate_args(
[
None,
([1, 64, 112], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)
@register_test_case(module_factory=lambda: MaxPool1dStaticCeilModeTrueModule())
def MaxPool1dStaticCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112))
class MaxPool1dCeilModeTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp1d = torch.nn.MaxPool1d(
kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True
)
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return self.mp1d(x)
@register_test_case(module_factory=lambda: MaxPool1dCeilModeTrueModule())
def MaxPool1dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, low=0.5, high=1.0))
# ==============================================================================
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -1,5 +1,27 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @forward_max_pool1d
func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor<?xf32>
// CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?xf32>, tensor<?xf32>) outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
%4 = torch.aten.max_pool1d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?],f32>
}
// CHECK-LABEL: func @forward_max_pool2d
func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1