mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Improve the lowering of pool op in stablehlo (#3259)
1. Handle case stride == None 2. add avgpool3d maxpool1d maxpool3d loweringpull/3269/head
parent
fb8aed0907
commit
f32ada993d
|
@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenMaxPool1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||
AtenCumsumOp>(op)) {
|
||||
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
|
@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
}
|
||||
|
||||
// Max pooling
|
||||
if (isa<AtenMaxPool2dOp, AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
||||
AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
if (isa<mlir::FloatType>(elementTy)) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
|
@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// AtenMaxPool2dOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (inputRank <= 2) {
|
||||
return op.emitError(
|
||||
"max_pooling2d only supports inputs with rank higher than 2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.getKernelSize(),
|
||||
m_TorchListOfConstantInts(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemTy);
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArg = block.args_begin();
|
||||
auto secondArg = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result =
|
||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
|
@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT, int Dim>
|
||||
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy = cast<RankedTensorType>(
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
return op.emitError(
|
||||
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||
}
|
||||
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
|
||||
bool ceilMode = false;
|
||||
|
||||
if (!(matchPattern(op.getKernelSize(),
|
||||
m_TorchListOfConstantInts(kernelSize)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const int kernel size unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int stride unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int padding unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilation)))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const int dilation unsupported!");
|
||||
}
|
||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const bool ceil_mode unsupported!");
|
||||
}
|
||||
|
||||
if (stride.empty()) {
|
||||
stride = kernelSize;
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||
// as input
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloDilation.begin() + inputRank - Dim);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - Dim);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - Dim);
|
||||
|
||||
Value initVal =
|
||||
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
if (Dim == 1) {
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
} else if (Dim == 2) {
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
} else if (Dim == 3) {
|
||||
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
||||
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
||||
DenseI64ArrayAttr baseDilations;
|
||||
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
||||
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
block.addArgument(blockArgumentType, op->getLoc());
|
||||
block.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = block.args_begin();
|
||||
auto secondArg = block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
|
||||
*secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||
|
@ -375,8 +404,8 @@ public:
|
|||
auto outShape = outTy.getShape();
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
return op.emitError(
|
||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
|
||||
"higher than 1/2/3");
|
||||
}
|
||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||
bool ceilMode = false;
|
||||
|
@ -405,6 +434,10 @@ public:
|
|||
op, "non-const bool count_include_pad unsupported!");
|
||||
}
|
||||
|
||||
if (stride.empty()) {
|
||||
stride = kernelSize;
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -425,11 +458,20 @@ public:
|
|||
if (Dim == 1) {
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
} else {
|
||||
} else if (Dim == 2) {
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
} else if (Dim == 3) {
|
||||
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
|
||||
Value initVal =
|
||||
|
@ -474,10 +516,17 @@ public:
|
|||
divisor =
|
||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||
.value();
|
||||
} else {
|
||||
} else if (Dim == 2) {
|
||||
divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
} else if (Dim == 3) {
|
||||
divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op,
|
||||
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
|
||||
.value();
|
||||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
|
@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context, options);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
||||
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
||||
#undef INSERT_ATEN_POOLING_PATTERN
|
||||
|
||||
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
|
||||
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
|
||||
#undef INSERT_ATEN_MAXPOOL_PATTERN
|
||||
|
||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||
options)
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
|
||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||
}
|
||||
|
|
|
@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
||||
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" func.func @__torch__.pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %int-3 = torch.constant.int -3\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
|
@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %23 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -1075,6 +1075,9 @@ STABLEHLO_PASS_SET = {
|
|||
"Matmul_vecmat",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
"MeanDimAllReduceModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
|
|
|
@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
|||
|
||||
# TODO: This should be upstreamed.
|
||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
||||
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
||||
def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool):
|
||||
assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int"
|
||||
kL = kernel_size[0]
|
||||
|
||||
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
||||
assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int"
|
||||
dL = kL if len(stride) == 0 else stride[0]
|
||||
|
||||
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
||||
assert len(padding) == 1, "pool1d: padding must be a single int"
|
||||
padL = padding[0]
|
||||
|
||||
dilationL = 1
|
||||
|
@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]):
|
|||
return shape
|
||||
|
||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
|
||||
return pool1d(self, kernel_size, stride, padding, ceil_mode)
|
||||
|
||||
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
|
||||
return adaptive_avg_pool1d(self, output_size)
|
||||
|
|
|
@ -591,6 +591,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
|
||||
|
|
Loading…
Reference in New Issue