//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Checks the validity of pooling parameters and stores them in the respective // vector. template static LogicalResult checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, bool &ceilMode, SmallVectorImpl &kernelSizeInts, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { // Pattern match against the op's original operands, because otherwise we // will get the lowered version of the operands which is harder to pattern // match. if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts))) return rewriter.notifyMatchFailure(op, "only support kernel size ints"); if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) return rewriter.notifyMatchFailure(op, "only support constant int paddings"); if (!matchPattern(op.ceil_mode(), m_TorchConstantBool(&ceilMode))) return rewriter.notifyMatchFailure(op, "only support constant bool ceil_mode"); return success(); } // Creates a pooling operation based on the type specified by `OpTy` and // arguments passed. template static LogicalResult createPoolingOp( Operation *op, ConversionPatternRewriter &rewriter, Value self, bool supportFPInputOnly, bool ceilMode, SmallVectorImpl &kernelSizeInts, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); Type elementType = self.getType().cast().getElementType(); if (!elementType.isa() && !supportFPInputOnly) return op->emitError("unimplemented: non-floating point type"); SmallVector lowPaddingIncludingNC = {0, 0}; lowPaddingIncludingNC.append(paddingInts); SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; if (ceilMode) { highPaddingIncludingNC[2] += strideInts[0]; highPaddingIncludingNC[3] += strideInts[1]; } Value initValue = rewriter.create(loc, initValueAttr); paddedInput = torch_to_linalg::getPaddedTensor( op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, initValue); Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); Value H = getDimOp(rewriter, loc, self, 2); Value W = getDimOp(rewriter, loc, self, 3); SmallVector paddingIntValues = getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); SmallVector kernelSizeIntValues = getAsConstantIntValues(rewriter, loc, kernelSizeInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); Value hOut = torch_to_linalg::getOutputDimForConvOps( rewriter, loc, H, paddingIntValues[0], dilationIntValues[0], kernelSizeIntValues[0], strideIntValues[0], ceilMode); Value wOut = torch_to_linalg::getOutputDimForConvOps( rewriter, loc, W, paddingIntValues[1], dilationIntValues[1], kernelSizeIntValues[1], strideIntValues[1], ceilMode); // Create output tensor initialized with smallest floating point value. outTensorShape.insert(outTensorShape.begin(), {N, C, hOut, wOut}); Value outTensorInitialized = createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); Value windowTensor = rewriter.create( loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts), elementType); result = rewriter .create(loc, outTensorInitialized.getType(), ValueRange{paddedInput, windowTensor}, outTensorInitialized, stridesAttr, dilationAttr) .getResult(0); return success(); } namespace { class ConvertAtenMaxPool2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Value self = adaptor.self(); int64_t selfRank = self.getType().cast().getRank(); // TODO: Add support for 3D inputs. if (selfRank == 3) return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); SmallVector kernelSizeInts, strideInts, paddingInts, dilationInts; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); bool ceilMode; if (failed(checkAndGetPoolingParameters( op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getLargest( elementType.cast().getFloatSemantics(), /*Negative=*/true)); SmallVector outTensorShape; // `maxpool2d` contains the result of maxpool2d operation over the input. Value maxPool2d, paddedInput; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/false, ceilMode, kernelSizeInts, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); return success(); } }; } // namespace namespace { // Returns the result of maxpool2d over the input tensor. And the corresponding // indices of the input tensor for the values of the result tensor. // // The result of the maxpool2d operation is calculated using the helper function // written above. For finding the indices, we follow the below method: // // Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will // also be a 4-d tensor. Then: // for i in range(N): // for j in range(C): // for m in range(Hout): // for n in range(Wout): // for p in range(kH): // for r in range(kW): // indexH = m * stride[0] + p * dilation[0] // indexW = n * stride[0] + r * dilation[0] // if paddedInput[i, j, indexH, indexW] == // maxPool2d[i, j, m, n]: // indices[i, j, m, n] = (indexH - padding[0]) * W + // (indexW - padding[1]) // class ConvertAtenMaxPool2dWithIndicesOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value self = adaptor.self(); RankedTensorType selfType = self.getType().cast(); Type elementType = selfType.getElementType(); RankedTensorType indicesRankedTensorType = getTypeConverter() ->convertType(op->getResult(1).getType()) .cast(); // TODO: Add support for 3D inputs. if (selfType.getRank() == 3) return rewriter.notifyMatchFailure( op, "unimplemented: only support 4D input"); SmallVector kernelSizeInts, strideInts, paddingInts, dilationInts; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); bool ceilMode; if (failed(checkAndGetPoolingParameters( op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // `maxpool2d` contains the result of maxpool2d operation over the input. auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getLargest( elementType.cast().getFloatSemantics(), /*Negative=*/true)); Value maxPool2d, paddedInput; SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/false, ceilMode, kernelSizeInts, strideInts, paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, maxPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); Value cstMinusOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); Value indicesTensor = createInitTensor(rewriter, loc, outTensorShape, indicesRankedTensorType.getElementType(), cstMinusOne); SmallVector kernelSize = getAsConstantIndexValues(rewriter, loc, kernelSizeInts); SmallVector padding = getAsConstantIndexValues(rewriter, loc, paddingInts); SmallVector dilation = getAsConstantIndexValues(rewriter, loc, dilationInts); SmallVector stride = getAsConstantIndexValues(rewriter, loc, strideInts); Value windowTensor = rewriter.create( loc, kernelSize, indicesRankedTensorType.getElementType()); SmallVector inputExprs, outputExprs, kernelExprs; for (unsigned i = 0; i < 4; i++) { inputExprs.push_back(rewriter.getAffineDimExpr(i)); outputExprs.push_back(rewriter.getAffineDimExpr(i)); } kernelExprs.push_back(rewriter.getAffineDimExpr(4)); kernelExprs.push_back(rewriter.getAffineDimExpr(5)); // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, // and kW, respectively, as described in the algorithm above. SmallVector indexingMaps = AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs}); SmallVector iteratorTypes(4, getParallelIteratorTypeName()); iteratorTypes.push_back(getReductionIteratorTypeName()); iteratorTypes.push_back(getReductionIteratorTypeName()); // Input format is : [N, C, H, W] Value inputShapeW = getDimOp(rewriter, loc, self, 3); Value indicesResult = rewriter .create( loc, /*resultTensorTypes=*/indicesTensor.getType(), /*inputs=*/ValueRange({maxPool2d, windowTensor}), /*outputs=*/indicesTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value maxVal = args[0], res = args[2]; Value i = b.create(loc, 0); Value j = b.create(loc, 1); Value m = b.create(loc, 2); Value n = b.create(loc, 3); Value p = b.create(loc, 4); Value r = b.create(loc, 5); Value mTimesStride = b.create(loc, m, stride[0]); Value pTimesDilation = b.create(loc, p, dilation[0]); Value indexH = b.create(loc, mTimesStride, pTimesDilation); Value nTimesStride = b.create(loc, n, stride[1]); Value rTimesDilation = b.create(loc, r, dilation[1]); Value indexW = b.create(loc, nTimesStride, rTimesDilation); Value input = b.create( loc, paddedInput, ValueRange{i, j, indexH, indexW}); Value pred = b.create( loc, arith::CmpFPredicate::OEQ, input, maxVal); Value indexHMinusPadding = b.create(loc, indexH, padding[0]); Value indexWMinusPadding = b.create(loc, indexW, padding[1]); Value outIndex = b.create( loc, indexHMinusPadding, inputShapeW); outIndex = b.create(loc, outIndex, indexWMinusPadding); Value result = b.create( loc, pred, castIndexToInt64(b, loc, outIndex), res); Value predInvalidIndex = b.create( loc, arith::CmpIPredicate::eq, res, cstMinusOne); Value out = b.create(loc, predInvalidIndex, result, res); b.create(loc, out); }) .getResult(0); Type maxPool2dResultType = getTypeConverter()->convertType(op->getResult(0).getType()); Type indicesResultType = getTypeConverter()->convertType(op->getResult(1).getType()); Value outMaxpool2d = rewriter.create(loc, maxPool2dResultType, maxPool2d); Value outIndices = rewriter.create(loc, indicesResultType, indicesResult); rewriter.replaceOp(op, {outMaxpool2d, outIndices}); return success(); } }; } // namespace namespace { class ConvertAtenAdaptiveAvgPool2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.self(); /* in form of N*C*H*W */ RankedTensorType inputType = input.getType().cast(); Type elementType = inputType.getElementType(); if (!elementType.isa()) return op.emitError("unimplemented: non-floating point type"); auto inputRank = inputType.getRank(); if (inputRank != 4) return rewriter.notifyMatchFailure(op, "input should be rank 4"); SmallVector expects{1, 1}; // Pattern match against the op's original operands, because otherwise we // will get the lowered version of the operands which is harder to pattern // match. if (!isConstantIntListMatching(op.output_size(), expects)) return rewriter.notifyMatchFailure( op, "only support output_size with H and W both equal to constant 1"); Value N = getDimOp(rewriter, loc, input, 0); Value C = getDimOp(rewriter, loc, input, 1); Value initTensor = rewriter.create( loc, ValueRange{N, C}, elementType); Value c0 = rewriter.create( loc, FloatAttr::get(elementType, 0.0)); Value initTensor0 = rewriter.create(loc, c0, initTensor).getResult(0); SmallVector ncExprs; ncExprs.push_back(mlir::getAffineDimExpr(0, context)); ncExprs.push_back(mlir::getAffineDimExpr(1, context)); auto ncIndexingMap = AffineMap::get( /*dimCount=*/4, /*symbolCount=*/0, ncExprs, context); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(4), // input ncIndexingMap, // output }; SmallVector iteratorTypesSum{"parallel", "parallel", "reduction", "reduction"}; Value sumPool2d = rewriter .create( loc, initTensor0.getType(), input, initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypesSum, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], sum = args[1]; Value result = rewriter.create( loc, sum, input); b.create(loc, result); }) .getResult(0); // Calculate H*W so that avg can be got from sum / (H*W) Value H = getDimOp(rewriter, loc, input, 2); Value W = getDimOp(rewriter, loc, input, 3); auto castIndexToInt = [&](Value v) { return rewriter.create( loc, IntegerType::get(context, 64), v); }; Value HtimesW = rewriter.create(loc, castIndexToInt(H), castIndexToInt(W)); Value HtimesWf = rewriter.create(loc, elementType, HtimesW); Value c1Index = rewriter.create(loc, /*value=*/1); Value outputTensor = rewriter.create( loc, ValueRange{N, C, c1Index, c1Index}, elementType); SmallVector indexingMapsAvg{ ncIndexingMap, rewriter.getMultiDimIdentityMap(4)}; SmallVector iteratorTypesAvg(4, "parallel"); Value avgPool2d = rewriter .create( loc, outputTensor.getType(), sumPool2d, outputTensor, /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { Value avg = b.create(loc, args[0], HtimesWf); b.create(loc, avg); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, avgPool2d); return success(); } }; } // namespace namespace { class ConvertAtenAvgPool2dOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value self = adaptor.self(); Type inputElementType = self.getType().cast().getElementType(); Type resultType = getTypeConverter()->convertType(op.getType()); Type resultElementType = resultType.cast().getElementType(); SmallVector dilationInts{1, 1}; SmallVector kernelSizeInts, strideInts, paddingInts; bool ceilMode; if (failed(checkAndGetPoolingParameters( op, rewriter, ceilMode, kernelSizeInts, strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // TODO: Add support for count_include_pad equal to `False`. bool countIncludePad; if (!matchPattern(op.count_include_pad(), m_TorchConstantBool(&countIncludePad))) return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); if (!countIncludePad) { return rewriter.notifyMatchFailure( op, "unimplemented: count_include_pad is expected to be true"); } // `sumPool2d` contains the result of sumpool2d operation over the input. Value sumPool2d, paddedInput; SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportFPInput=*/true, ceilMode, kernelSizeInts, strideInts, paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool2d))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool2d"); SmallVector kernelSizeIntValues = getAsConstantIntValues(rewriter, loc, kernelSizeInts); Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); Value divisor = op.divisor_override().getType().isa() ? kHtimeskW : adaptor.divisor_override(); divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); Value outputTensor = rewriter.create( loc, outTensorShape, resultElementType); SmallVector indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(4)); SmallVector iteratorTypesAvg(4, "parallel"); Value avgPool2d = rewriter .create( loc, outputTensor.getType(), sumPool2d, outputTensor, /*indexingMaps=*/indexingMapsAvg, /*iteratorTypes=*/iteratorTypesAvg, [&](OpBuilder &b, Location loc, ValueRange args) { Value avg; if (resultElementType.isa()) avg = b.create(loc, args[0], divisor); else if (resultElementType.isa()) avg = b.create(loc, args[0], divisor); b.create(loc, avg); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool2d); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }