diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index ae19437da..111fd1613 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -152,4 +152,5 @@ TOSA_PASS_SET = { "ConvolutionModule2DStatic_basic", "ElementwiseNegModule_basic", "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 65fddcca4..75d8f6da8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2218,9 +2218,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( newShape.push_back(1); auto newType = RankedTensorType::get(newShape, selfType.getElementType()); - auto reshapeOp = - rewriter.create(op.getLoc(), newType, adaptor.self(), - rewriter.getI64ArrayAttr(newShape)); + auto reshapeOp = rewriter.create( + op.getLoc(), newType, adaptor.self(), rewriter.getI64ArrayAttr(newShape)); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), reshapeOp); @@ -2507,7 +2506,8 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, auto mean = zero; Value xMinusMean = rewriter.create(loc, outType, x, mean); // rsqrt of 2 - Value rsqrt2 = tosa::getConstTensor(rewriter, op, 0.70710678, {}).getValue(); + Value rsqrt2 = + tosa::getConstTensor(rewriter, op, 0.70710678, {}).getValue(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); Value erf = approximateErfOp(rewriter, op, erfArg); @@ -2623,9 +2623,9 @@ public: op, "Unimplemented pooling input parsing function"); } - int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride, - int64_t padBefore, int64_t padAfter, - int64_t dilation) const { + static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, + int64_t stride, int64_t padBefore, + int64_t padAfter, int64_t dilation) { if (inputDim == ShapedType::kDynamicSize) { return ShapedType::kDynamicSize; } else { @@ -2799,78 +2799,130 @@ public: } }; -template -class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp { +template +static Type getOutputTypeForNonAdaptivePoolingOp( + RankedTensorType inputTy, SmallVectorImpl &kernelSize, + SmallVectorImpl &strideArray, SmallVectorImpl &padArray, + SmallVectorImpl &dilationArray) { + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + auto inputElemTy = inputTy.getElementType(); + + int64_t outputHDim = ConvertAtenPoolingBaseOp::getOutputDim( + inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0], + padArray[0], dilationArray[0]); + int64_t outputWDim = ConvertAtenPoolingBaseOp::getOutputDim( + inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1], + padArray[1], dilationArray[1]); + SmallVector outputShape; + if (inputRank > 3) + outputShape.push_back(inputShape[0]); + outputShape.push_back(outputHDim); + outputShape.push_back(outputWDim); + outputShape.push_back(inputShape[inputRank - 3]); + return RankedTensorType::get(outputShape, inputElemTy); +} + +// Checks the validity of pooling parameters and stores them in the respective +// vector. Also, gets the output type for the pooling op. +template +static LogicalResult getOutputTypeAndPoolingParameters( + AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw, + SmallVectorImpl &dilationArray, Type &outputTy, ArrayAttr &kernel, + ArrayAttr &stride, ArrayAttr &pad) { + + RankedTensorType inputTy = inputXchw.getType().cast(); + if (!inputTy) + return op.emitError("Pooling op requires ranked tensor input"); + + auto inputRank = inputTy.getRank(); + // Rank sanity check. + if (inputTy.getRank() != 4 && inputRank != 3) + return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + + SmallVector kernelSizeInts, strideInts, paddingInts; + if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) + return rewriter.notifyMatchFailure( + op, "Non-const kernel_size for pooling op unsupported"); + if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) + return rewriter.notifyMatchFailure( + op, "Non-const stride for pooling op unsupported"); + if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) + return rewriter.notifyMatchFailure( + op, "Non-const padding factor for pooling op unsupported"); + + kernel = rewriter.getI64ArrayAttr(kernelSizeInts); + stride = rewriter.getI64ArrayAttr(strideInts); + pad = rewriter.getI64ArrayAttr( + {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}); + + // FIXME: add ceil_mode support. + bool ceilMode; + if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode))) + return rewriter.notifyMatchFailure( + op, "only support constant bool ceil_mode for pooling op"); + if (ceilMode) + return rewriter.notifyMatchFailure( + op, "only support ceil_mode equals to False for pooling op"); + + outputTy = getOutputTypeForNonAdaptivePoolingOp( + inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray); + + return success(); +} + +class ConvertAtenMaxPool2dOp + : public ConvertAtenPoolingBaseOp { public: - using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &input, ArrayAttr &kernel, ArrayAttr &stride, ArrayAttr &pad, Type &outputTy) const override { - - auto inputXchw = adaptor.self(); - auto inputTy = inputXchw.getType().template cast(); - if (!inputTy) - return op.emitError("Adaptive avgpool requires ranked tensor input"); - - auto inputShape = inputTy.getShape(); - auto inputRank = inputTy.getRank(); - auto inputElemTy = inputTy.getElementType(); - - // Rank sanity check. - if (inputTy.getRank() != 4 && inputRank != 3) - return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); - - // Transpose to xHWC - input = - ConvertAtenPoolingBaseOp::transposePoolingInputToHwc( - op, rewriter, inputXchw); - - SmallVector kernelSize; - if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize))) - return rewriter.notifyMatchFailure( - op, "Non-const kernel_size for adaptive pooling unsupported."); - kernel = rewriter.getI64ArrayAttr(kernelSize); - - SmallVector strideArray; - if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray))) - return rewriter.notifyMatchFailure( - op, "Non-const stride for adaptive pooling unsupported."); - stride = rewriter.getI64ArrayAttr(strideArray); - - SmallVector padArray; - if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray))) - return rewriter.notifyMatchFailure( - op, "Non-const pad for adaptive pooling unsupported."); - pad = rewriter.getI64ArrayAttr( - {padArray[0], padArray[0], padArray[1], padArray[1]}); - - SmallVector dilationArray; + SmallVector dilationArray; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray))) return rewriter.notifyMatchFailure( - op, "Non-const dilation for adaptive pooling unsupported."); + op, "Non-const dilation for pooling op unsupported."); // TOSA pooling only supports unit dilation. if (dilationArray[0] > 1 || dilationArray[1] > 1) return op.emitError("Cannot process non-unit pooling dilation."); - // FIXME: add ceil_mode support. + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, adaptor.self(), dilationArray, outputTy, kernel, + stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); - int64_t outputHDim = - ConvertAtenPoolingBaseOp::getOutputDim( - inputShape[inputRank - 2], kernelSize[0], strideArray[0], - padArray[0], padArray[0], dilationArray[0]); - int64_t outputWDim = - ConvertAtenPoolingBaseOp::getOutputDim( - inputShape[inputRank - 1], kernelSize[1], strideArray[1], - padArray[1], padArray[1], dilationArray[1]); - SmallVector outputShape; - if (inputRank > 3) - outputShape.push_back(inputShape[0]); - outputShape.push_back(outputHDim); - outputShape.push_back(outputWDim); - outputShape.push_back(inputShape[inputRank - 3]); - outputTy = RankedTensorType::get(outputShape, inputElemTy); + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, adaptor.self()); + + return success(); + } +}; + +class ConvertAtenAvgPool2dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + ArrayAttr &kernel, ArrayAttr &stride, + ArrayAttr &pad, Type &outputTy) const override { + SmallVector dilationArray{1, 1}; + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, adaptor.self(), dilationArray, outputTy, kernel, + stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, adaptor.self()); return success(); } @@ -3109,12 +3161,6 @@ public: INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN -#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp); -#undef INSERT_POOLING_ATEMOP_PATTERN - #define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -3123,6 +3169,12 @@ public: tosa::AvgPool2dOp); #undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN +target.addIllegalOp(); +patterns.add(typeConverter, context); + +target.addIllegalOp(); +patterns.add(typeConverter, context); + #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 7b4d0b7de..8534c694e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -757,3 +757,40 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch %0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = torch.constant.bool false +// CHECK: %[[VAL_6:.*]] = torch.constant.bool true +// CHECK: %[[VAL_7:.*]] = torch.constant.none +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32> +// CHECK: } +func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> { + %int7 = torch.constant.int 7 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %none = torch.constant.none + %kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> + return %0 : !torch.vtensor<[1,512,1,1],f32> +}