[Stablehlo] Improve the lowering of pool op in stablehlo (#3259)

1. Handle case stride == None
2. add avgpool3d maxpool1d  maxpool3d lowering
pull/3269/head
Xinyu Yang 2024-05-01 00:06:13 +08:00 committed by GitHub
parent fb8aed0907
commit f32ada993d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 216 additions and 122 deletions

View File

@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
}]; }];
} }
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$ceil_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
auto constType = RankedTensorType::get({}, elementTy); auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling // Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) { AtenAvgPool3dOp, AtenCumsumOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
} }
// Max pooling // Max pooling
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) { if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
AtenMaxPool2dWithIndicesOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) { if (isa<mlir::FloatType>(elementTy)) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr; return nullptr;
} }
// AtenMaxPool2dOp
template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
if (inputRank <= 2) {
return op.emitError(
"max_pooling2d only supports inputs with rank higher than 2");
}
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
// AtenMaxPool2dWithIndicesOp // AtenMaxPool2dWithIndicesOp
template <> template <>
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
return success(); return success();
} }
namespace {
template <typename AtenOpT, int Dim>
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
auto inputTy = cast<RankedTensorType>(input.getType());
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto outTy = cast<RankedTensorType>(
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
if (inputRank <= Dim) {
return op.emitError(
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
}
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
bool ceilMode = false;
if (!(matchPattern(op.getKernelSize(),
m_TorchListOfConstantInts(kernelSize)))) {
return rewriter.notifyMatchFailure(
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op,
"non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilation)))) {
return rewriter.notifyMatchFailure(op,
"non-const int dilation unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(
op, "non-const bool ceil_mode unsupported!");
}
if (stride.empty()) {
stride = kernelSize;
}
// prepend 1 to kernelSize, stride, dilation until they are of same rank
// as input
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - Dim);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - Dim);
std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - Dim);
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
}
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
DenseI64ArrayAttr baseDilations;
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &block = reduceWindowOp.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
block.addArgument(blockArgumentType, op->getLoc());
block.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = block.args_begin();
auto secondArg = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
*secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
};
} // namespace
namespace { namespace {
template <typename AtenOpT, int Dim> template <typename AtenOpT, int Dim>
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> { class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
@ -375,8 +404,8 @@ public:
auto outShape = outTy.getShape(); auto outShape = outTy.getShape();
if (inputRank <= Dim) { if (inputRank <= Dim) {
return op.emitError( return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
"avg_pooling1d/2d only supports inputs with rank higher than 1/2"); "higher than 1/2/3");
} }
SmallVector<int64_t, Dim> padding, kernelSize, stride; SmallVector<int64_t, Dim> padding, kernelSize, stride;
bool ceilMode = false; bool ceilMode = false;
@ -405,6 +434,10 @@ public:
op, "non-const bool count_include_pad unsupported!"); op, "non-const bool count_include_pad unsupported!");
} }
if (stride.empty()) {
stride = kernelSize;
}
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) { if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -425,11 +458,20 @@ public:
if (Dim == 1) { if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0]; stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else { } else if (Dim == 2) {
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
} else if (Dim == 3) {
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
} else {
assert(false && "Unsupported pooling dimension");
} }
Value initVal = Value initVal =
@ -474,10 +516,17 @@ public:
divisor = divisor =
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {}) hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
.value(); .value();
} else { } else if (Dim == 2) {
divisor = hlo::getConstTensor<int64_t>( divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value(); .value();
} else if (Dim == 3) {
divisor = hlo::getConstTensor<int64_t>(
rewriter, op,
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
.value();
} else {
assert(false && "Unsupported pooling dimension");
} }
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseI64ArrayAttr bcastDimensions; DenseI64ArrayAttr bcastDimensions;
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenAvgPool1dOp>(); #define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options); target.addIllegalOp<AtenOp>(); \
target.addIllegalOp<AtenMaxPool2dOp>(); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options); INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
target.addIllegalOp<AtenAvgPool2dOp>(); INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options); #undef INSERT_ATEN_POOLING_PATTERN
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter, #define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
context, options); target.addIllegalOp<AtenOp>(); \
target.addIllegalOp<AtenCumsumOp>(); patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options); options)
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
#undef INSERT_ATEN_MAXPOOL_PATTERN
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ #define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \ patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
options) options)
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
#undef INSERT_ATEN_AVGPOOL_PATTERN #undef INSERT_ATEN_AVGPOOL_PATTERN
} }

View File

@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n" " %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n" " func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n" " %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n" " %int-2 = torch.constant.int -2\n"
" %int-3 = torch.constant.int -3\n" " %int-3 = torch.constant.int -3\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" " %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" " %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" " %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
" %int1 = torch.constant.int 1\n" " %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n" " %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n" " %int2 = torch.constant.int 2\n"
@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %23 : !torch.list<int>\n" " return %23 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"

View File

@ -1075,6 +1075,9 @@ STABLEHLO_PASS_SET = {
"Matmul_vecmat", "Matmul_vecmat",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MeanDimAllReduceModule_basic", "MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic", "MeanDimNoneDimModule_basic",

View File

@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
# TODO: This should be upstreamed. # TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example. # See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
kL = kernel_size[0] kL = kernel_size[0]
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
dL = kL if len(stride) == 0 else stride[0] dL = kL if len(stride) == 0 else stride[0]
assert len(padding) == 1, "avg_pool1d: padding must be a single int" assert len(padding) == 1, "pool1d: padding must be a single int"
padL = padding[0] padL = padding[0]
dilationL = 1 dilationL = 1
@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
return shape return shape
def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenmax_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
return pool1d(self, kernel_size, stride, padding, ceil_mode)
def atenadaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: def atenadaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
return adaptive_avg_pool1d(self, output_size) return adaptive_avg_pool1d(self, output_size)

View File

@ -591,6 +591,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
) )
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit( emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"