mirror of https://github.com/llvm/torch-mlir
[Stablehlo]Add support for AvgPool1dOp (#2268)
* Add support for AvgPool1d * Update AbstractInterpLibrary * support avgpool1d in linalg * refactored code * fix nit problempull/2294/head
parent
d57f67e7f8
commit
31ef08b63d
|
@ -595,6 +595,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ViewOffsetBackwardTestStaticModule_basic",
|
"ViewOffsetBackwardTestStaticModule_basic",
|
||||||
"NumToTensorFloatModule_basic",
|
"NumToTensorFloatModule_basic",
|
||||||
"AtenToDeviceModule_basic",
|
"AtenToDeviceModule_basic",
|
||||||
|
"AvgPool1dStaticModule_basic",
|
||||||
"AvgPool2dStaticModule_basic",
|
"AvgPool2dStaticModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
"Convolution2DStaticModule_basic",
|
"Convolution2DStaticModule_basic",
|
||||||
|
|
|
@ -5045,6 +5045,34 @@ def Torch_AtenAvgPool2dOp : Torch_Op<"aten.avg_pool2d", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenAvgPool1dOp : Torch_Op<"aten.avg_pool1d", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$kernel_size,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchListOfTorchIntType:$padding,
|
||||||
|
Torch_BoolType:$ceil_mode,
|
||||||
|
Torch_BoolType:$count_include_pad
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenAvgPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenAvgPool1dOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -39,7 +39,7 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
||||||
// Pattern match against the op's original operands, because otherwise we
|
// Pattern match against the op's original operands, because otherwise we
|
||||||
// will get the lowered version of the operands which is harder to pattern
|
// will get the lowered version of the operands which is harder to pattern
|
||||||
// match.
|
// match.
|
||||||
SmallVector<Value, 2> kernelSizeTorchInt;
|
SmallVector<Value> kernelSizeTorchInt;
|
||||||
if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) {
|
if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"unimplemented: the kernel size is "
|
"unimplemented: the kernel size is "
|
||||||
|
@ -72,12 +72,13 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Creates a pooling operation based on the type specified by `OpTy` and
|
// Creates a pooling operation based on the type specified by `OpTy` and
|
||||||
// arguments passed.
|
// arguments passed.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
static LogicalResult createPoolingOp(
|
static LogicalResult createPoolingOp(
|
||||||
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
||||||
bool supportNonFPInput, bool ceilMode,
|
bool supportNonFPInput, bool ceilMode, int64_t dimensionality,
|
||||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||||
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
||||||
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
||||||
|
@ -87,22 +88,23 @@ static LogicalResult createPoolingOp(
|
||||||
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
|
if (!elementType.isa<mlir::FloatType>() && !supportNonFPInput)
|
||||||
return op->emitError("unimplemented: non-floating point type");
|
return op->emitError("unimplemented: non-floating point type");
|
||||||
|
|
||||||
SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
|
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||||
lowPaddingIncludingNC.append(paddingInts);
|
lowPaddingIncludingNC.append(paddingInts);
|
||||||
SmallVector<int64_t, 4> highPaddingIncludingNC = lowPaddingIncludingNC;
|
SmallVector<int64_t> highPaddingIncludingNC = lowPaddingIncludingNC;
|
||||||
|
|
||||||
if (ceilMode) {
|
if (ceilMode) {
|
||||||
highPaddingIncludingNC[2] += strideInts[0];
|
for (int64_t i = 0; i < dimensionality; ++i) {
|
||||||
highPaddingIncludingNC[3] += strideInts[1];
|
highPaddingIncludingNC[i + 2] += strideInts[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
Value initValue = rewriter.create<arith::ConstantOp>(loc, cast<TypedAttr>(initValueAttr));
|
||||||
paddedInput = torch_to_linalg::getPaddedTensor(
|
paddedInput = torch_to_linalg::getPaddedTensor(
|
||||||
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
|
op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC,
|
||||||
initValue);
|
initValue);
|
||||||
|
|
||||||
Value N = getDimOp(rewriter, loc, self, 0);
|
Value N = getDimOp(rewriter, loc, self, 0);
|
||||||
Value C = getDimOp(rewriter, loc, self, 1);
|
Value C = getDimOp(rewriter, loc, self, 1);
|
||||||
Value H = getDimOp(rewriter, loc, self, 2);
|
|
||||||
Value W = getDimOp(rewriter, loc, self, 3);
|
|
||||||
|
|
||||||
SmallVector<Value> paddingIntValues =
|
SmallVector<Value> paddingIntValues =
|
||||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||||
|
@ -111,15 +113,17 @@ static LogicalResult createPoolingOp(
|
||||||
SmallVector<Value> strideIntValues =
|
SmallVector<Value> strideIntValues =
|
||||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||||
|
|
||||||
Value hOut = torch_to_linalg::getOutputDimForConvOps(
|
// Get dimension size for each dimension and calculate output size
|
||||||
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
|
for (int64_t i = dimensionality - 1; i > -1; --i) {
|
||||||
kernelSizeIntValues[0], strideIntValues[0], ceilMode);
|
Value dimSize = getDimOp(rewriter, loc, self, i + 2);
|
||||||
Value wOut = torch_to_linalg::getOutputDimForConvOps(
|
Value outDim = torch_to_linalg::getOutputDimForConvOps(
|
||||||
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
|
rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i],
|
||||||
kernelSizeIntValues[1], strideIntValues[1], ceilMode);
|
kernelSizeIntValues[i], strideIntValues[i], ceilMode);
|
||||||
|
outTensorShape.insert(outTensorShape.begin(), {outDim});
|
||||||
|
}
|
||||||
|
|
||||||
// Create output tensor initialized with smallest floating point value.
|
// Create output tensor initialized with smallest floating point value.
|
||||||
outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut});
|
outTensorShape.insert(outTensorShape.begin(), {N, C});
|
||||||
Value outTensorInitialized =
|
Value outTensorInitialized =
|
||||||
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
|
createInitTensor(rewriter, loc, outTensorShape, elementType, initValue);
|
||||||
|
|
||||||
|
@ -138,6 +142,7 @@ static LogicalResult createPoolingOp(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -177,8 +182,9 @@ public:
|
||||||
Value maxPool2d, paddedInput;
|
Value maxPool2d, paddedInput;
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
||||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
|
||||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
|
||||||
|
maxPool2d)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
||||||
|
@ -253,8 +259,9 @@ public:
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, 4> outTensorShape;
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/false, ceilMode,
|
||||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
/*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts,
|
||||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
dilationInts, smallestFPValueAttr, outTensorShape, paddedInput,
|
||||||
|
maxPool2d)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||||
|
|
||||||
Value cstMinusOne =
|
Value cstMinusOne =
|
||||||
|
@ -366,29 +373,32 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
|
template <typename OpTy, typename PoolingOpTy, int Dim>
|
||||||
|
class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
TypeConverter *typeConverter = getTypeConverter();
|
TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
|
|
||||||
Type inputElementType =
|
Type inputElementType =
|
||||||
self.getType().cast<RankedTensorType>().getElementType();
|
self.getType().cast<RankedTensorType>().getElementType();
|
||||||
Type resultType = getTypeConverter()->convertType(op.getType());
|
Type resultType = typeConverter->convertType(op.getType());
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
resultType.cast<RankedTensorType>().getElementType();
|
resultType.cast<RankedTensorType>().getElementType();
|
||||||
|
|
||||||
bool ceilMode;
|
bool ceilMode;
|
||||||
SmallVector<Value, 2> kernelSizeIntValues;
|
SmallVector<Value, Dim> kernelSizeIntValues;
|
||||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1};
|
SmallVector<int64_t, Dim> strideInts, paddingInts, dilationInts(Dim, 1);
|
||||||
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
|
if (failed(checkAndGetPoolingParameters<OpTy>(
|
||||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||||
strideInts, paddingInts)))
|
strideInts, paddingInts)))
|
||||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||||
|
@ -404,34 +414,36 @@ public:
|
||||||
op, "unimplemented: count_include_pad is expected to be true");
|
op, "unimplemented: count_include_pad is expected to be true");
|
||||||
}
|
}
|
||||||
|
|
||||||
// `sumPool2d` contains the result of sumpool2d operation over the input.
|
// `sumPool` contains the result of sumpool operation over the input.
|
||||||
Value sumPool2d, paddedInput;
|
Value sumPool, paddedInput;
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, Dim+2> outTensorShape;
|
||||||
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
|
if (failed(createPoolingOp<PoolingOpTy>(
|
||||||
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
|
||||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
/*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts,
|
||||||
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
|
dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape,
|
||||||
sumPool2d)))
|
paddedInput, sumPool)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
|
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
|
||||||
|
Value divisor;
|
||||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||||
Value divisor = op.getDivisorOverride().getType().isa<Torch::NoneType>()
|
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||||
? kHtimeskW
|
divisor = op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
||||||
: adaptor.getDivisorOverride();
|
? kHtimeskW
|
||||||
|
: adaptor.getDivisorOverride();
|
||||||
|
} else {
|
||||||
|
divisor = kernelSizeIntValues[0];
|
||||||
|
}
|
||||||
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
||||||
|
|
||||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
||||||
SmallVector<AffineMap> indexingMapsAvg(2,
|
SmallVector<AffineMap> indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2));
|
||||||
rewriter.getMultiDimIdentityMap(4));
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
||||||
4, utils::IteratorType::parallel);
|
Dim+2, utils::IteratorType::parallel);
|
||||||
|
Value avgPool =
|
||||||
Value avgPool2d =
|
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
loc, outputTensor.getType(), sumPool2d, outputTensor,
|
loc, outputTensor.getType(), sumPool, outputTensor,
|
||||||
/*indexingMaps=*/indexingMapsAvg,
|
/*indexingMaps=*/indexingMapsAvg,
|
||||||
/*iteratorTypes=*/iteratorTypesAvg,
|
/*iteratorTypes=*/iteratorTypesAvg,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
@ -444,11 +456,12 @@ public:
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool2d);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
}
|
||||||
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
@ -458,6 +471,9 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp>();
|
||||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool1dOp, linalg::PoolingNcwSumOp, 1>>(
|
||||||
|
typeConverter, context);
|
||||||
|
patterns.add<ConvertAtenAvgPoolOp<AtenAvgPool2dOp, linalg::PoolingNchwSumOp, 2>>(
|
||||||
|
typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,13 +16,13 @@
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto constType = RankedTensorType::get({}, elementTy);
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
// Avg pooling
|
// Avg pooling
|
||||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
|
@ -373,169 +373,195 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// AtenAvgPool2dOp
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|
||||||
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
Value input = adaptor.getSelf();
|
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
|
||||||
auto inputElemTy = inputTy.getElementType();
|
|
||||||
auto inputRank = inputTy.getRank();
|
|
||||||
auto outTy =
|
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
|
||||||
auto outShape = outTy.getShape();
|
|
||||||
|
|
||||||
if (inputRank <= 2) {
|
namespace {
|
||||||
return op.emitError(
|
template <typename AtenOpT, int Dim>
|
||||||
"avg_pooling2d only supports inputs with rank higher than 2");
|
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
}
|
public:
|
||||||
SmallVector<int64_t, 2> padding, kernelSize, stride;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
bool ceilMode = false;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
bool countIncludePad = true;
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
RankedTensorType inputTy = input.getType().cast<RankedTensorType>();
|
||||||
|
Type inputElemTy = inputTy.getElementType();
|
||||||
|
int64_t inputRank = inputTy.getRank();
|
||||||
|
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
auto outShape = outTy.getShape();
|
||||||
|
|
||||||
if (!(matchPattern(op.getKernelSize(),
|
|
||||||
m_TorchListOfConstantInts(kernelSize)))) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "non-const int kernel size unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const int padding unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"non-const bool ceil_mode unsupported!");
|
|
||||||
}
|
|
||||||
if (!(matchPattern(op.getCountIncludePad(),
|
|
||||||
m_TorchConstantBool(&countIncludePad)))) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "non-const bool count_include_pad unsupported!");
|
|
||||||
}
|
|
||||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "only None divisor_override supported for now!");
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
if (inputRank <= Dim) {
|
||||||
// input
|
return op.emitError(
|
||||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
}
|
||||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
bool ceilMode = false;
|
||||||
|
bool countIncludePad = true;
|
||||||
|
|
||||||
std::copy(stride.begin(), stride.end(),
|
if (!(matchPattern(op.getKernelSize(),
|
||||||
stablehloStride.begin() + inputRank - 2);
|
m_TorchListOfConstantInts(kernelSize)))) {
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
return rewriter.notifyMatchFailure(
|
||||||
stablehloKernelSize.begin() + inputRank - 2);
|
op, "non-const int kernel size unsupported!");
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
}
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
}
|
||||||
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int padding unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const bool ceil_mode unsupported!");
|
||||||
|
}
|
||||||
|
if (!(matchPattern(op.getCountIncludePad(),
|
||||||
|
m_TorchConstantBool(&countIncludePad)))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const bool count_include_pad unsupported!");
|
||||||
|
}
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||||
|
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only None divisor_override supported for now!");
|
||||||
|
}
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
// Prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
// as input
|
||||||
rewriter.getI64Type()),
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||||
stablehloKernelSize);
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
rewriter.getI64Type()),
|
|
||||||
stablehloStride);
|
|
||||||
DenseIntElementsAttr baseDilations;
|
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
|
||||||
rewriter.getI64Type()),
|
|
||||||
stablehloDilation);
|
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get(
|
|
||||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
||||||
rewriter.getI64Type()),
|
|
||||||
stablehloPadding);
|
|
||||||
|
|
||||||
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
std::copy(stride.begin(), stride.end(),
|
||||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
stablehloStride.begin() + inputRank - Dim);
|
||||||
baseDilations, windowDilations, pad);
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
|
stablehloKernelSize.begin() + inputRank - Dim);
|
||||||
|
if (Dim == 1) {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
|
} else {
|
||||||
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
|
}
|
||||||
|
|
||||||
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
// Add bb argument
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
rewriter.getI64Type()),
|
||||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
stablehloKernelSize);
|
||||||
auto *firstArg = sumBlock.args_begin();
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
auto secondArg = sumBlock.args_rbegin();
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloStride);
|
||||||
|
DenseIntElementsAttr baseDilations;
|
||||||
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloDilation);
|
||||||
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stablehloPadding);
|
||||||
|
|
||||||
{
|
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||||
rewriter.setInsertionPointToStart(&sumBlock);
|
baseDilations, windowDilations, pad);
|
||||||
|
|
||||||
Value sumResult =
|
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
||||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use kernel size as the divisor
|
// Add bb argument
|
||||||
if (countIncludePad) {
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||||
Value divisor = hlo::getConstTensor<int64_t>(
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
auto firstArg = *sumBlock.args_begin();
|
||||||
|
auto secondArg = *sumBlock.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&sumBlock);
|
||||||
|
|
||||||
|
Value sumResult =
|
||||||
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use kernel size as the divisor
|
||||||
|
if (countIncludePad) {
|
||||||
|
Value divisor;
|
||||||
|
if (Dim == 1) {
|
||||||
|
divisor =
|
||||||
|
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||||
|
.value();
|
||||||
|
} else {
|
||||||
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.value();
|
.value();
|
||||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
}
|
||||||
DenseIntElementsAttr bcastDimensions;
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
DenseIntElementsAttr bcastDimensions;
|
||||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||||
|
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||||
|
Value windowSizeConst =
|
||||||
|
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||||
|
windowSizeConst =
|
||||||
|
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||||
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
|
auto inputShapeVec =
|
||||||
|
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), inputShapeVec);
|
||||||
|
|
||||||
|
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||||
|
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||||
|
|
||||||
|
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||||
|
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
||||||
|
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
||||||
|
windowDilations, pad);
|
||||||
|
|
||||||
|
Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock();
|
||||||
|
|
||||||
|
// Add bb argument
|
||||||
|
blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||||
|
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
||||||
|
firstArg = *sizeBlock.args_begin();
|
||||||
|
secondArg = *sizeBlock.args_rbegin();
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||||
|
|
||||||
|
Value sumResult =
|
||||||
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||||
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||||
|
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use another stablehlo.ReduceWindowOp to get the divisor
|
};
|
||||||
Value windowSizeConst =
|
|
||||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
|
||||||
windowSizeConst =
|
|
||||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
|
||||||
const auto &options = getOptions();
|
|
||||||
auto inputShapeVec =
|
|
||||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
||||||
op->getLoc(), inputShapeVec);
|
|
||||||
|
|
||||||
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
|
||||||
op->getLoc(),
|
|
||||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
|
||||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
|
||||||
|
|
||||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
||||||
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
||||||
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
|
||||||
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
|
||||||
windowDilations, pad);
|
|
||||||
|
|
||||||
Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock();
|
|
||||||
|
|
||||||
// Add bb argument
|
|
||||||
blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
||||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
|
||||||
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
|
||||||
firstArg = sizeBlock.args_begin();
|
|
||||||
secondArg = sizeBlock.args_rbegin();
|
|
||||||
|
|
||||||
{
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(&sizeBlock);
|
|
||||||
|
|
||||||
Value sumResult =
|
|
||||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
|
||||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// AtenCumsumOp
|
// AtenCumsumOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
|
@ -621,6 +647,8 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||||
|
patterns.add<ConvertAtenOp<AtenAvgPool1dOp>>(typeConverter, context, options);
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||||
|
@ -630,4 +658,11 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
context, options);
|
context, options);
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
target.addIllegalOp<AtenCumsumOp>();
|
||||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||||
|
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>( \
|
||||||
|
typeConverter, context, options)
|
||||||
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||||
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||||
|
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -6839,6 +6839,102 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<float>, %arg4: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||||
" return %arg2 : !torch.list<int>\n"
|
" return %arg2 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @__torch__.avg_pool1d(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %int-1 = torch.constant.int -1\n"
|
||||||
|
" %int-2 = torch.constant.int -2\n"
|
||||||
|
" %int-3 = torch.constant.int -3\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n"
|
||||||
|
" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int2 = torch.constant.int 2\n"
|
||||||
|
" %int3 = torch.constant.int 3\n"
|
||||||
|
" %0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %1 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %3 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %24 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %5 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %6 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %24 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %24 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %9 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %10 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %11 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %12 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %13 = torch.aten.eq.int %12, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %24 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %25 = torch.aten.eq.int %24, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %14 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %15 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %16 = torch.aten.eq.int %15, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
|
||||||
|
" %24 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %24 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %int1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %18 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %19 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %20 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%19, %2, %11, %8, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n"
|
||||||
|
" %21 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %22 = torch.aten.eq.int %21, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %23 = torch.prim.If %22 -> (!torch.list<int>) {\n"
|
||||||
|
" %24 = torch.prim.ListConstruct %18, %20 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %24 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %24 = torch.prim.ListConstruct %17, %18, %20 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %24 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %23 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -8150,6 +8246,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -526,6 +526,37 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
|
||||||
else:
|
else:
|
||||||
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
return [nbatch, nInputPlane, outputHeight, outputWidth]
|
||||||
|
|
||||||
|
# TODO: This should be upstreamed.
|
||||||
|
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
|
||||||
|
def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool):
|
||||||
|
assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int"
|
||||||
|
kL = kernel_size[0]
|
||||||
|
|
||||||
|
assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int"
|
||||||
|
dL = kL if len(stride) == 0 else stride[0]
|
||||||
|
|
||||||
|
assert len(padding) == 1, "avg_pool1d: padding must be a single int"
|
||||||
|
padL = padding[0]
|
||||||
|
|
||||||
|
dilationL = 1
|
||||||
|
|
||||||
|
assert len(input) == 2 or len(input) == 3
|
||||||
|
|
||||||
|
nbatch = input[-3] if len(input) == 3 else 1
|
||||||
|
nInputPlane = input[-2]
|
||||||
|
inputLength = input[-1]
|
||||||
|
|
||||||
|
outputLength = upstream_shape_functions.pooling_output_shape(
|
||||||
|
inputLength, kL, padL, dL, dilationL, ceil_mode)
|
||||||
|
|
||||||
|
if len(input) == 2:
|
||||||
|
return [nInputPlane, outputLength]
|
||||||
|
else:
|
||||||
|
return [nbatch, nInputPlane, outputLength]
|
||||||
|
|
||||||
|
def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]:
|
||||||
|
return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||||
|
|
||||||
def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
|
def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
|
||||||
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
||||||
|
|
||||||
|
@ -1362,6 +1393,11 @@ def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
|
||||||
|
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
|
||||||
def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int:
|
def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -402,6 +402,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit(
|
emit(
|
||||||
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
"aten::avg_pool2d : (Tensor, int[], int[], int[], bool, bool, int?) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::avg_pool1d : (Tensor, int[], int[], int[], bool, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit(
|
emit(
|
||||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -700,3 +700,75 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
|
@register_test_case(module_factory=lambda: AvgPool2dCeilModeTrueModule())
|
||||||
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool1dFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
ceil_mode=False,
|
||||||
|
count_include_pad=True)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ap1d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AvgPool1dFloatModule())
|
||||||
|
def AvgPool1dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 4, 20, low=-1))
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool1dIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
ceil_mode=False,
|
||||||
|
count_include_pad=True)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ap1d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AvgPool1dIntModule())
|
||||||
|
def AvgPool1dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 4, 20, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool1dStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ap1d = torch.nn.AvgPool1d(kernel_size=6,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
ceil_mode=False,
|
||||||
|
count_include_pad=True)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 4, 20], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ap1d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AvgPool1dStaticModule())
|
||||||
|
def AvgPool1dStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 4, 20, high=100))
|
Loading…
Reference in New Issue