From f32ada993d393581ae1e70ac6b47dbdd4a70dca1 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 1 May 2024 00:06:13 +0800 Subject: [PATCH] [Stablehlo] Improve the lowering of pool op in stablehlo (#3259) 1. Handle case stride == None 2. add avgpool3d maxpool1d maxpool3d lowering --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 ++ lib/Conversion/TorchToStablehlo/Pooling.cpp | 279 +++++++++++------- .../Transforms/AbstractInterpLibrary.cpp | 14 +- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 13 +- .../build_tools/torch_ods_gen.py | 1 + 6 files changed, 216 insertions(+), 122 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8ebd7b162..cb08ffd53 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ }]; } +def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenMaxPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 132410a2a..9219b4af3 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, auto constType = RankedTensorType::get({}, elementTy); // Avg pooling if (isa(op)) { + AtenAvgPool3dOp, AtenCumsumOp>(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, } // Max pooling - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } -// AtenMaxPool2dOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMaxPool2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - auto inputElemTy = inputTy.getElementType(); - - auto inputRank = inputTy.getRank(); - auto outTy = - cast(getTypeConverter()->convertType(op.getType())); - - if (inputRank <= 2) { - return op.emitError( - "max_pooling2d only supports inputs with rank higher than 2"); - } - SmallVector padding, kernelSize, stride, dilation; - bool ceilMode = false; - - if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); - } - if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); - } - if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); - } - if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { - return rewriter.notifyMatchFailure(op, - "non-const int dilation unsupported!"); - } - if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); - } - - // prepend 1 to kernelSize, stride, dilation until they are of same rank as - // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); - std::copy(dilation.begin(), dilation.end(), - stablehloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); - std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); - - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - - auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); - auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); - DenseI64ArrayAttr baseDilations; - auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); - DenseIntElementsAttr pad = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(inputRank), static_cast(2)}, - rewriter.getI64Type()), - stablehloPadding); - auto reduceWindowOp = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); - - Block &block = reduceWindowOp.getBody().emplaceBlock(); - - auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArg = block.args_begin(); - auto secondArg = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value result = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), result); - } - - rewriter.replaceOp(op, reduceWindowOp.getResults()); - return success(); -} - // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +namespace { +template +class ConvertAtenMaxPoolOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto inputRank = inputTy.getRank(); + auto outTy = cast( + ConvertAtenOp::getTypeConverter()->convertType(op.getType())); + + if (inputRank <= Dim) { + return op.emitError( + "max_pooling1d/2d only supports inputs with rank higher than 1/2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); + } + + if (stride.empty()) { + stride = kernelSize; + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank + // as input + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + stablehloDilation.begin() + inputRank - Dim); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - Dim); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - Dim); + + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + if (Dim == 1) { + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + } else if (Dim == 2) { + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); + } + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); + + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.getBody().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + block.addArgument(blockArgumentType, op->getLoc()); + block.addArgument(blockArgumentType, op->getLoc()); + auto *firstArg = block.args_begin(); + auto secondArg = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value result = rewriter.create(op->getLoc(), *firstArg, + *secondArg); + rewriter.create(op->getLoc(), result); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -375,8 +404,8 @@ public: auto outShape = outTy.getShape(); if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank " + "higher than 1/2/3"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; @@ -405,6 +434,10 @@ public: op, "non-const bool count_include_pad unsupported!"); } + if (stride.empty()) { + stride = kernelSize; + } + if constexpr (std::is_same()) { if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) return rewriter.notifyMatchFailure( @@ -425,11 +458,20 @@ public: if (Dim == 1) { stablehloPadding[stablehloPadding.size() - 2] = padding[0]; stablehloPadding[stablehloPadding.size() - 1] = padding[0]; - } else { + } else if (Dim == 2) { stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); } Value initVal = @@ -474,10 +516,17 @@ public: divisor = hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) .value(); - } else { + } else if (Dim == 2) { divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); + } else if (Dim == 3) { + divisor = hlo::getConstTensor( + rewriter, op, + {kernelSize[0] * kernelSize[1] * kernelSize[2]}, {}) + .value(); + } else { + assert(false && "Unsupported pooling dimension"); } divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); DenseI64ArrayAttr bcastDimensions; @@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, - context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); +#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); + INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); +#undef INSERT_ATEN_POOLING_PATTERN + +#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3); +#undef INSERT_ATEN_MAXPOOL_PATTERN + #define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); + INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3); #undef INSERT_ATEN_AVGPOOL_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 553a8dc74..d9ac7a6d0 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.avg_pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" func.func @__torch__.pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool) -> !torch.list {\n" " %int-1 = torch.constant.int -1\n" " %int-2 = torch.constant.int -2\n" " %int-3 = torch.constant.int -3\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" -" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" +" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n" +" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" +" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n" " %int1 = torch.constant.int 1\n" " %int0 = torch.constant.int 0\n" " %int2 = torch.constant.int 2\n" @@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %23 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657..8ffe8d1c3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1075,6 +1075,9 @@ STABLEHLO_PASS_SET = { "Matmul_vecmat", "MatmulStaticBroadcast_basic", "MaxPool2dStaticModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool3dStaticModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", "MeanDimNoneDimModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index da486fe46..eb6062056 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): - assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" +def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool): + assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int" kL = kernel_size[0] - assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" + assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int" dL = kL if len(stride) == 0 else stride[0] - assert len(padding) == 1, "avg_pool1d: padding must be a single int" + assert len(padding) == 1, "pool1d: padding must be a single int" padL = padding[0] dilationL = 1 @@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]): return shape def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: - return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) + return pool1d(self, kernel_size, stride, padding, ceil_mode) + +def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]: + return pool1d(self, kernel_size, stride, padding, ceil_mode) def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: return adaptive_avg_pool1d(self, output_size) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index eea8d31a9..e0329c8df 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -591,6 +591,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) + emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"