[tosa] Add maxpool2d and adaptive_avgpool2d support (#550)

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/538/head snapshot-20220131.240
Suraj Sudhir 2022-01-31 13:34:09 -08:00 committed by GitHub
parent 5d9a15263a
commit 0f083e770a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 287 additions and 1 deletions

View File

@ -90,4 +90,5 @@ TOSA_PASS_SET = {
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
}

View File

@ -17,9 +17,9 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
@ -1987,6 +1987,277 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
// Different pooling variants need to process inputs differently, e.g.
// adaptive pooling generates the kernel size rather than receive it. This
// function also transposes inputs.
virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Value &input, ArrayAttr &kernel,
ArrayAttr &stride, ArrayAttr &pad,
Type &outputTy) const {
return rewriter.notifyMatchFailure(
op, "Unimplemented pooling input parsing function");
}
int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride,
int64_t padBefore, int64_t padAfter,
int64_t dilation) const {
if (inputDim == ShapedType::kDynamicSize) {
return ShapedType::kDynamicSize;
} else {
return (
(inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1) /
stride +
1);
}
}
// Apply the transposeDims vector on input to generate a transposed form.
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
Value input, ArrayRef<int32_t> transposeDims) const {
auto inputTy = input.getType().template cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
llvm::Optional<Value> transposeDimsConst = tosa::getConstTensor<int32_t>(
rewriter, op,
/*vec=*/transposeDims,
/*shape=*/{static_cast<int32_t>(inputRank)});
SmallVector<int64_t> transposedInputShape;
for (auto &dim : transposeDims)
transposedInputShape.push_back(inputShape[dim]);
auto transposedInputType =
RankedTensorType::get(transposedInputShape, inputElemTy);
return rewriter
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
transposeDimsConst.getValue())
.getResult();
}
Value transposePoolingInputToHwc(AtenOpT op,
ConversionPatternRewriter &rewriter,
Value input) const {
auto inputRank =
input.getType().template cast<RankedTensorType>().getRank();
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
return transposeTensor(op, rewriter, input,
inputRank == 3 ? chwToHwc3DTransposeDims
: nchwToNhwc4DTransposeDims);
}
Value transposePoolingOutputToChw(AtenOpT op,
ConversionPatternRewriter &rewriter,
Value input) const {
auto inputTy = input.getType().template cast<RankedTensorType>();
auto inputRank = inputTy.getRank();
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
SmallVector<int32_t> hwcToChw3DTransposeDims({2, 0, 1});
return transposeTensor(op, rewriter, input,
inputRank == 3 ? hwcToChw3DTransposeDims
: nhwcToNchw4DTransposeDims);
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input;
ArrayAttr kernel, stride, pad;
Type outputTy;
// Attempts to read input and kernel parameters, or synthesize them in the
// case of adaptive pooling. Also performs input CHW->HWC transpose.
if (failed(processInputs(op, adaptor, rewriter, input, kernel, stride, pad,
outputTy)))
return op.emitError("Failed to process inputs for pooling");
auto pooledOutput =
rewriter
.create<TosaOpT>(op->getLoc(), outputTy, input, kernel, stride, pad)
.getResult();
auto transposedOutput =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
op, rewriter, pooledOutput);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
transposedOutput);
return success();
}
};
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenAdaptivePoolingOp
: public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
public:
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride,
ArrayAttr &pad, Type &outputTy) const override {
auto inputXchw = adaptor.self();
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
if (!inputTy)
return op.emitError("Adaptive avgpool requires ranked tensor input");
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto inputElemTy = inputTy.getElementType();
// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
int64_t inputHDim = inputShape[inputRank - 2];
int64_t inputWDim = inputShape[inputRank - 1];
SmallVector<int64_t> outputSize;
if (!matchPattern(op.output_size(), m_TorchConstantIntList(outputSize)))
return rewriter.notifyMatchFailure(
op, "Non-const output_size for adaptive pooling unsupported.");
SmallVector<int64_t> kernelDims;
int64_t outputHDim, outputWDim;
if (outputSize.size() == 1) {
outputHDim = outputWDim = outputSize[0];
} else {
if (outputSize.size() != 2)
return op.emitError(
"Adaptive avgpool output_size not 1 or 2 elements.");
// Assumes 'None' (e.g. output_size=(None, 5) ) is expressed as <=0.
outputHDim =
(outputSize[0] <= 0) ? inputShape[inputRank - 2] : outputSize[0];
outputWDim =
(outputSize[1] <= 0) ? inputShape[inputRank - 1] : outputSize[1];
}
// In adaptive pooling,
// stride = inputDim // outputDim
// kernel = inputDim - (outputDim-1)* stride
// pad = 0, dilation = 1
int64_t strideH = inputShape[inputRank - 2] / outputHDim;
int64_t strideW = inputShape[inputRank - 1] / outputWDim;
kernelDims.push_back(inputHDim - (outputHDim - 1) * strideH);
kernelDims.push_back(inputWDim - (outputWDim - 1) * strideW);
SmallVector<int64_t> outputShape;
if (inputRank > 3)
outputShape.push_back(inputShape[0]);
outputShape.push_back(outputHDim);
outputShape.push_back(outputWDim);
outputShape.push_back(inputShape[inputRank - 3]);
// Transpose to xHWC
input =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
op, rewriter, inputXchw);
kernel = rewriter.getI64ArrayAttr(kernelDims);
stride = rewriter.getI64ArrayAttr({strideH, strideW});
// Adaptive pooling does unit dilation and zero pad.
pad = rewriter.getI64ArrayAttr({0, 0, 0, 0});
outputTy = RankedTensorType::get(outputShape, inputElemTy);
return success();
}
};
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
public:
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride,
ArrayAttr &pad, Type &outputTy) const override {
auto inputXchw = adaptor.self();
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
if (!inputTy)
return op.emitError("Adaptive avgpool requires ranked tensor input");
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto inputElemTy = inputTy.getElementType();
// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
// Transpose to xHWC
input =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
op, rewriter, inputXchw);
SmallVector<int64_t> kernelSize;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for adaptive pooling unsupported.");
kernel = rewriter.getI64ArrayAttr(kernelSize);
SmallVector<int64_t> strideArray;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray)))
return rewriter.notifyMatchFailure(
op, "Non-const stride for adaptive pooling unsupported.");
stride = rewriter.getI64ArrayAttr(strideArray);
SmallVector<int64_t> padArray;
if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray)))
return rewriter.notifyMatchFailure(
op, "Non-const pad for adaptive pooling unsupported.");
pad = rewriter.getI64ArrayAttr(
{padArray[0], padArray[0], padArray[1], padArray[1]});
SmallVector<int64_t> dilationArray;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray)))
return rewriter.notifyMatchFailure(
op, "Non-const dilation for adaptive pooling unsupported.");
// TOSA pooling only supports unit dilation.
if (dilationArray[0] > 1 || dilationArray[1] > 1)
return op.emitError("Cannot process non-unit pooling dilation.");
// FIXME: add ceil_mode support.
int64_t outputHDim =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
inputShape[inputRank - 2], kernelSize[0], strideArray[0],
padArray[0], padArray[0], dilationArray[0]);
int64_t outputWDim =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
inputShape[inputRank - 1], kernelSize[1], strideArray[1],
padArray[1], padArray[1], dilationArray[1]);
SmallVector<int64_t> outputShape;
if (inputRank > 3)
outputShape.push_back(inputShape[0]);
outputShape.push_back(outputHDim);
outputShape.push_back(outputWDim);
outputShape.push_back(inputShape[inputRank - 3]);
outputTy = RankedTensorType::get(outputShape, inputElemTy);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
@ -2131,6 +2402,20 @@ public:
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPoolingOp<AtenOp, TosaOpT>>(typeConverter, context);
INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp);
#undef INSERT_POOLING_ATEMOP_PATTERN
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
context);
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);