mirror of https://github.com/llvm/torch-mlir
[tosa] Add maxpool2d and adaptive_avgpool2d support (#550)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/538/head snapshot-20220131.240
parent
5d9a15263a
commit
0f083e770a
|
@ -90,4 +90,5 @@ TOSA_PASS_SET = {
|
|||
"FlattenRank0Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"SquareModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
}
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
|
@ -1987,6 +1987,277 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
// Different pooling variants need to process inputs differently, e.g.
|
||||
// adaptive pooling generates the kernel size rather than receive it. This
|
||||
// function also transposes inputs.
|
||||
virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &input, ArrayAttr &kernel,
|
||||
ArrayAttr &stride, ArrayAttr &pad,
|
||||
Type &outputTy) const {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented pooling input parsing function");
|
||||
}
|
||||
|
||||
int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride,
|
||||
int64_t padBefore, int64_t padAfter,
|
||||
int64_t dilation) const {
|
||||
if (inputDim == ShapedType::kDynamicSize) {
|
||||
return ShapedType::kDynamicSize;
|
||||
} else {
|
||||
return (
|
||||
(inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1) /
|
||||
stride +
|
||||
1);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the transposeDims vector on input to generate a transposed form.
|
||||
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
|
||||
Value input, ArrayRef<int32_t> transposeDims) const {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
||||
llvm::Optional<Value> transposeDimsConst = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op,
|
||||
/*vec=*/transposeDims,
|
||||
/*shape=*/{static_cast<int32_t>(inputRank)});
|
||||
|
||||
SmallVector<int64_t> transposedInputShape;
|
||||
for (auto &dim : transposeDims)
|
||||
transposedInputShape.push_back(inputShape[dim]);
|
||||
auto transposedInputType =
|
||||
RankedTensorType::get(transposedInputShape, inputElemTy);
|
||||
return rewriter
|
||||
.create<tosa::TransposeOp>(op->getLoc(), transposedInputType, input,
|
||||
transposeDimsConst.getValue())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value transposePoolingInputToHwc(AtenOpT op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input) const {
|
||||
auto inputRank =
|
||||
input.getType().template cast<RankedTensorType>().getRank();
|
||||
|
||||
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
|
||||
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
|
||||
|
||||
return transposeTensor(op, rewriter, input,
|
||||
inputRank == 3 ? chwToHwc3DTransposeDims
|
||||
: nchwToNhwc4DTransposeDims);
|
||||
}
|
||||
|
||||
Value transposePoolingOutputToChw(AtenOpT op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input) const {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
||||
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
|
||||
SmallVector<int32_t> hwcToChw3DTransposeDims({2, 0, 1});
|
||||
|
||||
return transposeTensor(op, rewriter, input,
|
||||
inputRank == 3 ? hwcToChw3DTransposeDims
|
||||
: nhwcToNchw4DTransposeDims);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value input;
|
||||
ArrayAttr kernel, stride, pad;
|
||||
Type outputTy;
|
||||
|
||||
// Attempts to read input and kernel parameters, or synthesize them in the
|
||||
// case of adaptive pooling. Also performs input CHW->HWC transpose.
|
||||
if (failed(processInputs(op, adaptor, rewriter, input, kernel, stride, pad,
|
||||
outputTy)))
|
||||
return op.emitError("Failed to process inputs for pooling");
|
||||
|
||||
auto pooledOutput =
|
||||
rewriter
|
||||
.create<TosaOpT>(op->getLoc(), outputTy, input, kernel, stride, pad)
|
||||
.getResult();
|
||||
|
||||
auto transposedOutput =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
|
||||
op, rewriter, pooledOutput);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
transposedOutput);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenAdaptivePoolingOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
ArrayAttr &kernel, ArrayAttr &stride,
|
||||
ArrayAttr &pad, Type &outputTy) const override {
|
||||
auto inputXchw = adaptor.self();
|
||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy)
|
||||
return op.emitError("Adaptive avgpool requires ranked tensor input");
|
||||
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
// Rank sanity check.
|
||||
if (inputTy.getRank() != 4 && inputRank != 3)
|
||||
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
|
||||
|
||||
int64_t inputHDim = inputShape[inputRank - 2];
|
||||
int64_t inputWDim = inputShape[inputRank - 1];
|
||||
|
||||
SmallVector<int64_t> outputSize;
|
||||
if (!matchPattern(op.output_size(), m_TorchConstantIntList(outputSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const output_size for adaptive pooling unsupported.");
|
||||
|
||||
SmallVector<int64_t> kernelDims;
|
||||
int64_t outputHDim, outputWDim;
|
||||
if (outputSize.size() == 1) {
|
||||
outputHDim = outputWDim = outputSize[0];
|
||||
} else {
|
||||
if (outputSize.size() != 2)
|
||||
return op.emitError(
|
||||
"Adaptive avgpool output_size not 1 or 2 elements.");
|
||||
|
||||
// Assumes 'None' (e.g. output_size=(None, 5) ) is expressed as <=0.
|
||||
outputHDim =
|
||||
(outputSize[0] <= 0) ? inputShape[inputRank - 2] : outputSize[0];
|
||||
outputWDim =
|
||||
(outputSize[1] <= 0) ? inputShape[inputRank - 1] : outputSize[1];
|
||||
}
|
||||
|
||||
// In adaptive pooling,
|
||||
// stride = inputDim // outputDim
|
||||
// kernel = inputDim - (outputDim-1)* stride
|
||||
// pad = 0, dilation = 1
|
||||
|
||||
int64_t strideH = inputShape[inputRank - 2] / outputHDim;
|
||||
int64_t strideW = inputShape[inputRank - 1] / outputWDim;
|
||||
|
||||
kernelDims.push_back(inputHDim - (outputHDim - 1) * strideH);
|
||||
kernelDims.push_back(inputWDim - (outputWDim - 1) * strideW);
|
||||
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (inputRank > 3)
|
||||
outputShape.push_back(inputShape[0]);
|
||||
outputShape.push_back(outputHDim);
|
||||
outputShape.push_back(outputWDim);
|
||||
outputShape.push_back(inputShape[inputRank - 3]);
|
||||
|
||||
// Transpose to xHWC
|
||||
input =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
|
||||
op, rewriter, inputXchw);
|
||||
kernel = rewriter.getI64ArrayAttr(kernelDims);
|
||||
stride = rewriter.getI64ArrayAttr({strideH, strideW});
|
||||
// Adaptive pooling does unit dilation and zero pad.
|
||||
pad = rewriter.getI64ArrayAttr({0, 0, 0, 0});
|
||||
outputTy = RankedTensorType::get(outputShape, inputElemTy);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
ArrayAttr &kernel, ArrayAttr &stride,
|
||||
ArrayAttr &pad, Type &outputTy) const override {
|
||||
|
||||
auto inputXchw = adaptor.self();
|
||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy)
|
||||
return op.emitError("Adaptive avgpool requires ranked tensor input");
|
||||
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
// Rank sanity check.
|
||||
if (inputTy.getRank() != 4 && inputRank != 3)
|
||||
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
|
||||
|
||||
// Transpose to xHWC
|
||||
input =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
|
||||
op, rewriter, inputXchw);
|
||||
|
||||
SmallVector<int64_t> kernelSize;
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const kernel_size for adaptive pooling unsupported.");
|
||||
kernel = rewriter.getI64ArrayAttr(kernelSize);
|
||||
|
||||
SmallVector<int64_t> strideArray;
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const stride for adaptive pooling unsupported.");
|
||||
stride = rewriter.getI64ArrayAttr(strideArray);
|
||||
|
||||
SmallVector<int64_t> padArray;
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const pad for adaptive pooling unsupported.");
|
||||
pad = rewriter.getI64ArrayAttr(
|
||||
{padArray[0], padArray[0], padArray[1], padArray[1]});
|
||||
|
||||
SmallVector<int64_t> dilationArray;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const dilation for adaptive pooling unsupported.");
|
||||
// TOSA pooling only supports unit dilation.
|
||||
if (dilationArray[0] > 1 || dilationArray[1] > 1)
|
||||
return op.emitError("Cannot process non-unit pooling dilation.");
|
||||
|
||||
// FIXME: add ceil_mode support.
|
||||
|
||||
int64_t outputHDim =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
|
||||
inputShape[inputRank - 2], kernelSize[0], strideArray[0],
|
||||
padArray[0], padArray[0], dilationArray[0]);
|
||||
int64_t outputWDim =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
|
||||
inputShape[inputRank - 1], kernelSize[1], strideArray[1],
|
||||
padArray[1], padArray[1], dilationArray[1]);
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (inputRank > 3)
|
||||
outputShape.push_back(inputShape[0]);
|
||||
outputShape.push_back(outputHDim);
|
||||
outputShape.push_back(outputWDim);
|
||||
outputShape.push_back(inputShape[inputRank - 3]);
|
||||
outputTy = RankedTensorType::get(outputShape, inputElemTy);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2131,6 +2402,20 @@ public:
|
|||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenPoolingOp<AtenOp, TosaOpT>>(typeConverter, context);
|
||||
INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp);
|
||||
#undef INSERT_POOLING_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
|
||||
context);
|
||||
INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp,
|
||||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
|
|
Loading…
Reference in New Issue