mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add decomposition of aten.adaptive_avg_pool2d op
This commit adds the decomposition of `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. The current decomposition only supports cases where input size is equal to the output size. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/882/merge
parent
b76c8c82dc
commit
6f548fc3ad
|
@ -45,6 +45,10 @@ Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
|||
|
||||
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx);
|
||||
|
||||
SmallVector<Value>
|
||||
castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
||||
SmallVectorImpl<Value> &intValues);
|
||||
|
||||
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);
|
||||
|
||||
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||
|
|
|
@ -32,15 +32,21 @@ using namespace mlir::torch::Torch;
|
|||
template <typename OpTy>
|
||||
static LogicalResult
|
||||
checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
|
||||
bool &ceilMode,
|
||||
SmallVectorImpl<int64_t> &kernelSizeInts,
|
||||
TypeConverter *typeConverter, bool &ceilMode,
|
||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||
SmallVectorImpl<int64_t> &strideInts,
|
||||
SmallVectorImpl<int64_t> &paddingInts) {
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
|
||||
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
|
||||
SmallVector<Value, 2> kernelSizeTorchInt;
|
||||
if (!getListConstructElements(op.kernel_size(), kernelSizeTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: the kernel size is "
|
||||
"not constructed from ListConstruct");
|
||||
}
|
||||
kernelSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt);
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op, "only support constant int strides");
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
|
||||
|
@ -58,7 +64,7 @@ template <typename OpTy>
|
|||
static LogicalResult createPoolingOp(
|
||||
Operation *op, ConversionPatternRewriter &rewriter, Value self,
|
||||
bool supportFPInputOnly, bool ceilMode,
|
||||
SmallVectorImpl<int64_t> &kernelSizeInts,
|
||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
|
||||
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
||||
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
||||
|
@ -88,8 +94,6 @@ static LogicalResult createPoolingOp(
|
|||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> dilationIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||
SmallVector<Value> kernelSizeIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
|
||||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
|
@ -108,7 +112,7 @@ static LogicalResult createPoolingOp(
|
|||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
|
||||
loc, castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues),
|
||||
elementType);
|
||||
|
||||
result = rewriter
|
||||
|
@ -130,6 +134,7 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.self();
|
||||
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
|
||||
// TODO: Add support for 3D inputs.
|
||||
|
@ -137,14 +142,15 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only support 4D input");
|
||||
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts;
|
||||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
bool ceilMode;
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dOp>(
|
||||
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
|
||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
|
@ -158,7 +164,7 @@ public:
|
|||
Value maxPool2d, paddedInput;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/false, ceilMode,
|
||||
kernelSizeInts, strideInts, paddingInts, dilationInts,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
|
@ -200,6 +206,7 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.self();
|
||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
||||
Type elementType = selfType.getElementType();
|
||||
|
@ -213,14 +220,15 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only support 4D input");
|
||||
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts,
|
||||
dilationInts;
|
||||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
bool ceilMode;
|
||||
if (failed(checkAndGetPoolingParameters<AtenMaxPool2dWithIndicesOp>(
|
||||
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
|
||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
// `maxpool2d` contains the result of maxpool2d operation over the input.
|
||||
|
@ -233,7 +241,7 @@ public:
|
|||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/false, ceilMode,
|
||||
kernelSizeInts, strideInts, paddingInts, dilationInts,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
|
||||
|
||||
|
@ -244,7 +252,7 @@ public:
|
|||
indicesRankedTensorType.getElementType(), cstMinusOne);
|
||||
|
||||
SmallVector<Value> kernelSize =
|
||||
getAsConstantIndexValues(rewriter, loc, kernelSizeInts);
|
||||
castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues);
|
||||
SmallVector<Value> padding =
|
||||
getAsConstantIndexValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> dilation =
|
||||
|
@ -344,105 +352,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool2dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAdaptiveAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.self(); /* in form of N*C*H*W */
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
Type elementType = inputType.getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
return op.emitError("unimplemented: non-floating point type");
|
||||
|
||||
auto inputRank = inputType.getRank();
|
||||
if (inputRank != 4)
|
||||
return rewriter.notifyMatchFailure(op, "input should be rank 4");
|
||||
|
||||
SmallVector<int64_t, 2> expects{1, 1};
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
if (!isConstantIntListMatching(op.output_size(), expects))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support output_size with H and W both equal to constant 1");
|
||||
|
||||
Value N = getDimOp(rewriter, loc, input, 0);
|
||||
Value C = getDimOp(rewriter, loc, input, 1);
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, C}, elementType);
|
||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
Value initTensor0 =
|
||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
||||
SmallVector<AffineExpr, 2> ncExprs;
|
||||
ncExprs.push_back(mlir::getAffineDimExpr(0, context));
|
||||
ncExprs.push_back(mlir::getAffineDimExpr(1, context));
|
||||
auto ncIndexingMap = AffineMap::get(
|
||||
/*dimCount=*/4,
|
||||
/*symbolCount=*/0, ncExprs, context);
|
||||
SmallVector<AffineMap, 2> indexingMaps = {
|
||||
rewriter.getMultiDimIdentityMap(4), // input
|
||||
ncIndexingMap, // output
|
||||
};
|
||||
SmallVector<StringRef, 4> iteratorTypesSum{"parallel", "parallel",
|
||||
"reduction", "reduction"};
|
||||
Value sumPool2d = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor0.getType(), input, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypesSum,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], sum = args[1];
|
||||
Value result = rewriter.create<arith::AddFOp>(
|
||||
loc, sum, input);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Calculate H*W so that avg can be got from sum / (H*W)
|
||||
Value H = getDimOp(rewriter, loc, input, 2);
|
||||
Value W = getDimOp(rewriter, loc, input, 3);
|
||||
auto castIndexToInt = [&](Value v) {
|
||||
return rewriter.create<arith::IndexCastOp>(
|
||||
loc, IntegerType::get(context, 64), v);
|
||||
};
|
||||
Value HtimesW = rewriter.create<arith::MulIOp>(loc, castIndexToInt(H),
|
||||
castIndexToInt(W));
|
||||
Value HtimesWf =
|
||||
rewriter.create<arith::SIToFPOp>(loc, elementType, HtimesW);
|
||||
|
||||
Value c1Index = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
||||
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, C, c1Index, c1Index}, elementType);
|
||||
SmallVector<AffineMap, 2> indexingMapsAvg{
|
||||
ncIndexingMap, rewriter.getMultiDimIdentityMap(4)};
|
||||
SmallVector<StringRef, 4> iteratorTypesAvg(4, "parallel");
|
||||
Value avgPool2d =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputTensor.getType(), sumPool2d, outputTensor,
|
||||
/*indexingMaps=*/indexingMapsAvg,
|
||||
/*iteratorTypes=*/iteratorTypesAvg,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value avg = b.create<arith::DivFOp>(loc, args[0], HtimesWf);
|
||||
b.create<linalg::YieldOp>(loc, avg);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, avgPool2d);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAvgPool2dOp : public OpConversionPattern<AtenAvgPool2dOp> {
|
||||
public:
|
||||
|
@ -453,6 +362,7 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
Value self = adaptor.self();
|
||||
|
||||
Type inputElementType =
|
||||
|
@ -461,11 +371,12 @@ public:
|
|||
Type resultElementType =
|
||||
resultType.cast<RankedTensorType>().getElementType();
|
||||
|
||||
SmallVector<int64_t, 2> dilationInts{1, 1};
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
|
||||
bool ceilMode;
|
||||
SmallVector<Value, 2> kernelSizeIntValues;
|
||||
SmallVector<int64_t, 2> strideInts, paddingInts, dilationInts{1, 1};
|
||||
if (failed(checkAndGetPoolingParameters<AtenAvgPool2dOp>(
|
||||
op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts)))
|
||||
op, rewriter, typeConverter, ceilMode, kernelSizeIntValues,
|
||||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
// TODO: Add support for count_include_pad equal to `False`.
|
||||
|
@ -484,13 +395,11 @@ public:
|
|||
SmallVector<Value, 4> outTensorShape;
|
||||
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
|
||||
op, rewriter, self, /*supportFPInput=*/true, ceilMode,
|
||||
kernelSizeInts, strideInts, paddingInts, dilationInts,
|
||||
kernelSizeIntValues, strideInts, paddingInts, dilationInts,
|
||||
rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput,
|
||||
sumPool2d)))
|
||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d");
|
||||
|
||||
SmallVector<Value> kernelSizeIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
|
||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||
Value divisor = op.divisor_override().getType().isa<Torch::NoneType>()
|
||||
|
@ -534,8 +443,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -143,6 +143,15 @@ Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
|
|||
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
||||
SmallVectorImpl<Value> &intValues) {
|
||||
SmallVector<Value> indexValues;
|
||||
for (Value v : intValues)
|
||||
indexValues.push_back(castIntToIndex(b, loc, v));
|
||||
return indexValues;
|
||||
}
|
||||
|
||||
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
||||
return b.createOrFold<tensor::DimOp>(loc, v, dim);
|
||||
}
|
||||
|
|
|
@ -1551,7 +1551,8 @@ class DecomposeAten_UnsafeViewOp : public OpRewritePattern<Aten_UnsafeViewOp> {
|
|||
// Note that this is the same decomposition as in AOTAutograd
|
||||
// https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/_src/aot_autograd.py#L104
|
||||
namespace {
|
||||
class DecomposeAten_ReshapeAliasOp : public OpRewritePattern<Aten_ReshapeAliasOp> {
|
||||
class DecomposeAten_ReshapeAliasOp
|
||||
: public OpRewritePattern<Aten_ReshapeAliasOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
@ -1775,6 +1776,116 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
|
||||
//
|
||||
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
|
||||
// output size the kernel_size, stride and padding is calculated as follows:
|
||||
// strideH = inH // outH
|
||||
// strideW = inH // outH
|
||||
// kernelH = inH - [(outH - 1) * strideH]
|
||||
// kernelW = inW - [(outW - 1) * strideW]
|
||||
// paddingH = 0, paddingW = 0
|
||||
//
|
||||
// For the special case, when the output size is one for all dimensions,
|
||||
// the kernel size is same as the input size.
|
||||
class DecomposeAtenAdaptiveAvgPool2dOp
|
||||
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
Value input = op.self();
|
||||
int64_t rank = getTensorRank(input);
|
||||
SmallVector<Value, 2> inputHW;
|
||||
Value dimH = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rank - 2));
|
||||
inputHW.push_back(
|
||||
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
|
||||
Value dimW = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rank - 1));
|
||||
inputHW.push_back(
|
||||
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
|
||||
|
||||
Value outputShape = op.output_size();
|
||||
SmallVector<Value> outputShapeSizesTorchInt;
|
||||
getListConstructElements(outputShape, outputShapeSizesTorchInt);
|
||||
|
||||
// TODO: Add support for cases other than:
|
||||
// 1.) inH == outH and inW == outW.
|
||||
// 2.) outH == outW == 1
|
||||
bool unitOutputSize = true;
|
||||
for (Value outShape : outputShapeSizesTorchInt) {
|
||||
int64_t outShapeInt;
|
||||
if (!matchPattern(outShape, m_TorchConstantInt(&outShapeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "output size is expected to be a constant");
|
||||
}
|
||||
if (outShapeInt != 1) {
|
||||
unitOutputSize = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value constantTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
Value constantNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
SmallVector<Value, 2> kernelSize;
|
||||
|
||||
for (unsigned i = 0; i < inputHW.size(); i++) {
|
||||
if (unitOutputSize) {
|
||||
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
kernelSize.push_back(inputShape[rank - 2 + i] == kUnknownSize
|
||||
? inputHW[i]
|
||||
: rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(
|
||||
inputShape[rank - 2 + i])));
|
||||
} else {
|
||||
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
|
||||
outputShapeSizesTorchInt[i]);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"unimplemented: only support cases where input and output size are "
|
||||
"equal for non-unit output size");
|
||||
|
||||
Value outMinusOne = rewriter.create<AtenSubIntOp>(
|
||||
loc, outputShapeSizesTorchInt[i], constantOne);
|
||||
kernelSize.push_back(
|
||||
rewriter.create<AtenSubIntOp>(loc, inputHW[i], outMinusOne));
|
||||
}
|
||||
}
|
||||
|
||||
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
|
||||
// Currently we only support cases where input size is equal to the output
|
||||
// size or unit output size. For the former case, stride is always equal to
|
||||
// one and for the latter the stride value doesn't matter, since the kernel
|
||||
// size is same as the input size. Therfore, keeping the stride as one for
|
||||
// the latter case as well for the ease of implementation.
|
||||
Value strideList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantOne, constantOne});
|
||||
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantZero, constantZero});
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
|
||||
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
|
||||
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue,
|
||||
/*divisor_override=*/constantNone);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -1910,6 +2021,8 @@ class DecomposeComplexOpsPass
|
|||
patterns.add<DecomposeAtenPadOp>(context);
|
||||
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -12,7 +12,72 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
||||
class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 512, 7, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeStaticModule())
|
||||
def AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7, 7))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((7, 7))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule())
|
||||
def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7, 7))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 512, 7, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeStaticModule())
|
||||
def AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7, 7))
|
||||
|
||||
|
||||
class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -27,9 +92,10 @@ class AdaptiveAvgPool2dModule(torch.nn.Module):
|
|||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule())
|
||||
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9))
|
||||
@register_test_case(
|
||||
module_factory=lambda: AdaptiveAvgPool2dUnitOutputSizeDynamicModule())
|
||||
def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 512, 7, 7))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -11,12 +11,14 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,
|
|||
%int7 = torch.constant.int 7
|
||||
%int8 = torch.constant.int 8
|
||||
%false = torch.constant.bool false
|
||||
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
|
||||
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -3.40282347E+38 : f32
|
||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[T1]], %[[T2]]] : tensor<?x?xf32>
|
||||
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
|
|
@ -949,3 +949,37 @@ func.func @torch.aten.to.dtype_layout(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
|||
%0 = torch.aten.to.dtype_layout %arg0, %int7, %int0, %none, %none, %false, %false, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||
return %0 : !torch.vtensor<[?,?],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
|
||||
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
|
||||
// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.adaptive_avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int7 = torch.constant.int 7
|
||||
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue