[tosa] Support for AtenAvgPool2d op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/882/merge snapshot-20220527.476
Vivek Khandelwal 2022-05-19 18:05:59 +05:30
parent 6f548fc3ad
commit 06750815d1
3 changed files with 164 additions and 74 deletions

View File

@ -152,4 +152,5 @@ TOSA_PASS_SET = {
"ConvolutionModule2DStatic_basic",
"ElementwiseNegModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
}

View File

@ -2218,9 +2218,8 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
newShape.push_back(1);
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
auto reshapeOp =
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.self(),
rewriter.getI64ArrayAttr(newShape));
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), newType, adaptor.self(), rewriter.getI64ArrayAttr(newShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
@ -2507,7 +2506,8 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
auto mean = zero;
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
// rsqrt of 2
Value rsqrt2 = tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
Value rsqrt2 =
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
/*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg);
@ -2623,9 +2623,9 @@ public:
op, "Unimplemented pooling input parsing function");
}
int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride,
int64_t padBefore, int64_t padAfter,
int64_t dilation) const {
static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim,
int64_t stride, int64_t padBefore,
int64_t padAfter, int64_t dilation) {
if (inputDim == ShapedType::kDynamicSize) {
return ShapedType::kDynamicSize;
} else {
@ -2799,78 +2799,130 @@ public:
}
};
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
template <typename AtenOpT, typename tosaOp>
static Type getOutputTypeForNonAdaptivePoolingOp(
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize,
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray,
SmallVectorImpl<int64_t> &dilationArray) {
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto inputElemTy = inputTy.getElementType();
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0],
padArray[0], dilationArray[0]);
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1],
padArray[1], dilationArray[1]);
SmallVector<int64_t> outputShape;
if (inputRank > 3)
outputShape.push_back(inputShape[0]);
outputShape.push_back(outputHDim);
outputShape.push_back(outputWDim);
outputShape.push_back(inputShape[inputRank - 3]);
return RankedTensorType::get(outputShape, inputElemTy);
}
// Checks the validity of pooling parameters and stores them in the respective
// vector. Also, gets the output type for the pooling op.
template <typename AtenOpT, typename tosaOp>
static LogicalResult getOutputTypeAndPoolingParameters(
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy, ArrayAttr &kernel,
ArrayAttr &stride, ArrayAttr &pad) {
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
if (!inputTy)
return op.emitError("Pooling op requires ranked tensor input");
auto inputRank = inputTy.getRank();
// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for pooling op unsupported");
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(
op, "Non-const stride for pooling op unsupported");
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(
op, "Non-const padding factor for pooling op unsupported");
kernel = rewriter.getI64ArrayAttr(kernelSizeInts);
stride = rewriter.getI64ArrayAttr(strideInts);
pad = rewriter.getI64ArrayAttr(
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
// FIXME: add ceil_mode support.
bool ceilMode;
if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))
return rewriter.notifyMatchFailure(
op, "only support constant bool ceil_mode for pooling op");
if (ceilMode)
return rewriter.notifyMatchFailure(
op, "only support ceil_mode equals to False for pooling op");
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray);
return success();
}
class ConvertAtenMaxPool2dOp
: public ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp> {
public:
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
using ConvertAtenPoolingBaseOp<AtenMaxPool2dOp,
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride,
ArrayAttr &pad, Type &outputTy) const override {
auto inputXchw = adaptor.self();
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
if (!inputTy)
return op.emitError("Adaptive avgpool requires ranked tensor input");
auto inputShape = inputTy.getShape();
auto inputRank = inputTy.getRank();
auto inputElemTy = inputTy.getElementType();
// Rank sanity check.
if (inputTy.getRank() != 4 && inputRank != 3)
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
// Transpose to xHWC
input =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
op, rewriter, inputXchw);
SmallVector<int64_t> kernelSize;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
return rewriter.notifyMatchFailure(
op, "Non-const kernel_size for adaptive pooling unsupported.");
kernel = rewriter.getI64ArrayAttr(kernelSize);
SmallVector<int64_t> strideArray;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray)))
return rewriter.notifyMatchFailure(
op, "Non-const stride for adaptive pooling unsupported.");
stride = rewriter.getI64ArrayAttr(strideArray);
SmallVector<int64_t> padArray;
if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray)))
return rewriter.notifyMatchFailure(
op, "Non-const pad for adaptive pooling unsupported.");
pad = rewriter.getI64ArrayAttr(
{padArray[0], padArray[0], padArray[1], padArray[1]});
SmallVector<int64_t> dilationArray;
SmallVector<int64_t, 2> dilationArray;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray)))
return rewriter.notifyMatchFailure(
op, "Non-const dilation for adaptive pooling unsupported.");
op, "Non-const dilation for pooling op unsupported.");
// TOSA pooling only supports unit dilation.
if (dilationArray[0] > 1 || dilationArray[1] > 1)
return op.emitError("Cannot process non-unit pooling dilation.");
// FIXME: add ceil_mode support.
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
tosa::MaxPool2dOp>(
op, rewriter, adaptor.self(), dilationArray, outputTy, kernel,
stride, pad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");
int64_t outputHDim =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
inputShape[inputRank - 2], kernelSize[0], strideArray[0],
padArray[0], padArray[0], dilationArray[0]);
int64_t outputWDim =
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
inputShape[inputRank - 1], kernelSize[1], strideArray[1],
padArray[1], padArray[1], dilationArray[1]);
SmallVector<int64_t> outputShape;
if (inputRank > 3)
outputShape.push_back(inputShape[0]);
outputShape.push_back(outputHDim);
outputShape.push_back(outputWDim);
outputShape.push_back(inputShape[inputRank - 3]);
outputTy = RankedTensorType::get(outputShape, inputElemTy);
// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
transposePoolingInputToHwc(op, rewriter, adaptor.self());
return success();
}
};
class ConvertAtenAvgPool2dOp
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
public:
using ConvertAtenPoolingBaseOp<AtenAvgPool2dOp,
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, Value &input,
ArrayAttr &kernel, ArrayAttr &stride,
ArrayAttr &pad, Type &outputTy) const override {
SmallVector<int64_t, 2> dilationArray{1, 1};
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
tosa::AvgPool2dOp>(
op, rewriter, adaptor.self(), dilationArray, outputTy, kernel,
stride, pad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");
// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, adaptor.self());
return success();
}
@ -3109,12 +3161,6 @@ public:
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPoolingOp<AtenOp, TosaOpT>>(typeConverter, context);
INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp);
#undef INSERT_POOLING_ATEMOP_PATTERN
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
@ -3123,6 +3169,12 @@ public:
tosa::AvgPool2dOp);
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \

View File

@ -757,3 +757,40 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK: %[[VAL_7:.*]] = torch.constant.none
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32>
// CHECK: }
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
%int7 = torch.constant.int 7
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%none = torch.constant.none
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
return %0 : !torch.vtensor<[1,512,1,1],f32>
}