mirror of https://github.com/llvm/torch-mlir
parent
261074f594
commit
1d4859699b
|
@ -185,6 +185,12 @@ namespace {
|
|||
|
||||
template <typename T> struct DimensionTraits {};
|
||||
|
||||
template <> struct DimensionTraits<AtenMaxPool1dOp> {
|
||||
static constexpr int64_t Dim = 1;
|
||||
// unused const variable warning suppression:
|
||||
static_assert(Dim == Dim);
|
||||
};
|
||||
|
||||
template <> struct DimensionTraits<AtenMaxPool2dOp> {
|
||||
static constexpr int64_t Dim = 2;
|
||||
// unused const variable warning suppression:
|
||||
|
@ -328,7 +334,24 @@ public:
|
|||
|
||||
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||
|
||||
if constexpr (Dim == 2) {
|
||||
if constexpr (Dim == 1) {
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
Value maxPool1d, paddedInput;
|
||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getInf(
|
||||
cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||
/*Negative=*/true));
|
||||
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
|
||||
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
|
||||
paddedInput, maxPool1d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
|
||||
Type newResultType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool1d);
|
||||
return success();
|
||||
} else if constexpr (Dim == 2) {
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
Value maxPool2d, paddedInput;
|
||||
|
@ -1090,8 +1113,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenMaxPool1dOp>();
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
target.addIllegalOp<AtenMaxPool3dOp>();
|
||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);
|
||||
|
||||
|
|
|
@ -10481,6 +10481,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -1078,6 +1078,8 @@ STABLEHLO_PASS_SET = {
|
|||
"Matmul_matvec",
|
||||
"Matmul_vecmat",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
|
@ -1905,6 +1907,9 @@ MAKE_FX_TOSA_PASS_SET = (
|
|||
TOSA_PASS_SET
|
||||
| {
|
||||
### Tests additionally passing in make_fx_tosa
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
|
@ -2361,6 +2366,11 @@ ONNX_XFAIL_SET = {
|
|||
"LinalgNormKeepDimComplexModule_basic",
|
||||
"LinalgVectorNormComplexModule_basic",
|
||||
"LogSoftmaxBackwardModule_basic",
|
||||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||
|
|
|
@ -2612,6 +2612,11 @@ def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], kernel_size=[2]))
|
||||
def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
|
||||
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -157,6 +157,126 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxPool1dModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp1d = torch.nn.MaxPool1d(
|
||||
kernel_size=[6], stride=[2], padding=[3], dilation=2
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.mp1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool1dModule())
|
||||
def MaxPool1dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, low=-1))
|
||||
|
||||
|
||||
class MaxPool1dEmptyStrideStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 1, 20], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool1d(x, kernel_size=2, stride=[])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool1dEmptyStrideStaticModule())
|
||||
def MaxPool1dEmptyStrideStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, low=-1))
|
||||
|
||||
|
||||
class MaxPool1dStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp1d = torch.nn.MaxPool1d(
|
||||
kernel_size=[3], stride=[2], padding=[1], dilation=[1]
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 64, 112], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.mp1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool1dStaticModule())
|
||||
def MaxPool1dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 64, 112))
|
||||
|
||||
|
||||
class MaxPool1dStaticCeilModeTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp1d = torch.nn.MaxPool1d(
|
||||
kernel_size=[3], stride=[2], padding=[1], dilation=[1], ceil_mode=True
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 64, 112], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.mp1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool1dStaticCeilModeTrueModule())
|
||||
def MaxPool1dStaticCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 64, 112))
|
||||
|
||||
|
||||
class MaxPool1dCeilModeTrueModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp1d = torch.nn.MaxPool1d(
|
||||
kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=True
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.mp1d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool1dCeilModeTrueModule())
|
||||
def MaxPool1dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, low=0.5, high=1.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,5 +1,27 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @forward_max_pool1d
|
||||
func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%int4 = torch.constant.int 4
|
||||
%false = torch.constant.bool false
|
||||
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
|
||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
|
||||
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor<?xf32>
|
||||
// CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?xf32>, tensor<?xf32>) outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
%kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
|
||||
%padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
|
||||
%dilation = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list<int>
|
||||
%4 = torch.aten.max_pool1d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
|
||||
return %4 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @forward_max_pool2d
|
||||
func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
|
|
Loading…
Reference in New Issue