diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af7..73f0984e5 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -208,8 +209,6 @@ static LogicalResult createPoolingOp( return success(); } -namespace { - template struct DimensionTraits {}; template <> struct DimensionTraits { @@ -238,247 +237,156 @@ template <> struct DimensionTraits : DimensionTraits {}; +template +LogicalResult createCustomMaxPoolingOp( + OpTy &op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, const TypeConverter *typeConverter, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, bool ceilMode, + SmallVectorImpl &outTensorShape, Value &paddedInput, + ValueRange &results, + std::function &&indicesComputation = + nullptr) { + constexpr bool withIndices = + llvm::is_one_of::value; + constexpr int64_t Dim = DimensionTraits::Dim; + + if (withIndices && !indicesComputation) { + return op->emitError("need to provide indices computation functor for " + "lowering maxpool with indices op"); + } + + Value self = adaptor.getSelf(); + Type elementType = cast(self.getType()).getElementType(); + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(cast(elementType).getFloatSemantics(), + /*Negative=*/true)); + + Value initValue = + rewriter.create(op->getLoc(), smallestFPValueAttr); + + paddedInput = padInputTensor(op, rewriter, self, ceilMode, Dim, strideInts, + paddingInts, initValue); + + auto maxOutputInitialized = computeOutputTensor( + op, rewriter, self, Dim, ceilMode, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, outTensorShape, initValue); + + auto shape = + castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues); + Value windowTensor = rewriter.create( + op->getLoc(), getAsOpFoldResult(shape), elementType); + + MLIRContext *context = rewriter.getContext(); + + SmallVector inputIndexing( + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}); + SmallVector maxOutputIndexing = inputIndexing; + SmallVector kernelIndexing; + std::optional> indicesIndexing; + for (int i = 0; i < Dim; i++) { + mlir::AffineExpr poolingDim = rewriter.getAffineDimExpr(i + 2); + mlir::AffineExpr kernelDim = rewriter.getAffineDimExpr(i + 2 + Dim); + inputIndexing.push_back( + poolingDim * getAffineConstantExpr(strideInts[i], context) + + kernelDim * getAffineConstantExpr(dilationInts[i], context)); + maxOutputIndexing.push_back(poolingDim); + kernelIndexing.push_back(kernelDim); + } + + auto iteratorTypes = + SmallVector(2 + Dim, utils::IteratorType::parallel); + iteratorTypes.append(Dim, utils::IteratorType::reduction); + SmallVector indexingMaps = { + mlir::AffineMap::get(2 + Dim * 2, 0, inputIndexing, context), + mlir::AffineMap::get(2 + Dim * 2, 0, kernelIndexing, context), + mlir::AffineMap::get(2 + Dim * 2, 0, maxOutputIndexing, context)}; + SmallVector outputs({maxOutputInitialized}); + SmallVector outTypes({maxOutputInitialized.getType()}); + + if constexpr (withIndices) { + // Indices tensor has same indexing/shape as max value tensor. + indexingMaps.push_back( + mlir::AffineMap::get(2 + Dim * 2, 0, maxOutputIndexing, context)); + RankedTensorType indicesRankedTensorType = cast( + typeConverter->convertType(op->getResult(1).getType())); + Value cstMinusOne = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(-1)); + Value indicesOutputInitialized = + createInitTensor(rewriter, op->getLoc(), outTensorShape, + indicesRankedTensorType.getElementType(), cstMinusOne); + outputs.push_back(indicesOutputInitialized); + outTypes.push_back(indicesOutputInitialized.getType()); + } + + results = + rewriter + .create( + op->getLoc(), + /*result_types=*/outTypes, + /*operands=*/ValueRange({paddedInput, windowTensor}), + /*outputs=*/outputs, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value currentVal = args[0], accMaxValue = args[2]; + if constexpr (withIndices) { + Value curIndex = args[3]; + SmallVector iterators; + for (int i = 0; i < Dim * 2; i++) { + iterators.push_back(b.create(loc, i + 2)); + } + Value pred = b.create( + loc, arith::CmpFPredicate::UGT, currentVal, accMaxValue); + // Consider the corner case: the max pooling result is same as + // padding value, which is -inf. We should return the first + // index of pooling window but not -1. + pred = b.create( + loc, pred, + b.create( + loc, arith::CmpIPredicate::eq, curIndex, + b.create( + loc, b.getI64IntegerAttr(-1)))); + ValueRange outResults = + b.create( + loc, pred, + [&](OpBuilder &b, Location loc) { + SmallVector curResults{currentVal}; + if constexpr (withIndices) { + curResults.push_back( + indicesComputation(b, loc, iterators)); + } + b.create(loc, curResults); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, + args.drop_front(/*n=*/2)); + }) + ->getResults(); + b.create(loc, outResults); + } else { + Value max_result = + b.create(loc, currentVal, accMaxValue); + b.create(loc, max_result); + } + }) + ->getResults(); + + return success(); +} + +namespace { template class ConvertAtenMaxPoolOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - static const bool withIndices = - llvm::is_one_of::value; - private: static const int64_t Dim = DimensionTraits::Dim; - LogicalResult createPoolingMax3D(OpTy &op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, - SmallVectorImpl &kernelSizeIntValues, - SmallVectorImpl &strideInts, - SmallVectorImpl &paddingInts, - SmallVectorImpl &dilationInts, - bool ceilMode, - SmallVectorImpl &outTensorShape, - Value &paddedInput, Value &poolingOp) const { - static_assert(Dim == 3, "op must be MaxPool3d or MaxPool3dWithIndices"); - Value self = adaptor.getSelf(); - Type elementType = cast(self.getType()).getElementType(); - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(cast(elementType).getFloatSemantics(), - /*Negative=*/true)); - Value initValue = - rewriter.create(op->getLoc(), smallestFPValueAttr); - - paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, strideInts, - paddingInts, initValue); - - auto outTensorInitialized = computeOutputTensor( - op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts, - kernelSizeIntValues, outTensorShape, initValue); - - auto shape = - castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues); - Value windowTensor = rewriter.create( - op->getLoc(), getAsOpFoldResult(shape), elementType); - - MLIRContext *context = rewriter.getContext(); - - auto mapInput = mlir::AffineMap::get( - 8, 0, - { - rewriter.getAffineDimExpr(0), // n - rewriter.getAffineDimExpr(1), // c - // dim_d * stride_d + kernal_d * dilation_d - rewriter.getAffineDimExpr(2) * - getAffineConstantExpr(strideInts[0], context) + - rewriter.getAffineDimExpr(5) * - getAffineConstantExpr(dilationInts[0], context), - // dim_h * stride_h + kernal_h * dilation_h - rewriter.getAffineDimExpr(3) * - getAffineConstantExpr(strideInts[1], context) + - rewriter.getAffineDimExpr(6) * - getAffineConstantExpr(dilationInts[1], context), - // dim_w * stride_w + kernal_w * dilation_w - rewriter.getAffineDimExpr(4) * - getAffineConstantExpr(strideInts[2], context) + - rewriter.getAffineDimExpr(7) * - getAffineConstantExpr(dilationInts[2], context), - }, - context); - auto mapKernel = - mlir::AffineMap::get(8, 0, - { - rewriter.getAffineDimExpr(5), // kd - rewriter.getAffineDimExpr(6), // kh - rewriter.getAffineDimExpr(7) // kw - }, - context); - auto mapOutput = mlir::AffineMap::get( - 8, 0, - {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1), - rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3), - rewriter.getAffineDimExpr(4)}, - context); - auto iteratorTypes = - SmallVector(5, utils::IteratorType::parallel); - iteratorTypes.append(3, utils::IteratorType::reduction); - SmallVector indexingMaps = {mapInput, mapKernel, mapOutput}; - poolingOp = rewriter - .create( - op->getLoc(), - /* result types */ outTensorInitialized.getType(), - /* operands */ ValueRange({paddedInput, windowTensor}), - /* outputs */ outTensorInitialized, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value currentVal = args[0], accMaxValue = args[2]; - Value max_result = b.create( - loc, currentVal, accMaxValue); - b.create(loc, max_result); - }) - .getResult(0); - - return success(); - } - - // Returns the corresponding indices of the input tensor for the max pooling - // result tensor. - // - // For finding the indices, we follow the below method: - // - // Take maxpool2d as an example to illustrate. Let's say the input tensor is a - // 4-d tensor. The maxpool2d and indices will also be a 4-d tensor. Then: - // for i in range(N): - // for j in range(C): - // for m in range(Hout): - // for n in range(Wout): - // for p in range(kH): - // for r in range(kW): - // indexH = m * stride[0] + p * dilation[0] - // indexW = n * stride[0] + r * dilation[0] - // if paddedInput[i, j, indexH, indexW] == - // maxPool2d[i, j, m, n]: - // indices[i, j, m, n] = - // (indexH - padding[0]) * W + - // (indexW - padding[1]) - // - LogicalResult - computeMaxPoolingIndices(Value maxPool, Value paddedInput, OpTy &op, - typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, - SmallVectorImpl &outTensorShape, - SmallVectorImpl &kernelSizeIntValues, - SmallVectorImpl &strideInts, - SmallVectorImpl &paddingInts, - SmallVectorImpl &dilationInts, int64_t rank, - Value &indicesResult) const { - Location loc = op->getLoc(); - RankedTensorType indicesRankedTensorType = cast( - this->getTypeConverter()->convertType(op->getResult(1).getType())); - Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value indicesTensor = - createInitTensor(rewriter, loc, outTensorShape, - indicesRankedTensorType.getElementType(), cstMinusOne); - - SmallVector kernelSize = - castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); - SmallVector padding = - getAsConstantIndexValues(rewriter, loc, paddingInts); - SmallVector dilation = - getAsConstantIndexValues(rewriter, loc, dilationInts); - SmallVector kernelStride = - getAsConstantIndexValues(rewriter, loc, strideInts); - - Value windowTensor = rewriter.create( - loc, getAsOpFoldResult(kernelSize), - indicesRankedTensorType.getElementType()); - - SmallVector inputExprs, outputExprs, kernelExprs; - for (unsigned i = 0; i < rank; i++) { - inputExprs.push_back(rewriter.getAffineDimExpr(i)); - outputExprs.push_back(rewriter.getAffineDimExpr(i)); - } - for (unsigned i = 0; i < rank - 2; i++) { - kernelExprs.push_back(rewriter.getAffineDimExpr(i + rank)); - } - - // If computing indices for maxpool2d, we have six dimensions here. Each - // corresponding to N, C, Hout, Wout, kH, and kW, respectively, as described - // in the algorithm above. - SmallVector indexingMaps = AffineMap::inferFromExprList( - {inputExprs, kernelExprs, outputExprs}, rewriter.getContext()); - SmallVector iteratorTypes( - rank, utils::IteratorType::parallel); - iteratorTypes.append(rank - 2, utils::IteratorType::reduction); - - // Extract pooling dimensions of input shape. - SmallVector inputSubShape; - for (unsigned i = 0; i < rank - 2; i++) { - inputSubShape.push_back( - getDimOp(rewriter, loc, adaptor.getSelf(), i + 2)); - } - - indicesResult = - rewriter - .create( - loc, /*resultTensorTypes=*/indicesTensor.getType(), - /*inputs=*/ValueRange({maxPool, windowTensor}), - /*outputs=*/indicesTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value maxVal = args[0], res = args[2]; - - SmallVector inputDims; - inputDims.append({b.create(loc, 0), - b.create(loc, 1)}); - for (unsigned i = 2; i < rank; i++) { - Value mainIndex = b.create(loc, i); - Value subIndex = - b.create(loc, i + rank - 2); - Value origin = b.create(loc, mainIndex, - kernelStride[i - 2]); - Value offset = - b.create(loc, subIndex, dilation[i - 2]); - inputDims.push_back( - b.create(loc, origin, offset)); - } - - Value input = - b.create(loc, paddedInput, inputDims); - Value pred = b.create( - loc, arith::CmpFPredicate::OEQ, input, maxVal); - - Value outIndex = - b.create(loc, b.getIndexAttr(0)); - Value curInputStride = - b.create(loc, b.getIndexAttr(1)); - for (unsigned i = 0; i < rank - 2; i++) { - Value minusPadding = b.create( - loc, inputDims[rank - 1 - i], padding[rank - 3 - i]); - Value timesStride = b.create( - loc, minusPadding, curInputStride); - outIndex = - b.create(loc, outIndex, timesStride); - curInputStride = b.create( - loc, curInputStride, inputSubShape[rank - 3 - i]); - } - Value result = b.create( - loc, pred, castIndexToInt64(b, loc, outIndex), res); - - Value predInvalidIndex = b.create( - loc, arith::CmpIPredicate::eq, res, cstMinusOne); - Value out = b.create(loc, predInvalidIndex, - result, res); - - b.create(loc, out); - }) - .getResult(0); - - return success(); - } - public: LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, @@ -546,32 +454,124 @@ public: paddedInput, maxPool))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); } else { - if (failed(createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues, - strideInts, paddingInts, dilationInts, - ceilMode, outTensorShape, paddedInput, - maxPool))) + ValueRange poolingResults; + if (failed(createCustomMaxPoolingOp( + op, adaptor, rewriter, typeConverter, kernelSizeIntValues, + strideInts, paddingInts, dilationInts, ceilMode, outTensorShape, + paddedInput, poolingResults))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d"); + maxPool = poolingResults.front(); } - Value outMaxPool = rewriter.create( - op->getLoc(), maxPoolResultType, maxPool); - SmallVector outResult({outMaxPool}); - if (withIndices) { - Value indicesResult; - if (failed(computeMaxPoolingIndices( - maxPool, paddedInput, op, adaptor, rewriter, outTensorShape, - kernelSizeIntValues, strideInts, paddingInts, dilationInts, - selfRank, indicesResult))) - return rewriter.notifyMatchFailure(op, - "unable to compute maxpool indices"); - Type indicesResultType = - typeConverter->convertType(op->getResult(1).getType()); - Value outIndices = rewriter.create( - op->getLoc(), indicesResultType, indicesResult); - outResult.push_back(outIndices); - } - rewriter.replaceOp(op, outResult); + rewriter.replaceOpWithNewOp(op, maxPoolResultType, maxPool); + return success(); + } +}; +} // namespace +namespace { +template +class ConvertAtenMaxPoolWithIndicesOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + static const int64_t Dim = DimensionTraits::Dim; + +public: + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = this->getTypeConverter(); + + bool ceilMode; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int dilations"); + + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) + return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); + + // Initialize padding/dilation/kernelStride to help computing indices + // correspond to max pooling values. + SmallVector kernelSize = + castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); + SmallVector padding = + getAsConstantIndexValues(rewriter, loc, paddingInts); + SmallVector dilation = + getAsConstantIndexValues(rewriter, loc, dilationInts); + SmallVector kernelStride = + getAsConstantIndexValues(rewriter, loc, strideInts); + // Extract pooling dimensions of input shape. + SmallVector inputSubShape; + for (int i = 0; i < Dim; i++) { + inputSubShape.push_back( + getDimOp(rewriter, loc, adaptor.getSelf(), i + 2)); + } + + auto indicesComputation = [&](OpBuilder &b, Location loc, + ValueRange iteratorDims) -> Value { + SmallVector inputDims; + for (int i = 0; i < Dim; i++) { + Value origin = + b.create(loc, iteratorDims[i], kernelStride[i]); + Value offset = + b.create(loc, iteratorDims[i + Dim], dilation[i]); + inputDims.push_back(b.create(loc, origin, offset)); + } + + Value outIndex = b.create(loc, b.getIndexAttr(0)); + Value curInputStride = + b.create(loc, b.getIndexAttr(1)); + Value validIndex = + b.create(loc, b.getIntegerAttr(b.getI1Type(), 1)); + Value cstZero = b.create(loc, b.getIndexAttr(0)); + for (int i = 0; i < Dim; i++) { + Value minusPadding = b.create( + loc, inputDims[Dim - 1 - i], padding[Dim - 1 - i]); + validIndex = b.create( + loc, validIndex, + b.create(loc, arith::CmpIPredicate::sge, + minusPadding, cstZero)); + Value timesStride = + b.create(loc, minusPadding, curInputStride); + outIndex = b.create(loc, outIndex, timesStride); + curInputStride = b.create(loc, curInputStride, + inputSubShape[Dim - 1 - i]); + } + return b.create( + loc, validIndex, castIndexToInt64(b, loc, outIndex), + b.create(loc, b.getI64IntegerAttr(-1))); + }; + + Value paddedInput; + SmallVector outTensorShape; + ValueRange results; + if (failed(createCustomMaxPoolingOp( + op, adaptor, rewriter, typeConverter, kernelSizeIntValues, + strideInts, paddingInts, dilationInts, ceilMode, outTensorShape, + paddedInput, results, std::move(indicesComputation)))) + return rewriter.notifyMatchFailure( + op, "unable to compute maxpool with indices"); + + Type maxPoolResultType = + typeConverter->convertType(op->getResult(0).getType()); + Type indicesResultType = + typeConverter->convertType(op->getResult(1).getType()); + Value outMaxpool = rewriter.create(loc, maxPoolResultType, + results.front()); + Value outIndices = + rewriter.create(loc, indicesResultType, results.back()); + + rewriter.replaceOp(op, {outMaxpool, outIndices}); return success(); } }; @@ -1521,10 +1521,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context);