//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "Utils.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static SmallVector getIndexIntsAsOpFoldResult(OpBuilder &b, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range( ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); })); } // Helper function to get the padding tensor given the padding int values. Value torch_to_linalg::getPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &lowPaddingInts, SmallVectorImpl &highPaddingInts, Value pad) { Location loc = op->getLoc(); Type rankedTensorType = tensor::PadOp::inferResultType(input.getType().cast(), lowPaddingInts, highPaddingInts); SmallVector lowPaddings = getIndexIntsAsOpFoldResult(b, lowPaddingInts); SmallVector highPaddings = getIndexIntsAsOpFoldResult(b, highPaddingInts); Value paddedInput = b.create(loc, rankedTensorType, input, /*low=*/lowPaddings, /*high=*/highPaddings, pad); return paddedInput; } // Helper function to get the padding tensor given the padding int values. // It's assumed that the padding on the low end and high end are the same, // and that zero padding is required. Value torch_to_linalg::getZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &paddingInts) { assert(input.getType().isa() && "input must be RankedTensorType"); Location loc = op->getLoc(); Value c0 = b.create( loc, b.getZeroAttr(input.getType().cast().getElementType())); return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); } Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt, bool ceilMode) { Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); Value doublePadding = b.create(loc, paddingInt, c2); // in + 2 * padding Value inAddDoublePadding = b.create(loc, castIndexToInt64(b, loc, in), doublePadding); // dilation * (kernelSize - 1) Value kernelSizeSub1 = b.create(loc, kernelSizeInt, c1); Value dilationTimesKernelSize = b.create(loc, dilationInt, kernelSizeSub1); Value temp = b.create(loc, inAddDoublePadding, dilationTimesKernelSize); Value dividend = b.create(loc, temp, c1); Value division; if (ceilMode) division = b.create(loc, dividend, strideInt); else division = b.create(loc, dividend, strideInt); Value out = b.create(loc, division, c1); return castIntToIndex(b, loc, out); } Value torch_to_linalg::getOutputDimForConvTransposeOps( OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt) { Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); // (in - 1) * stride Value inStrided = b.create(loc, castIndexToInt64(b, loc, in), c1); inStrided = b.create(loc, inStrided, strideInt); // 2 * padding Value doublePadding = b.create(loc, paddingInt, c2); // (kernelSize - 1) * dilation Value kernelDilated = b.create(loc, kernelSizeInt, c1); kernelDilated = b.create(loc, kernelDilated, dilationInt); Value out = b.create(loc, inStrided, doublePadding); out = b.create(loc, out, kernelDilated); out = b.create(loc, out, c1); return castIntToIndex(b, loc, out); } Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { auto inputType = opInfo.tensorOperand.getType().cast(); // Get the result shape by obtaining the size of each // dimension in the input tensor that is not getting reduced. // If `opInfo.keepDim` is true, the rank of the output tensor // is kept the same as the rank of the input tensor, and the // reduced dimensions are set to have size 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { auto currentDimSize = b.create(loc, opInfo.tensorOperand, i); if (!opInfo.dimSet.contains(i)) resultShape.push_back(currentDimSize); else if (opInfo.keepDim) resultShape.push_back(c1); } // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. SmallVector exprs; SmallVector iteratorTypes; SmallVector resultExprs; for (auto size : llvm::enumerate(inputType.getShape())) { exprs.push_back(b.getAffineDimExpr(size.index())); if (opInfo.dimSet.contains(size.index())) { iteratorTypes.push_back(getReductionIteratorTypeName()); // If `opInfo.keepDim`, create affine map to the first element // in the current dimension. if (opInfo.keepDim) resultExprs.push_back(b.getAffineConstantExpr(0)); } else { iteratorTypes.push_back(getParallelIteratorTypeName()); resultExprs.push_back(b.getAffineDimExpr(size.index())); } } auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); Value accumulator = createInitTensor(b, loc, resultShape, initElem.getType(), initElem); return b .create( loc, /*resultTensorTypes=*/accumulator.getType(), /*inputs=*/opInfo.tensorOperand, /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } Value torch_to_linalg::createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, function_ref bodyBuild) { // The overall error handling strategy here is best viewed by thinking about // what happens for a single result dimension. This loop not structured that // way because it is hard to create the affine maps for each operand unless // we structure the loop to iterate over tensor operands as the outer loop // instead of inner loop. This pseudocode gives better intuition: // ``` // for each result dimension: // for each tensor operand: // if it doesn't even have high enough rank relative to the result: // continue // if it is a static size-1 along this result dimension: // continue // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, // **non-broadcasted** traversal!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` SmallVector operandRanks; operandRanks.resize(tensorOperands.size()); llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { return tensor.getType().dyn_cast().getRank(); }); auto resultRankIt = std::max_element(operandRanks.begin(), operandRanks.end()); assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); int64_t resultRank = *resultRankIt; // Initialize the resultShape to all 1's, as a fallback in case // all sizes along that result dimension are statically 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); for (auto size : llvm::enumerate(type.getShape())) { // If the size is statically known to be 1, we don't want any // error guards to be spuriously emitted, since we are specifically // allowing size-1 broadcasts in this case, as they correspond to a // constant-0 indexing map. if (size.value() == 1) { exprs.push_back(b.getAffineConstantExpr(0)); continue; } // The rank of this operand might be smaller than the overall rank of // the broadcast. Add an offset to correlate it to the correct // dimension of the result. auto resultDim = size.index() + (resultRank - type.getRank()); // The generated linalg op will now be iterating along the full size // of this dimension. Record that fact. exprs.push_back(b.getAffineDimExpr(resultDim)); // Now, we need to ensure that such iteration is not going to trigger // undefined behavior, by doing appropriate checks against the current // dimension size. auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); // If the result size of this dimension has so far only hit the // statically-known-to-be-1 case above (i.e., we have not yet assigned a // new Value to `resultShape[resultDim]`), then we have no other dynamic // values to check against, and merely need to record the current // dimension size. if (resultShape[resultDim] == c1) { resultShape[resultDim] = currentDimSize; continue; } // We prohibit the size-1 dynamic broadcasting scenario, so just check // for exact equality with the running result size. // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. auto equalToRunning = b.create(loc, arith::CmpIPredicate::eq, resultShape[resultDim], currentDimSize); b.create(loc, equalToRunning, "mismatched size for broadcast"); } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); } SmallVector iteratorTypes(resultRank, getParallelIteratorTypeName()); // Add the indexing map for the outs init tensor. indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); Value initTensor = b.create( loc, getAsOpFoldResult(resultShape), resultElementType); return b .create(loc, /*resultTensorTypes=*/initTensor.getType(), /*inputs=*/tensorOperands, /*outputs=*/initTensor, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, SmallVector broadcastToShape, Value &result) { RankedTensorType inputType = input.getType().cast(); ArrayRef inputShape = inputType.getShape(); if (broadcastToShape.size() < inputShape.size()) { return rewriter.notifyMatchFailure( op, "invalid shape: broadcastToShape size must not be smaller than the " "size of the input shape"); } Type elementType = inputType.getElementType(); Location loc = op->getLoc(); MLIRContext *context = op->getContext(); SmallVector outShape; // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); size_t diff = broadcastToShape.size() - inputShape.size(); for (size_t i = 0; i < broadcastToShape.size(); i++) { Value shapeValue = broadcastToShape[i]; size_t j = i - diff; if (i < diff) { Value isValid = rewriter.create( loc, arith::CmpIPredicate::sge, shapeValue, zero); rewriter.create( loc, isValid, rewriter.getStringAttr( "negative values not allowed in new dimensions")); outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } if (inputShape[j] == 1) { // Broadcast singleton dimension Value one = rewriter.create(loc, rewriter.getIndexAttr(1)); Value isNegative = rewriter.create( loc, arith::CmpIPredicate::slt, shapeValue, zero); Value select = rewriter.create( loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); outExpr.push_back(mlir::getAffineConstantExpr(0, context)); continue; } // Non-broadcast case Value dim = getDimOp(rewriter, loc, input, j); Value isNegative = rewriter.create( loc, arith::CmpIPredicate::slt, shapeValue, zero); Value isEqual = rewriter.create( loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim), shapeValue); Value isValid = rewriter.create(loc, isNegative, isEqual); rewriter.create( loc, isValid, rewriter.getStringAttr( "only broadcasting singleton dimensions supported")); outShape.push_back(dim); outExpr.push_back(mlir::getAffineDimExpr(i, context)); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(outShape), elementType); SmallVector indexingMaps = { AffineMap::get(broadcastToShape.size(), 0, outExpr, context), rewriter.getMultiDimIdentityMap(broadcastToShape.size())}; SmallVector iteratorTypes(broadcastToShape.size(), "parallel"); result = rewriter .create( loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); return success(); } Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, Value tensor) { auto tensorType = tensor.getType().cast(); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); return b.create(loc, tensorType.clone(unknownSizes), tensor); }