mirror of https://github.com/llvm/torch-mlir
[tosa] Support for AtenAvgPool2d op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/882/merge snapshot-20220527.476
parent
6f548fc3ad
commit
06750815d1
|
@ -152,4 +152,5 @@ TOSA_PASS_SET = {
|
|||
"ConvolutionModule2DStatic_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"TestMultipleTensorReturn_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||
}
|
||||
|
|
|
@ -2218,9 +2218,8 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
|||
newShape.push_back(1);
|
||||
|
||||
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
|
||||
auto reshapeOp =
|
||||
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.self(),
|
||||
rewriter.getI64ArrayAttr(newShape));
|
||||
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
|
||||
op.getLoc(), newType, adaptor.self(), rewriter.getI64ArrayAttr(newShape));
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
|
||||
|
@ -2507,7 +2506,8 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
|
|||
auto mean = zero;
|
||||
Value xMinusMean = rewriter.create<tosa::SubOp>(loc, outType, x, mean);
|
||||
// rsqrt of 2
|
||||
Value rsqrt2 = tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
|
||||
Value rsqrt2 =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.70710678, {}).getValue();
|
||||
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
|
||||
/*shift=*/0);
|
||||
Value erf = approximateErfOp(rewriter, op, erfArg);
|
||||
|
@ -2623,9 +2623,9 @@ public:
|
|||
op, "Unimplemented pooling input parsing function");
|
||||
}
|
||||
|
||||
int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride,
|
||||
int64_t padBefore, int64_t padAfter,
|
||||
int64_t dilation) const {
|
||||
static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim,
|
||||
int64_t stride, int64_t padBefore,
|
||||
int64_t padAfter, int64_t dilation) {
|
||||
if (inputDim == ShapedType::kDynamicSize) {
|
||||
return ShapedType::kDynamicSize;
|
||||
} else {
|
||||
|
@ -2799,78 +2799,130 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT> {
|
||||
template <typename AtenOpT, typename tosaOp>
|
||||
static Type getOutputTypeForNonAdaptivePoolingOp(
|
||||
RankedTensorType inputTy, SmallVectorImpl<int64_t> &kernelSize,
|
||||
SmallVectorImpl<int64_t> &strideArray, SmallVectorImpl<int64_t> &padArray,
|
||||
SmallVectorImpl<int64_t> &dilationArray) {
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
int64_t outputHDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
|
||||
inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0],
|
||||
padArray[0], dilationArray[0]);
|
||||
int64_t outputWDim = ConvertAtenPoolingBaseOp<AtenOpT, tosaOp>::getOutputDim(
|
||||
inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1],
|
||||
padArray[1], dilationArray[1]);
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (inputRank > 3)
|
||||
outputShape.push_back(inputShape[0]);
|
||||
outputShape.push_back(outputHDim);
|
||||
outputShape.push_back(outputWDim);
|
||||
outputShape.push_back(inputShape[inputRank - 3]);
|
||||
return RankedTensorType::get(outputShape, inputElemTy);
|
||||
}
|
||||
|
||||
// Checks the validity of pooling parameters and stores them in the respective
|
||||
// vector. Also, gets the output type for the pooling op.
|
||||
template <typename AtenOpT, typename tosaOp>
|
||||
static LogicalResult getOutputTypeAndPoolingParameters(
|
||||
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
|
||||
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy, ArrayAttr &kernel,
|
||||
ArrayAttr &stride, ArrayAttr &pad) {
|
||||
|
||||
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
|
||||
if (!inputTy)
|
||||
return op.emitError("Pooling op requires ranked tensor input");
|
||||
|
||||
auto inputRank = inputTy.getRank();
|
||||
// Rank sanity check.
|
||||
if (inputTy.getRank() != 4 && inputRank != 3)
|
||||
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
|
||||
|
||||
SmallVector<int64_t, 2> kernelSizeInts, strideInts, paddingInts;
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const kernel_size for pooling op unsupported");
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const stride for pooling op unsupported");
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const padding factor for pooling op unsupported");
|
||||
|
||||
kernel = rewriter.getI64ArrayAttr(kernelSizeInts);
|
||||
stride = rewriter.getI64ArrayAttr(strideInts);
|
||||
pad = rewriter.getI64ArrayAttr(
|
||||
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
|
||||
|
||||
// FIXME: add ceil_mode support.
|
||||
bool ceilMode;
|
||||
if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant bool ceil_mode for pooling op");
|
||||
if (ceilMode)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support ceil_mode equals to False for pooling op");
|
||||
|
||||
outputTy = getOutputTypeForNonAdaptivePoolingOp<AtenOpT, tosaOp>(
|
||||
inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
class ConvertAtenMaxPool2dOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::ConvertAtenPoolingBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
using ConvertAtenPoolingBaseOp<AtenMaxPool2dOp,
|
||||
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||
LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
ArrayAttr &kernel, ArrayAttr &stride,
|
||||
ArrayAttr &pad, Type &outputTy) const override {
|
||||
|
||||
auto inputXchw = adaptor.self();
|
||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy)
|
||||
return op.emitError("Adaptive avgpool requires ranked tensor input");
|
||||
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
||||
// Rank sanity check.
|
||||
if (inputTy.getRank() != 4 && inputRank != 3)
|
||||
return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor");
|
||||
|
||||
// Transpose to xHWC
|
||||
input =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
|
||||
op, rewriter, inputXchw);
|
||||
|
||||
SmallVector<int64_t> kernelSize;
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const kernel_size for adaptive pooling unsupported.");
|
||||
kernel = rewriter.getI64ArrayAttr(kernelSize);
|
||||
|
||||
SmallVector<int64_t> strideArray;
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const stride for adaptive pooling unsupported.");
|
||||
stride = rewriter.getI64ArrayAttr(strideArray);
|
||||
|
||||
SmallVector<int64_t> padArray;
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const pad for adaptive pooling unsupported.");
|
||||
pad = rewriter.getI64ArrayAttr(
|
||||
{padArray[0], padArray[0], padArray[1], padArray[1]});
|
||||
|
||||
SmallVector<int64_t> dilationArray;
|
||||
SmallVector<int64_t, 2> dilationArray;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const dilation for adaptive pooling unsupported.");
|
||||
op, "Non-const dilation for pooling op unsupported.");
|
||||
// TOSA pooling only supports unit dilation.
|
||||
if (dilationArray[0] > 1 || dilationArray[1] > 1)
|
||||
return op.emitError("Cannot process non-unit pooling dilation.");
|
||||
|
||||
// FIXME: add ceil_mode support.
|
||||
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool2dOp,
|
||||
tosa::MaxPool2dOp>(
|
||||
op, rewriter, adaptor.self(), dilationArray, outputTy, kernel,
|
||||
stride, pad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid pooling parameters or input type");
|
||||
|
||||
int64_t outputHDim =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
|
||||
inputShape[inputRank - 2], kernelSize[0], strideArray[0],
|
||||
padArray[0], padArray[0], dilationArray[0]);
|
||||
int64_t outputWDim =
|
||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::getOutputDim(
|
||||
inputShape[inputRank - 1], kernelSize[1], strideArray[1],
|
||||
padArray[1], padArray[1], dilationArray[1]);
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (inputRank > 3)
|
||||
outputShape.push_back(inputShape[0]);
|
||||
outputShape.push_back(outputHDim);
|
||||
outputShape.push_back(outputWDim);
|
||||
outputShape.push_back(inputShape[inputRank - 3]);
|
||||
outputTy = RankedTensorType::get(outputShape, inputElemTy);
|
||||
// Transpose to xHWC
|
||||
input = ConvertAtenPoolingBaseOp<AtenMaxPool2dOp, tosa::MaxPool2dOp>::
|
||||
transposePoolingInputToHwc(op, rewriter, adaptor.self());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertAtenAvgPool2dOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenAvgPool2dOp,
|
||||
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||
LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
ArrayAttr &kernel, ArrayAttr &stride,
|
||||
ArrayAttr &pad, Type &outputTy) const override {
|
||||
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
|
||||
tosa::AvgPool2dOp>(
|
||||
op, rewriter, adaptor.self(), dilationArray, outputTy, kernel,
|
||||
stride, pad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid pooling parameters or input type");
|
||||
|
||||
// Transpose to xHWC
|
||||
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
|
||||
transposePoolingInputToHwc(op, rewriter, adaptor.self());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3109,12 +3161,6 @@ public:
|
|||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenPoolingOp<AtenOp, TosaOpT>>(typeConverter, context);
|
||||
INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp);
|
||||
#undef INSERT_POOLING_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>(typeConverter, \
|
||||
|
@ -3123,6 +3169,12 @@ public:
|
|||
tosa::AvgPool2dOp);
|
||||
#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN
|
||||
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
|
|
|
@ -757,3 +757,40 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
|
|||
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_7:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
|
||||
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32>
|
||||
// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> {
|
||||
%int7 = torch.constant.int 7
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,512,1,1],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue