//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Checks the validity of pooling parameters and stores them in the respective // vector. template static LogicalResult checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, bool &ceilMode, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { // Pattern match against the op's original operands, because otherwise we // will get the lowered version of the operands which is harder to pattern // match. SmallVector kernelSizeTorchInt; if (!getListConstructElements(op.getKernelSize(), kernelSizeTorchInt)) { return rewriter.notifyMatchFailure(op, "unimplemented: the kernel size is " "not constructed from ListConstruct"); } kernelSizeIntValues = getTypeConvertedValues( rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt); if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); // If `stride` is not specified by the user, it is assigned the value of empty // list during import. For such a case, the stride value is the kernel size. // See: // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d if (strideInts.empty()) { if (!matchPattern(op.getKernelSize(), m_TorchListOfConstantInts(strideInts))) { return rewriter.notifyMatchFailure( op, "if stride is the empty list, kernel_size must be a list of " "constant ints"); } } if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure(op, "only support constant int paddings"); if (!matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode))) return rewriter.notifyMatchFailure(op, "only support constant bool ceil_mode"); return success(); } // Creates a pooling operation based on the type specified by `OpTy` and // arguments passed. template static LogicalResult createPoolingOp( Operation *op, ConversionPatternRewriter &rewriter, Value self, bool supportNonFPInput, bool ceilMode, int64_t dimensionality, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); Type elementType = self.getType().cast().getElementType(); if (!elementType.isa() && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); SmallVector lowPaddingIncludingNC = {0, 0}; lowPaddingIncludingNC.append(paddingInts); SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; if (ceilMode) { for (int64_t i = 0; i < dimensionality; ++i) { highPaddingIncludingNC[i + 2] += strideInts[i]; } } Value initValue = rewriter.create(loc, cast(initValueAttr)); paddedInput = torch_to_linalg::getPaddedTensor( op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, initValue); Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); SmallVector paddingIntValues = getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); // Get dimension size for each dimension and calculate output size for (int64_t i = dimensionality - 1; i > -1; --i) { Value dimSize = getDimOp(rewriter, loc, self, i + 2); Value outDim = torch_to_linalg::getOutputDimForConvOps( rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i], kernelSizeIntValues[i], strideIntValues[i], ceilMode); outTensorShape.insert(outTensorShape.begin(), {outDim}); } // Create output tensor initialized with smallest floating point value. outTensorShape.insert(outTensorShape.begin(), {N, C}); Value outTensorInitialized = createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto shape = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); Value windowTensor = rewriter.create( loc, getAsOpFoldResult(shape), elementType); result = rewriter .create(loc, outTensorInitialized.getType(), ValueRange{paddedInput, windowTensor}, outTensorInitialized, stridesAttr, dilationAttr) .getResult(0); return success(); } namespace { class ConvertAtenMaxPool2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); int64_t selfRank = self.getType().cast().getRank(); // TODO: Add support for 3D inputs. if (selfRank == 3) return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (failed(checkAndGetPoolingParameters( op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*Negative=*/true)); SmallVector outTensorShape; // `maxpool2d` contains the result of maxpool2d operation over the input. Value maxPool2d, paddedInput; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); return success(); } }; } // namespace namespace { // Returns the result of maxpool2d over the input tensor. And the corresponding // indices of the input tensor for the values of the result tensor. // // The result of the maxpool2d operation is calculated using the helper function // written above. For finding the indices, we follow the below method: // // Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will // also be a 4-d tensor. Then: // for i in range(N): // for j in range(C): // for m in range(Hout): // for n in range(Wout): // for p in range(kH): // for r in range(kW): // indexH = m * stride[0] + p * dilation[0] // indexW = n * stride[0] + r * dilation[0] // if paddedInput[i, j, indexH, indexW] == // maxPool2d[i, j, m, n]: // indices[i, j, m, n] = (indexH - padding[0]) * W + // (indexW - padding[1]) // class ConvertAtenMaxPool2dWithIndicesOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); RankedTensorType selfType = self.getType().cast(); Type elementType = selfType.getElementType(); RankedTensorType indicesRankedTensorType = getTypeConverter() ->convertType(op->getResult(1).getType()) .cast(); // TODO: Add support for 3D inputs. if (selfType.getRank() == 3) return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (failed(checkAndGetPoolingParameters( op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // `maxpool2d` contains the result of maxpool2d operation over the input. auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*Negative=*/true)); Value maxPool2d, paddedInput; SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Value cstMinusOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); Value indicesTensor = createInitTensor(rewriter, loc, outTensorShape, indicesRankedTensorType.getElementType(), cstMinusOne); SmallVector kernelSize = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); SmallVector padding = getAsConstantIndexValues(rewriter, loc, paddingInts); SmallVector dilation = getAsConstantIndexValues(rewriter, loc, dilationInts); SmallVector stride = getAsConstantIndexValues(rewriter, loc, strideInts); Value windowTensor = rewriter.create( loc, getAsOpFoldResult(kernelSize), indicesRankedTensorType.getElementType()); SmallVector inputExprs, outputExprs, kernelExprs; for (unsigned i = 0; i < 4; i++) { inputExprs.push_back(rewriter.getAffineDimExpr(i)); outputExprs.push_back(rewriter.getAffineDimExpr(i)); } kernelExprs.push_back(rewriter.getAffineDimExpr(4)); kernelExprs.push_back(rewriter.getAffineDimExpr(5)); // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, // and kW, respectively, as described in the algorithm above. SmallVector indexingMaps = AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs}); SmallVector iteratorTypes( 4, utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); iteratorTypes.push_back(utils::IteratorType::reduction); // Input format is : [N, C, H, W] Value inputShapeW = getDimOp(rewriter, loc, self, 3); Value indicesResult = rewriter .create( loc, /*resultTensorTypes=*/indicesTensor.getType(), /*inputs=*/ValueRange({maxPool2d, windowTensor}), /*outputs=*/indicesTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value maxVal = args[0], res = args[2]; Value i = b.create(loc, 0); Value j = b.create(loc, 1); Value m = b.create(loc, 2); Value n = b.create(loc, 3); Value p = b.create(loc, 4); Value r = b.create(loc, 5); Value mTimesStride = b.create(loc, m, stride[0]); Value pTimesDilation = b.create(loc, p, dilation[0]); Value indexH = b.create(loc, mTimesStride, pTimesDilation); Value nTimesStride = b.create(loc, n, stride[1]); Value rTimesDilation = b.create(loc, r, dilation[1]); Value indexW = b.create(loc, nTimesStride, rTimesDilation); Value input = b.create( loc, paddedInput, ValueRange{i, j, indexH, indexW}); Value pred = b.create( loc, arith::CmpFPredicate::OEQ, input, maxVal); Value indexHMinusPadding = b.create(loc, indexH, padding[0]); Value indexWMinusPadding = b.create(loc, indexW, padding[1]); Value outIndex = b.create( loc, indexHMinusPadding, inputShapeW); outIndex = b.create(loc, outIndex, indexWMinusPadding); Value result = b.create( loc, pred, castIndexToInt64(b, loc, outIndex), res); Value predInvalidIndex = b.create( loc, arith::CmpIPredicate::eq, res, cstMinusOne); Value out = b.create(loc, predInvalidIndex, result, res); b.create(loc, out); }) .getResult(0); Type maxPool2dResultType = getTypeConverter()->convertType(op->getResult(0).getType()); Type indicesResultType = getTypeConverter()->convertType(op->getResult(1).getType()); Value outMaxpool2d = rewriter.create(loc, maxPool2dResultType, maxPool2d); Value outIndices = rewriter.create(loc, indicesResultType, indicesResult); rewriter.replaceOp(op, {outMaxpool2d, outIndices}); return success(); } }; } // namespace namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); Type inputElementType = self.getType().cast().getElementType(); Type resultType = typeConverter->convertType(op.getType()); Type resultElementType = resultType.cast().getElementType(); bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts(Dim, 1); if (failed(checkAndGetPoolingParameters( op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // TODO: Add support for count_include_pad equal to `False`. bool countIncludePad; if (!matchPattern(op.getCountIncludePad(), m_TorchConstantBool(&countIncludePad))) return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); if (!countIncludePad) { return rewriter.notifyMatchFailure( op, "unimplemented: count_include_pad is expected to be true"); } // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); Value divisor; if constexpr (std::is_same()) { Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); divisor = op.getDivisorOverride().getType().template isa() ? kHtimeskW : adaptor.getDivisorOverride(); } else { divisor = kernelSizeIntValues[0]; } divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); SmallVector indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2)); SmallVector iteratorTypesAvg( Dim+2, utils::IteratorType::parallel); Value avgPool = rewriter .create( loc, outputTensor.getType(), sumPool, outputTensor, /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { Value avg; if (resultElementType.isa()) avg = b.create(loc, args[0], divisor); else if (resultElementType.isa()) avg = b.create(loc, args[0], divisor); b.create(loc, avg); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } }; } void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); }