diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 482c82038..a31120461 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -90,4 +90,5 @@ TOSA_PASS_SET = { "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", + "MaxPool2dStaticModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1dffa2055..ddf20f6b7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -17,9 +17,9 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -1987,6 +1987,277 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template +class ConvertAtenPoolingBaseOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + + // Different pooling variants need to process inputs differently, e.g. + // adaptive pooling generates the kernel size rather than receive it. This + // function also transposes inputs. + virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &input, ArrayAttr &kernel, + ArrayAttr &stride, ArrayAttr &pad, + Type &outputTy) const { + return rewriter.notifyMatchFailure( + op, "Unimplemented pooling input parsing function"); + } + + int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride, + int64_t padBefore, int64_t padAfter, + int64_t dilation) const { + if (inputDim == ShapedType::kDynamicSize) { + return ShapedType::kDynamicSize; + } else { + return ( + (inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1) / + stride + + 1); + } + } + + // Apply the transposeDims vector on input to generate a transposed form. + Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter, + Value input, ArrayRef transposeDims) const { + auto inputTy = input.getType().template cast(); + auto inputElemTy = inputTy.getElementType(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + + llvm::Optional transposeDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/transposeDims, + /*shape=*/{static_cast(inputRank)}); + + SmallVector transposedInputShape; + for (auto &dim : transposeDims) + transposedInputShape.push_back(inputShape[dim]); + auto transposedInputType = + RankedTensorType::get(transposedInputShape, inputElemTy); + return rewriter + .create(op->getLoc(), transposedInputType, input, + transposeDimsConst.getValue()) + .getResult(); + } + + Value transposePoolingInputToHwc(AtenOpT op, + ConversionPatternRewriter &rewriter, + Value input) const { + auto inputRank = + input.getType().template cast().getRank(); + + SmallVector nchwToNhwc4DTransposeDims({0, 2, 3, 1}); + SmallVector chwToHwc3DTransposeDims({1, 2, 0}); + + return transposeTensor(op, rewriter, input, + inputRank == 3 ? chwToHwc3DTransposeDims + : nchwToNhwc4DTransposeDims); + } + + Value transposePoolingOutputToChw(AtenOpT op, + ConversionPatternRewriter &rewriter, + Value input) const { + auto inputTy = input.getType().template cast(); + auto inputRank = inputTy.getRank(); + + SmallVector nhwcToNchw4DTransposeDims({0, 3, 1, 2}); + SmallVector hwcToChw3DTransposeDims({2, 0, 1}); + + return transposeTensor(op, rewriter, input, + inputRank == 3 ? hwcToChw3DTransposeDims + : nhwcToNchw4DTransposeDims); + } + + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input; + ArrayAttr kernel, stride, pad; + Type outputTy; + + // Attempts to read input and kernel parameters, or synthesize them in the + // case of adaptive pooling. Also performs input CHW->HWC transpose. + if (failed(processInputs(op, adaptor, rewriter, input, kernel, stride, pad, + outputTy))) + return op.emitError("Failed to process inputs for pooling"); + + auto pooledOutput = + rewriter + .create(op->getLoc(), outputTy, input, kernel, stride, pad) + .getResult(); + + auto transposedOutput = + ConvertAtenPoolingBaseOp::transposePoolingOutputToChw( + op, rewriter, pooledOutput); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + transposedOutput); + + return success(); + } +}; + +template +class ConvertAtenAdaptivePoolingOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + ArrayAttr &kernel, ArrayAttr &stride, + ArrayAttr &pad, Type &outputTy) const override { + auto inputXchw = adaptor.self(); + auto inputTy = inputXchw.getType().template cast(); + if (!inputTy) + return op.emitError("Adaptive avgpool requires ranked tensor input"); + + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + auto inputElemTy = inputTy.getElementType(); + + // Rank sanity check. + if (inputTy.getRank() != 4 && inputRank != 3) + return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + + int64_t inputHDim = inputShape[inputRank - 2]; + int64_t inputWDim = inputShape[inputRank - 1]; + + SmallVector outputSize; + if (!matchPattern(op.output_size(), m_TorchConstantIntList(outputSize))) + return rewriter.notifyMatchFailure( + op, "Non-const output_size for adaptive pooling unsupported."); + + SmallVector kernelDims; + int64_t outputHDim, outputWDim; + if (outputSize.size() == 1) { + outputHDim = outputWDim = outputSize[0]; + } else { + if (outputSize.size() != 2) + return op.emitError( + "Adaptive avgpool output_size not 1 or 2 elements."); + + // Assumes 'None' (e.g. output_size=(None, 5) ) is expressed as <=0. + outputHDim = + (outputSize[0] <= 0) ? inputShape[inputRank - 2] : outputSize[0]; + outputWDim = + (outputSize[1] <= 0) ? inputShape[inputRank - 1] : outputSize[1]; + } + + // In adaptive pooling, + // stride = inputDim // outputDim + // kernel = inputDim - (outputDim-1)* stride + // pad = 0, dilation = 1 + + int64_t strideH = inputShape[inputRank - 2] / outputHDim; + int64_t strideW = inputShape[inputRank - 1] / outputWDim; + + kernelDims.push_back(inputHDim - (outputHDim - 1) * strideH); + kernelDims.push_back(inputWDim - (outputWDim - 1) * strideW); + + SmallVector outputShape; + if (inputRank > 3) + outputShape.push_back(inputShape[0]); + outputShape.push_back(outputHDim); + outputShape.push_back(outputWDim); + outputShape.push_back(inputShape[inputRank - 3]); + + // Transpose to xHWC + input = + ConvertAtenPoolingBaseOp::transposePoolingInputToHwc( + op, rewriter, inputXchw); + kernel = rewriter.getI64ArrayAttr(kernelDims); + stride = rewriter.getI64ArrayAttr({strideH, strideW}); + // Adaptive pooling does unit dilation and zero pad. + pad = rewriter.getI64ArrayAttr({0, 0, 0, 0}); + outputTy = RankedTensorType::get(outputShape, inputElemTy); + + return success(); + } +}; + +template +class ConvertAtenPoolingOp : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + ArrayAttr &kernel, ArrayAttr &stride, + ArrayAttr &pad, Type &outputTy) const override { + + auto inputXchw = adaptor.self(); + auto inputTy = inputXchw.getType().template cast(); + if (!inputTy) + return op.emitError("Adaptive avgpool requires ranked tensor input"); + + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + auto inputElemTy = inputTy.getElementType(); + + // Rank sanity check. + if (inputTy.getRank() != 4 && inputRank != 3) + return op.emitError("NCHW->NHWC transpose requires 3D or 4D tensor"); + + // Transpose to xHWC + input = + ConvertAtenPoolingBaseOp::transposePoolingInputToHwc( + op, rewriter, inputXchw); + + SmallVector kernelSize; + if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize))) + return rewriter.notifyMatchFailure( + op, "Non-const kernel_size for adaptive pooling unsupported."); + kernel = rewriter.getI64ArrayAttr(kernelSize); + + SmallVector strideArray; + if (!matchPattern(op.stride(), m_TorchConstantIntList(strideArray))) + return rewriter.notifyMatchFailure( + op, "Non-const stride for adaptive pooling unsupported."); + stride = rewriter.getI64ArrayAttr(strideArray); + + SmallVector padArray; + if (!matchPattern(op.padding(), m_TorchConstantIntList(padArray))) + return rewriter.notifyMatchFailure( + op, "Non-const pad for adaptive pooling unsupported."); + pad = rewriter.getI64ArrayAttr( + {padArray[0], padArray[0], padArray[1], padArray[1]}); + + SmallVector dilationArray; + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationArray))) + return rewriter.notifyMatchFailure( + op, "Non-const dilation for adaptive pooling unsupported."); + // TOSA pooling only supports unit dilation. + if (dilationArray[0] > 1 || dilationArray[1] > 1) + return op.emitError("Cannot process non-unit pooling dilation."); + + // FIXME: add ceil_mode support. + + int64_t outputHDim = + ConvertAtenPoolingBaseOp::getOutputDim( + inputShape[inputRank - 2], kernelSize[0], strideArray[0], + padArray[0], padArray[0], dilationArray[0]); + int64_t outputWDim = + ConvertAtenPoolingBaseOp::getOutputDim( + inputShape[inputRank - 1], kernelSize[1], strideArray[1], + padArray[1], padArray[1], dilationArray[1]); + SmallVector outputShape; + if (inputRank > 3) + outputShape.push_back(inputShape[0]); + outputShape.push_back(outputHDim); + outputShape.push_back(outputWDim); + outputShape.push_back(inputShape[inputRank - 3]); + outputTy = RankedTensorType::get(outputShape, inputElemTy); + + return success(); + } +}; + } // namespace // ----------------------------------------------------------------------------- @@ -2131,6 +2402,20 @@ public: INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN +#define INSERT_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_POOLING_ATENOP_PATTERN(AtenMaxPool2dOp, tosa::MaxPool2dOp); +#undef INSERT_POOLING_ATEMOP_PATTERN + +#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); +#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context);