[Stablehlo]Add support for AvgPool1dOp (#2268)

* Add support for AvgPool1d

* Update AbstractInterpLibrary

* support avgpool1d in linalg

* refactored code

* fix nit problem
pull/2294/head
JianzheXiao 2023-07-25 14:09:53 +08:00 committed by GitHub
parent d57f67e7f8
commit 31ef08b63d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 489 additions and 198 deletions

View File

@ -595,6 +595,7 @@ STABLEHLO_PASS_SET = {
"ViewOffsetBackwardTestStaticModule_basic", "ViewOffsetBackwardTestStaticModule_basic",
"NumToTensorFloatModule_basic", "NumToTensorFloatModule_basic",
"AtenToDeviceModule_basic", "AtenToDeviceModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dStaticModule_basic", "AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"Convolution2DStaticModule_basic", "Convolution2DStaticModule_basic",

View File

@ -5045,6 +5045,34 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
}]; }];
} }
def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
Torch_BoolType:$ceil_mode,
Torch_BoolType:$count_include_pad
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenAvgPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -39,7 +39,7 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
// Pattern match against the op's original operands, because otherwise we // Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern // will get the lowered version of the operands which is harder to pattern
// match. // match.
SmallVector<Value, 2> kernelSizeTorchInt; SmallVector<Value> kernelSizeTorchInt;
if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) { if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"unimplemented: the kernel size is " "unimplemented: the kernel size is "
@ -72,12 +72,13 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
return success(); return success();
} }
// Creates a pooling operation based on the type specified by `OpTy` and // Creates a pooling operation based on the type specified by `OpTy` and
// arguments passed. // arguments passed.
template <typename OpTy> template <typename OpTy>
static LogicalResult createPoolingOp( static LogicalResult createPoolingOp(
Operation *op, ConversionPatternRewriter &rewriter, Value self, Operation *op, ConversionPatternRewriter &rewriter, Value self,
bool supportNonFPInput, bool ceilMode, bool supportNonFPInput, bool ceilMode, int64_t dimensionality,
SmallVectorImpl<Value> &kernelSizeIntValues, SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts, SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr, SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
@ -87,13 +88,16 @@ static LogicalResult createPoolingOp(
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput) if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
return op->emitError("unimplemented: non-floating point type"); return op->emitError("unimplemented: non-floating point type");
SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0}; SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
lowPaddingIncludingNC.append(paddingInts); lowPaddingIncludingNC.append(paddingInts);
SmallVector<int64_t, 4> highPaddingIncludingNC = lowPaddingIncludingNC; SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
if (ceilMode) { if (ceilMode) {
highPaddingIncludingNC[2] += strideInts[0]; for (int64_t i = 0; i < dimensionality; ++i) {
highPaddingIncludingNC[3] += strideInts[1]; highPaddingIncludingNC[i + 2] += strideInts[i];
} }
}
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr)); Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
paddedInput = torch_to_linalg::getPaddedTensor( paddedInput = torch_to_linalg::getPaddedTensor(
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
@ -101,8 +105,6 @@ static LogicalResult createPoolingOp(
Value N = getDimOp(rewriter, loc, self, 0); Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1); Value C = getDimOp(rewriter, loc, self, 1);
Value H = getDimOp(rewriter, loc, self, 2);
Value W = getDimOp(rewriter, loc, self, 3);
SmallVector<Value> paddingIntValues = SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts); getAsConstantIntValues(rewriter, loc, paddingInts);
@ -111,15 +113,17 @@ static LogicalResult createPoolingOp(
SmallVector<Value> strideIntValues = SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts); getAsConstantIntValues(rewriter, loc, strideInts);
Value hOut = torch_to_linalg::getOutputDimForConvOps( // Get dimension size for each dimension and calculate output size
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0], for (int64_t i = dimensionality - 1; i > -1; --i) {
kernelSizeIntValues[0], strideIntValues[0], ceilMode); Value dimSize = getDimOp(rewriter, loc, self, i + 2);
Value wOut = torch_to_linalg::getOutputDimForConvOps( Value outDim = torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1], rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i],
kernelSizeIntValues[1], strideIntValues[1], ceilMode); kernelSizeIntValues[i], strideIntValues[i], ceilMode);
outTensorShape.insert(outTensorShape.begin(), {outDim});
}
// Create output tensor initialized with smallest floating point value. // Create output tensor initialized with smallest floating point value.
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut}); outTensorShape.insert(outTensorShape.begin(), {N, C});
Value outTensorInitialized = Value outTensorInitialized =
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
@ -138,6 +142,7 @@ static LogicalResult createPoolingOp(
return success(); return success();
} }
namespace { namespace {
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> { class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
public: public:
@ -177,8 +182,9 @@ public:
Value maxPool2d, paddedInput; Value maxPool2d, paddedInput;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>( if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
kernelSizeIntValues, strideInts, paddingInts, dilationInts, /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
@ -253,8 +259,9 @@ public:
SmallVector<Value, 4> outTensorShape; SmallVector<Value, 4> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>( if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
kernelSizeIntValues, strideInts, paddingInts, dilationInts, /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Value cstMinusOne = Value cstMinusOne =
@ -366,29 +373,32 @@ public:
}; };
} // namespace } // namespace
namespace { namespace {
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> { template <typename OpTy, typename PoolingOpTy, int Dim>
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor, matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
Location loc = op->getLoc(); Location loc = op->getLoc();
TypeConverter *typeConverter = getTypeConverter(); TypeConverter *typeConverter = this->getTypeConverter();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Type inputElementType = Type inputElementType =
self.getType().cast<RankedTensorType>().getElementType(); self.getType().cast<RankedTensorType>().getElementType();
Type resultType = getTypeConverter()->convertType(op.getType()); Type resultType = typeConverter->convertType(op.getType());
Type resultElementType = Type resultElementType =
resultType.cast<RankedTensorType>().getElementType(); resultType.cast<RankedTensorType>().getElementType();
bool ceilMode; bool ceilMode;
SmallVector<Value, 2> kernelSizeIntValues; SmallVector<Value, Dim> kernelSizeIntValues;
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1}; SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts(Dim, 1);
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>( if (failed(checkAndGetPoolingParameters<OpTy>(
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
strideInts, paddingInts))) strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
@ -404,34 +414,36 @@ public:
op, "unimplemented: count_include_pad is expected to be true"); op, "unimplemented: count_include_pad is expected to be true");
} }
// `sumPool2d` contains the result of sumpool2d operation over the input. // `sumPool` contains the result of sumpool operation over the input.
Value sumPool2d, paddedInput; Value sumPool, paddedInput;
SmallVector<Value, 4> outTensorShape; SmallVector<Value, Dim+2> outTensorShape;
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>( if (failed(createPoolingOp<PoolingOpTy>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
kernelSizeIntValues, strideInts, paddingInts, dilationInts, /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts,
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape,
sumPool2d))) paddedInput, sumPool)))
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d"); return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
Value divisor;
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
Value kHtimeskW = rewriter.create<arith::MulIOp>( Value kHtimeskW = rewriter.create<arith::MulIOp>(
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
Value divisor = op.getDivisorOverride().getType().isa<Torch::NoneType>() divisor = op.getDivisorOverride().getType().template isa<Torch::NoneType>()
? kHtimeskW ? kHtimeskW
: adaptor.getDivisorOverride(); : adaptor.getDivisorOverride();
} else {
divisor = kernelSizeIntValues[0];
}
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
Value outputTensor = rewriter.create<tensor::EmptyOp>( Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outTensorShape), resultElementType); loc, getAsOpFoldResult(outTensorShape), resultElementType);
SmallVector<AffineMap> indexingMapsAvg(2, SmallVector<AffineMap> indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2));
rewriter.getMultiDimIdentityMap(4));
SmallVector<utils::IteratorType> iteratorTypesAvg( SmallVector<utils::IteratorType> iteratorTypesAvg(
4, utils::IteratorType::parallel); Dim+2, utils::IteratorType::parallel);
Value avgPool =
Value avgPool2d =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, outputTensor.getType(), sumPool2d, outputTensor, loc, outputTensor.getType(), sumPool, outputTensor,
/*indexingMaps=*/indexingMapsAvg, /*indexingMaps=*/indexingMapsAvg,
/*iteratorTypes=*/iteratorTypesAvg, /*iteratorTypes=*/iteratorTypesAvg,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
@ -444,11 +456,12 @@ public:
}) })
.getResult(0); .getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
return success(); return success();
} }
}; };
} // namespace }
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
@ -458,6 +471,9 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context); patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>(); target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context); patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>(); target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context); patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
typeConverter, context);
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
typeConverter, context);
} }

View File

@ -16,13 +16,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
@ -35,7 +35,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy); auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling // Avg pooling
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) { if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
@ -373,24 +373,31 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
return success(); return success();
} }
// AtenAvgPool2dOp
template <> namespace {
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite( template <typename AtenOpT, int Dim>
AtenAvgPool2dOp op, OpAdaptor adaptor, class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
ConversionPatternRewriter &rewriter) const { public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>(); RankedTensorType inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType(); Type inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank(); int64_t inputRank = inputTy.getRank();
auto outTy = RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); ->convertType(op.getType())
.template cast<RankedTensorType>();
auto outShape = outTy.getShape(); auto outShape = outTy.getShape();
if (inputRank <= 2) {
if (inputRank <= Dim) {
return op.emitError( return op.emitError(
"avg_pooling2d only supports inputs with rank higher than 2"); "avg_pooling1d/2d only supports inputs with rank higher than 1/2");
} }
SmallVector<int64_t, 2> padding, kernelSize, stride; SmallVector<int64_t, Dim> padding, kernelSize, stride;
bool ceilMode = false; bool ceilMode = false;
bool countIncludePad = true; bool countIncludePad = true;
@ -415,26 +422,33 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "non-const bool count_include_pad unsupported!"); op, "non-const bool count_include_pad unsupported!");
} }
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) {
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only None divisor_override supported for now!"); op, "only None divisor_override supported for now!");
} }
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // Prepend 1 to kernelSize, stride, dilation until they are of same rank
// input // as input
SmallVector<int64_t> stablehloStride(inputRank, 1); SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1); SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1); SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0); SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2); stablehloStride.begin() + inputRank - Dim);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2); stablehloKernelSize.begin() + inputRank - Dim);
if (Dim == 1) {
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
} else {
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1];
}
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
@ -467,23 +481,30 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
auto blockArgumentType = RankedTensorType::get({}, inputElemTy); auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
sumBlock.addArgument(blockArgumentType, op->getLoc()); sumBlock.addArgument(blockArgumentType, op->getLoc());
sumBlock.addArgument(blockArgumentType, op->getLoc()); sumBlock.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = sumBlock.args_begin(); auto firstArg = *sumBlock.args_begin();
auto secondArg = sumBlock.args_rbegin(); auto secondArg = *sumBlock.args_rbegin();
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sumBlock); rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult = Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
} }
// Use kernel size as the divisor // Use kernel size as the divisor
if (countIncludePad) { if (countIncludePad) {
Value divisor = hlo::getConstTensor<int64_t>( Value divisor;
if (Dim == 1) {
divisor =
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
.value();
} else {
divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value(); .value();
}
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>( rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
@ -491,12 +512,12 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
return success(); return success();
} }
// Use another stablehlo.ReduceWindowOp to get the divisor // Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst = Value windowSizeConst =
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value(); hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = windowSizeConst =
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
const auto &options = getOptions(); const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
auto inputShapeVec = auto inputShapeVec =
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
@ -519,23 +540,28 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
blockArgumentType = RankedTensorType::get({}, inputElemTy); blockArgumentType = RankedTensorType::get({}, inputElemTy);
sizeBlock.addArgument(blockArgumentType, op->getLoc()); sizeBlock.addArgument(blockArgumentType, op->getLoc());
sizeBlock.addArgument(blockArgumentType, op->getLoc()); sizeBlock.addArgument(blockArgumentType, op->getLoc());
firstArg = sizeBlock.args_begin(); firstArg = *sizeBlock.args_begin();
secondArg = sizeBlock.args_rbegin(); secondArg = *sizeBlock.args_rbegin();
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sizeBlock); rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult = Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
} }
rewriter.replaceOpWithNewOp<stablehlo::DivOp>( rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success(); return success();
}
};
} }
// AtenCumsumOp // AtenCumsumOp
template <> template <>
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
@ -621,6 +647,8 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenAvgPool1dOp>();
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenMaxPool2dOp>(); target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options); patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenAvgPool2dOp>(); target.addIllegalOp<AtenAvgPool2dOp>();
@ -630,4 +658,11 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
context, options); context, options);
target.addIllegalOp<AtenCumsumOp>(); target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options); patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>( \
typeConverter, context, options)
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
#undef INSERT_ATEN_AVGPOOL_PATTERN
} }

View File

@ -6839,6 +6839,102 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %int-2 = torch.constant.int -2\n"
" %int-3 = torch.constant.int -3\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %24 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %25 : !torch.bool\n"
" }\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %24 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" }\n"
" %9 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %13 = torch.aten.eq.int %12, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %25 : !torch.bool\n"
" }\n"
" torch.prim.If %14 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %15 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %16 = torch.aten.eq.int %15, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
" %24 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %24 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" }\n"
" %18 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%19, %2, %11, %8, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n"
" %21 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" %23 = torch.prim.If %22 -> (!torch.list<int>) {\n"
" %24 = torch.prim.ListConstruct %18, %20 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %24 : !torch.list<int>\n"
" } else {\n"
" %24 = torch.prim.ListConstruct %17, %18, %20 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %24 : !torch.list<int>\n"
" }\n"
" return %23 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n" " %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -8150,6 +8246,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -526,6 +526,37 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
else: else:
return [nbatch, nInputPlane, outputHeight, outputWidth] return [nbatch, nInputPlane, outputHeight, outputWidth]
# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
kL = kernel_size[0]
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
dL = kL if len(stride) == 0 else stride[0]
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
padL = padding[0]
dilationL = 1
assert len(input) == 2 or len(input) == 3
nbatch = input[-3] if len(input) == 3 else 1
nInputPlane = input[-2]
inputLength = input[-1]
outputLength = upstream_shape_functions.pooling_output_shape(
inputLength, kL, padL, dL, dilationL, ceil_mode)
if len(input) == 2:
return [nInputPlane, outputLength]
else:
return [nbatch, nInputPlane, outputLength]
def atenavg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
def atenavg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: def atenavg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
@ -1362,6 +1393,11 @@ def atenabs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.float32 return torch.float32
return self_dtype return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
def atenavg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
def atenadaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: def atenadaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -402,6 +402,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)" "aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
) )
emit(
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
)
emit( emit(
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)" "aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
) )

View File

@ -700,3 +700,75 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule()) @register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
# ==============================================================================
class AvgPool1dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
stride=2,
padding=3,
ceil_mode=False,
count_include_pad=True)
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.ap1d(x)
@register_test_case(module_factory=lambda: AvgPool1dFloatModule())
def AvgPool1dFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 20, low=-1))
class AvgPool1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
stride=2,
padding=3,
ceil_mode=False,
count_include_pad=True)
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, x):
return self.ap1d(x)
@register_test_case(module_factory=lambda: AvgPool1dIntModule())
def AvgPool1dIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 4, 20, high=100))
class AvgPool1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
stride=2,
padding=3,
ceil_mode=False,
count_include_pad=True)
@export
@annotate_args([
None,
([2, 4, 20], torch.int64, True),
])
def forward(self, x):
return self.ap1d(x)
@register_test_case(module_factory=lambda: AvgPool1dStaticModule())
def AvgPool1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(2, 4, 20, high=100))