mirror of https://github.com/llvm/torch-mlir
[Stablehlo]Add support for AvgPool1dOp (#2268)
* Add support for AvgPool1d * Update AbstractInterpLibrary * support avgpool1d in linalg * refactored code * fix nit problempull/2294/head
parent
d57f67e7f8
commit
31ef08b63d
|
@ -595,6 +595,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ViewOffsetBackwardTestStaticModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"AvgPool1dStaticModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||
"Convolution2DStaticModule_basic",
|
||||
|
|
|
@ -5045,6 +5045,34 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$kernel_size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
Torch_BoolType:$ceil_mode,
|
||||
Torch_BoolType:$count_include_pad
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenAvgPool1dOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -39,7 +39,7 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
|||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
SmallVector<Value, 2> kernelSizeTorchInt;
|
||||
SmallVector<Value> kernelSizeTorchInt;
|
||||
if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: the kernel size is "
|
||||
|
@ -72,12 +72,13 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
// Creates a pooling operation based on the type specified by `OpTy` and
|
||||
// arguments passed.
|
||||
template <typename OpTy>
|
||||
static LogicalResult createPoolingOp(
|
||||
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
||||
bool supportNonFPInput, bool ceilMode,
|
||||
bool supportNonFPInput, bool ceilMode, int64_t dimensionality,
|
||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
||||
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
||||
|
@ -87,13 +88,16 @@ static LogicalResult createPoolingOp(
|
|||
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
|
||||
return op->emitError("unimplemented: non-floating point type");
|
||||
|
||||
SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
|
||||
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||
lowPaddingIncludingNC.append(paddingInts);
|
||||
SmallVector<int64_t, 4> highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||
|
||||
if (ceilMode) {
|
||||
highPaddingIncludingNC[2] += strideInts[0];
|
||||
highPaddingIncludingNC[3] += strideInts[1];
|
||||
for (int64_t i = 0; i < dimensionality; ++i) {
|
||||
highPaddingIncludingNC[i + 2] += strideInts[i];
|
||||
}
|
||||
}
|
||||
|
||||
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
||||
paddedInput = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
|
||||
|
@ -101,8 +105,6 @@ static LogicalResult createPoolingOp(
|
|||
|
||||
Value N = getDimOp(rewriter, loc, self, 0);
|
||||
Value C = getDimOp(rewriter, loc, self, 1);
|
||||
Value H = getDimOp(rewriter, loc, self, 2);
|
||||
Value W = getDimOp(rewriter, loc, self, 3);
|
||||
|
||||
SmallVector<Value> paddingIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
|
@ -111,15 +113,17 @@ static LogicalResult createPoolingOp(
|
|||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
Value hOut = torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
|
||||
kernelSizeIntValues[0], strideIntValues[0], ceilMode);
|
||||
Value wOut = torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
|
||||
kernelSizeIntValues[1], strideIntValues[1], ceilMode);
|
||||
// Get dimension size for each dimension and calculate output size
|
||||
for (int64_t i = dimensionality - 1; i > -1; --i) {
|
||||
Value dimSize = getDimOp(rewriter, loc, self, i + 2);
|
||||
Value outDim = torch_to_linalg::getOutputDimForConvOps(
|
||||
rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i],
|
||||
kernelSizeIntValues[i], strideIntValues[i], ceilMode);
|
||||
outTensorShape.insert(outTensorShape.begin(), {outDim});
|
||||
}
|
||||
|
||||
// Create output tensor initialized with smallest floating point value.
|
||||
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut});
|
||||
outTensorShape.insert(outTensorShape.begin(), {N, C});
|
||||
Value outTensorInitialized =
|
||||
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
|
||||
|
||||
|
@ -138,6 +142,7 @@ static LogicalResult createPoolingOp(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
||||
public:
|
||||
|
@ -177,8 +182,9 @@ public:
|
|||
Value maxPool2d, paddedInput;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
||||
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
|
||||
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
|
||||
maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
||||
|
@ -253,8 +259,9 @@ public:
|
|||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
||||
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
|
||||
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
|
||||
maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
|
||||
Value cstMinusOne =
|
||||
|
@ -366,29 +373,32 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
|
||||
template <typename OpTy, typename PoolingOpTy, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
Type inputElementType =
|
||||
self.getType().cast<RankedTensorType>().getElementType();
|
||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
||||
Type resultType = typeConverter->convertType(op.getType());
|
||||
Type resultElementType =
|
||||
resultType.cast<RankedTensorType>().getElementType();
|
||||
|
||||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1};
|
||||
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
|
||||
SmallVector<Value, Dim> kernelSizeIntValues;
|
||||
SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts(Dim, 1);
|
||||
if (failed(checkAndGetPoolingParameters<OpTy>(
|
||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
@ -404,34 +414,36 @@ public:
|
|||
op, "unimplemented: count_include_pad is expected to be true");
|
||||
}
|
||||
|
||||
// `sumPool2d` contains the result of sumpool2d operation over the input.
|
||||
Value sumPool2d, paddedInput;
|
||||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
|
||||
// `sumPool` contains the result of sumpool operation over the input.
|
||||
Value sumPool, paddedInput;
|
||||
SmallVector<Value, Dim+2> outTensorShape;
|
||||
if (failed(createPoolingOp<PoolingOpTy>(
|
||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
|
||||
sumPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
|
||||
|
||||
/*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts,
|
||||
dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape,
|
||||
paddedInput, sumPool)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
|
||||
Value divisor;
|
||||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||
Value divisor = op.getDivisorOverride().getType().isa<Torch::NoneType>()
|
||||
divisor = op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
||||
? kHtimeskW
|
||||
: adaptor.getDivisorOverride();
|
||||
} else {
|
||||
divisor = kernelSizeIntValues[0];
|
||||
}
|
||||
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
||||
|
||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
||||
SmallVector<AffineMap> indexingMapsAvg(2,
|
||||
rewriter.getMultiDimIdentityMap(4));
|
||||
SmallVector<AffineMap> indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2));
|
||||
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
||||
4, utils::IteratorType::parallel);
|
||||
|
||||
Value avgPool2d =
|
||||
Dim+2, utils::IteratorType::parallel);
|
||||
Value avgPool =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputTensor.getType(), sumPool2d, outputTensor,
|
||||
loc, outputTensor.getType(), sumPool, outputTensor,
|
||||
/*indexingMaps=*/indexingMapsAvg,
|
||||
/*iteratorTypes=*/iteratorTypesAvg,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
|
@ -444,11 +456,12 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
}
|
||||
|
||||
|
||||
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
@ -458,6 +471,9 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
|
@ -35,7 +35,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
|
@ -373,24 +373,31 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenAvgPool2dOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT, int Dim>
|
||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
RankedTensorType inputTy = input.getType().cast<RankedTensorType>();
|
||||
Type inputElemTy = inputTy.getElementType();
|
||||
int64_t inputRank = inputTy.getRank();
|
||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto outShape = outTy.getShape();
|
||||
|
||||
if (inputRank <= 2) {
|
||||
|
||||
if (inputRank <= Dim) {
|
||||
return op.emitError(
|
||||
"avg_pooling2d only supports inputs with rank higher than 2");
|
||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||
}
|
||||
SmallVector<int64_t, 2> padding, kernelSize, stride;
|
||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||
bool ceilMode = false;
|
||||
bool countIncludePad = true;
|
||||
|
||||
|
@ -415,26 +422,33 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const bool count_include_pad unsupported!");
|
||||
}
|
||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) {
|
||||
|
||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only None divisor_override supported for now!");
|
||||
}
|
||||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
// Prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||
// as input
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
stablehloStride.begin() + inputRank - Dim);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
stablehloKernelSize.begin() + inputRank - Dim);
|
||||
if (Dim == 1) {
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||
} else {
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
}
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
|
@ -467,23 +481,30 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = sumBlock.args_begin();
|
||||
auto secondArg = sumBlock.args_rbegin();
|
||||
auto firstArg = *sumBlock.args_begin();
|
||||
auto secondArg = *sumBlock.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
// Use kernel size as the divisor
|
||||
if (countIncludePad) {
|
||||
Value divisor = hlo::getConstTensor<int64_t>(
|
||||
Value divisor;
|
||||
if (Dim == 1) {
|
||||
divisor =
|
||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||
.value();
|
||||
} else {
|
||||
divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
}
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
|
@ -491,12 +512,12 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Use another stablehlo.ReduceWindowOp to get the divisor
|
||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst =
|
||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||
const auto &options = getOptions();
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
auto inputShapeVec =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
|
@ -519,23 +540,28 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
firstArg = sizeBlock.args_begin();
|
||||
secondArg = sizeBlock.args_rbegin();
|
||||
firstArg = *sizeBlock.args_begin();
|
||||
secondArg = *sizeBlock.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||
return success();
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
// AtenCumsumOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||
|
@ -621,6 +647,8 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
|||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
|
@ -630,4 +658,11 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
|||
context, options);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>( \
|
||||
typeConverter, context, options)
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||
}
|
||||
|
|
|
@ -6839,6 +6839,102 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %int-3 = torch.constant.int -3\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %10 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %11 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %13 = torch.aten.eq.int %12, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %14 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %15 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %16 = torch.aten.eq.int %15, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
|
||||
" %24 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %24 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %18 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %19 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%19, %2, %11, %8, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n"
|
||||
" %21 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %23 = torch.prim.If %22 -> (!torch.list<int>) {\n"
|
||||
" %24 = torch.prim.ListConstruct %18, %20 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %24 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %24 = torch.prim.ListConstruct %17, %18, %20 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %24 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %23 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -8150,6 +8246,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -526,6 +526,37 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
|||
else:
|
||||
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||
|
||||
# TODO: This should be upstreamed.
|
||||
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
||||
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
||||
kL = kernel_size[0]
|
||||
|
||||
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
||||
dL = kL if len(stride) == 0 else stride[0]
|
||||
|
||||
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
||||
padL = padding[0]
|
||||
|
||||
dilationL = 1
|
||||
|
||||
assert len(input) == 2 or len(input) == 3
|
||||
|
||||
nbatch = input[-3] if len(input) == 3 else 1
|
||||
nInputPlane = input[-2]
|
||||
inputLength = input[-1]
|
||||
|
||||
outputLength = upstream_shape_functions.pooling_output_shape(
|
||||
inputLength, kL, padL, dL, dilationL, ceil_mode)
|
||||
|
||||
if len(input) == 2:
|
||||
return [nInputPlane, outputLength]
|
||||
else:
|
||||
return [nbatch, nInputPlane, outputLength]
|
||||
|
||||
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||
|
||||
def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
|
||||
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||
|
||||
|
@ -1362,6 +1393,11 @@ def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return torch.float32
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
|
||||
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
|
||||
def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -402,6 +402,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit(
|
||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -700,3 +700,75 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
|
||||
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AvgPool1dFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||
stride=2,
|
||||
padding=3,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap1d(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool1dFloatModule())
|
||||
def AvgPool1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 20, low=-1))
|
||||
|
||||
|
||||
class AvgPool1dIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||
stride=2,
|
||||
padding=3,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap1d(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool1dIntModule())
|
||||
def AvgPool1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 4, 20, high=100))
|
||||
|
||||
|
||||
class AvgPool1dStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||
stride=2,
|
||||
padding=3,
|
||||
ceil_mode=False,
|
||||
count_include_pad=True)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 4, 20], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap1d(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: AvgPool1dStaticModule())
|
||||
def AvgPool1dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 4, 20, high=100))
|
Loading…
Reference in New Issue