//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static SmallVector getIndexIntsAsOpFoldResult(OpBuilder &b, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range( ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); })); } // Helper function to get the padding tensor given the padding int values. Value torch_to_linalg::getPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &lowPaddingInts, SmallVectorImpl &highPaddingInts, Value pad) { Location loc = op->getLoc(); Type rankedTensorType = tensor::PadOp::inferResultType(input.getType().cast(), lowPaddingInts, highPaddingInts); SmallVector lowPaddings = getIndexIntsAsOpFoldResult(b, lowPaddingInts); SmallVector highPaddings = getIndexIntsAsOpFoldResult(b, highPaddingInts); Value paddedInput = b.create(loc, rankedTensorType, input, /*low=*/lowPaddings, /*high=*/highPaddings, pad); return paddedInput; } // Helper function to get the padding tensor given the padding int values. // It's assumed that the padding on the low end and high end are the same, // and that zero padding is required. Value torch_to_linalg::getZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &paddingInts) { assert(input.getType().isa() && "input must be RankedTensorType"); Location loc = op->getLoc(); Value c0 = b.create( loc, b.getZeroAttr(input.getType().cast().getElementType())); return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); } // Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims // dimensions at the beginning. The high and low padding are the same, and the // padding value is zero. Value torch_to_linalg::getDynamicZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, int unpaddedDims) { assert(input.getType().isa() && "input must be RankedTensorType"); unsigned int inRank = input.getType().cast().getRank(); Location loc = op->getLoc(); SmallVector inputDims = getTensorSizes(b, loc, input); Value c0 = b.create(loc, b.getI64IntegerAttr(0)); SmallVector paddingIncludingUnchanged(unpaddedDims, c0); paddingIncludingUnchanged.append(padding); assert(unpaddedDims + padding.size() == inRank && "sum of unpaddedDims and padding.size() must equal to inputRank"); for (auto pad = paddingIncludingUnchanged.begin(); pad < paddingIncludingUnchanged.end(); pad++) *pad = castIntToIndex(b, loc, *pad); Type elementType = input.getType().cast().getElementType(); // TODO: audit possibility of sparsity on this tensor Type inputType = RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( SmallVector(inRank, kUnknownSize))), elementType); Value cf0 = b.create(loc, b.getFloatAttr(elementType, 0.0)); SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); return b.create(loc, inputType, input, /*low=*/paddingValues, /*high=*/paddingValues, cf0); } Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt, bool ceilMode) { Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); Value doublePadding = b.create(loc, paddingInt, c2); // in + 2 * padding Value inAddDoublePadding = b.create(loc, castIndexToInt64(b, loc, in), doublePadding); // dilation * (kernelSize - 1) Value kernelSizeSub1 = b.create(loc, kernelSizeInt, c1); Value dilationTimesKernelSize = b.create(loc, dilationInt, kernelSizeSub1); Value temp = b.create(loc, inAddDoublePadding, dilationTimesKernelSize); Value dividend = b.create(loc, temp, c1); Value division; if (ceilMode) division = b.create(loc, dividend, strideInt); else division = b.create(loc, dividend, strideInt); Value out = b.create(loc, division, c1); return castIntToIndex(b, loc, out); } Value torch_to_linalg::getOutputDimForConvTransposeOps( OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt, Value outputPaddingInt) { Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); // (in - 1) * stride Value inStrided = b.create(loc, castIndexToInt64(b, loc, in), c1); inStrided = b.create(loc, inStrided, strideInt); // 2 * padding Value doublePadding = b.create(loc, paddingInt, c2); // (kernelSize - 1) * dilation Value kernelDilated = b.create(loc, kernelSizeInt, c1); kernelDilated = b.create(loc, kernelDilated, dilationInt); Value out = b.create(loc, inStrided, doublePadding); out = b.create(loc, out, kernelDilated); out = b.create(loc, out, outputPaddingInt); out = b.create(loc, out, c1); return castIntToIndex(b, loc, out); } Value torch_to_linalg::createReductionLinalgGeneric( OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { auto inputType = opInfo.tensorOperand.getType().cast(); // Get the result shape by obtaining the size of each // dimension in the input tensor that is not getting reduced. // If `opInfo.keepDim` is true, the rank of the output tensor // is kept the same as the rank of the input tensor, and the // reduced dimensions are set to have size 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { auto currentDimSize = b.create(loc, opInfo.tensorOperand, i); if (!opInfo.dimSet.contains(i)) resultShape.push_back(currentDimSize); else if (opInfo.keepDim) resultShape.push_back(c1); } // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. SmallVector exprs; SmallVector iteratorTypes; SmallVector resultExprs; for (auto size : llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { exprs.push_back(b.getAffineDimExpr(size.index())); if (opInfo.dimSet.contains(size.index())) { iteratorTypes.push_back(utils::IteratorType::reduction); // If `opInfo.keepDim`, create affine map to the first element // in the current dimension. if (opInfo.keepDim) resultExprs.push_back(b.getAffineConstantExpr(0)); } else { iteratorTypes.push_back(utils::IteratorType::parallel); resultExprs.push_back(b.getAffineDimExpr(size.index())); } } auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); Value accumulator = createInitTensor(b, loc, resultShape, initElem.getType(), initElem); return b .create( loc, /*resultTensorTypes=*/accumulator.getType(), /*inputs=*/opInfo.tensorOperand, /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } Value torch_to_linalg::createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, function_ref bodyBuild) { // The overall error handling strategy here is best viewed by thinking about // what happens for a single result dimension. This loop not structured that // way because it is hard to create the affine maps for each operand unless // we structure the loop to iterate over tensor operands as the outer loop // instead of inner loop. This pseudocode gives better intuition: // ``` // for each result dimension: // for each tensor operand: // if it doesn't even have high enough rank relative to the result: // continue // if it is a static size-1 along this result dimension: // continue // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, // **non-broadcasted** traversal unless if // isAssumingStrictSymbolicShapes!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` SmallVector operandRanks; operandRanks.resize(tensorOperands.size()); llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { return tensor.getType().dyn_cast().getRank(); }); auto resultRankIt = std::max_element(operandRanks.begin(), operandRanks.end()); assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); int64_t resultRank = *resultRankIt; // Initialize the resultShape to all 1's, as a fallback in case // all sizes along that result dimension are statically 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); for (auto size : llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) { // If the size is statically known to be 1, we don't want any // error guards to be spuriously emitted, since we are specifically // allowing size-1 broadcasts in this case, as they correspond to a // constant-0 indexing map. if (size.value() == 1) { exprs.push_back(b.getAffineConstantExpr(0)); continue; } // The rank of this operand might be smaller than the overall rank of // the broadcast. Add an offset to correlate it to the correct // dimension of the result. auto resultDim = size.index() + (resultRank - type.getRank()); // The generated linalg op will now be iterating along the full size // of this dimension. Record that fact. exprs.push_back(b.getAffineDimExpr(resultDim)); // Now, we need to ensure that such iteration is not going to trigger // undefined behavior, by doing appropriate checks against the current // dimension size. auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); // If the result size of this dimension has so far only hit the // statically-known-to-be-1 case above (i.e., we have not yet assigned a // new Value to `resultShape[resultDim]`), then we have no other dynamic // values to check against, and merely need to record the current // dimension size. if (resultShape[resultDim] == c1) { resultShape[resultDim] = currentDimSize; continue; } // We prohibit the size-1 dynamic broadcasting scenario, so just check // for exact equality with the running result size. // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. if (!elideDynamicBroadcastCheck) { auto equalToRunning = b.create(loc, arith::CmpIPredicate::eq, resultShape[resultDim], currentDimSize); b.create(loc, equalToRunning, "mismatched size for broadcast"); } } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); } SmallVector iteratorTypes(resultRank, utils::IteratorType::parallel); // Add the indexing map for the outs init tensor. indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); Value initTensor = b.create( loc, getAsOpFoldResult(resultShape), resultElementType); return b .create(loc, /*resultTensorTypes=*/initTensor.getType(), /*inputs=*/tensorOperands, /*outputs=*/initTensor, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } // Broadcasts input tensor based on the broadcastToShape. LogicalResult torch_to_linalg::broadcastToGivenShape( Operation *op, PatternRewriter &rewriter, Value input, SmallVector broadcastToShape, RankedTensorType broadcastType, Value &result, SmallVector useBroadcastToShape) { RankedTensorType inputType = input.getType().cast(); int64_t inputRank = inputType.getRank(); int64_t outputRank = broadcastToShape.size(); ArrayRef outputShape = broadcastType.getShape(); SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); if (outputRank < inputRank) { return rewriter.notifyMatchFailure( op, "invalid shape: broadcastToShape size must not be smaller than the " "size of the input shape"); } Type elementType = inputType.getElementType(); Location loc = op->getLoc(); SmallVector outShape; bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); // Vector indicating broadcasted status when assuming strict symbolic shapes. SmallVector broadcastedStatus; // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value zeroIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); Value oneIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); size_t diff = outputRank - inputRank; bool hasDynamicNumpyBroadcast = false; for (size_t i = 0, e = outputRank; i < e; i++) { Value shapeValue = broadcastToShape[i]; size_t j = i - diff; bool isDynamic = i >= diff && inputShape[j] == kUnknownSize; // Inherit static output shapes if present. if (outputShape[i] != ShapedType::kDynamic) { outShape.push_back(rewriter.getIndexAttr(outputShape[i])); if (i < diff) { if (outputShape[i] < 0) { return rewriter.notifyMatchFailure( op, "invalid shape: negative values not allowed in new broadcast " "dimensions"); } continue; } if (isDynamic) { hasDynamicNumpyBroadcast = true; } else if (inputShape[j] != outputShape[i] && inputShape[j] != 1) { return rewriter.notifyMatchFailure( op, "invalid shape: static mismatch in input and output broadcast " "shapes"); } // If strict symbolic shapes are assumed and the input shape is dynamic, // we can assume that dim is not broadcasted. broadcastedStatus.push_back(inputShape[j] != outputShape[i] && !isDynamic); continue; } if (i < diff) { if (!elideDynamicBroadcastCheck) { Value isValid = rewriter.create( loc, arith::CmpIPredicate::sge, shapeValue, zero); rewriter.create( loc, isValid, rewriter.getStringAttr( "negative values not allowed in new dimensions")); } outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } if (inputShape[j] == 1) { // Broadcast singleton dimension Value isNegative = rewriter.create( loc, arith::CmpIPredicate::slt, shapeValue, zero); Value select = rewriter.create( loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); broadcastedStatus.push_back(true); continue; } // Case of dynamic input dimension wherein the shape to broadcast will // yield us the dimension size of the output. Value dim; if (!useBroadcastToShape.empty() && useBroadcastToShape[j]) { dim = castIntToIndex(rewriter, loc, broadcastToShape[i]); if (isDynamic) { hasDynamicNumpyBroadcast = true; } if (!elideDynamicBroadcastCheck) { Value isValid = rewriter.create( loc, arith::CmpIPredicate::sge, shapeValue, zero); rewriter.create( loc, isValid, rewriter.getStringAttr( "unimplemented: dynamic negative broadcast sizes")); } } else { dim = getDimOp(rewriter, loc, input, j); } // We can safely assume this dimension is not broadcasted with strict // symbols. broadcastedStatus.push_back(false); outShape.push_back(dim); } Value outTensor = rewriter.create(loc, outShape, elementType); // If we know there are no ? -> ? broadcasted dims, or we are assuming // strict symbols, we can safely use standard linalg style broadcasting // semantics. if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) { // If no dims are broadcasted and the rank doesn't change, we can just fold // the op away entirely. if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) && inputRank == outputRank) { result = rewriter.create(loc, outTensor.getType(), input); return success(); } SmallVector inputExprs; for (int64_t i = 0, e = inputRank; i < e; ++i) { if (broadcastedStatus[i]) { inputExprs.push_back(rewriter.getAffineConstantExpr(0)); continue; } inputExprs.push_back(rewriter.getAffineDimExpr(i + diff)); } SmallVector indexingMaps = { AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()), rewriter.getMultiDimIdentityMap(outputRank)}; SmallVector iteratorTypes( outputRank, utils::IteratorType::parallel); result = rewriter .create( loc, outTensor.getType(), input, outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); return success(); } // Fall back to numpy-style dynamic broadcasting in the form of a single // linalg op. SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(outputRank)}; SmallVector iteratorTypes(outputRank, utils::IteratorType::parallel); result = rewriter .create( loc, outTensor.getType(), ValueRange(), outTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { // `loopIndices` contains IV of the linalg loops which // would be used to extract values from the input tensor // later on. SmallVector loopIndices; for (size_t i = 0, e = outputRank; i < e; ++i) { if (i < diff) continue; loopIndices.push_back(b.create(loc, i)); } // `inputIndicesToExtract` contains i-th linalg loop IV if // the i-th input dimension is not 1, else it contains a // zero index. SmallVector inputIndicesToExtract; for (size_t i = 0, n = inputRank; i < n; i++) { if (inputShape[i] == 1) { inputIndicesToExtract.push_back(zeroIndex); } else { Value inputDim = getDimOp(b, loc, input, i); Value isEqual = b.create( loc, arith::CmpIPredicate::eq, inputDim, oneIndex); Value select = rewriter.create( loc, isEqual, zeroIndex, loopIndices[i]); inputIndicesToExtract.push_back(select); } } // Extract and yield the value from input tensor at // `inputIndicesToExtract` indices. Value result = b.create( loc, input, inputIndicesToExtract); b.create(loc, result); }) .getResult(0); return success(); } Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, Value tensor) { auto tensorType = tensor.getType().cast(); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); return b.create( loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor); } Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, Value tensor, Type elementType) { auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elementType); builder.create(loc, elem); }; return torch_to_linalg::createElementwiseLinalgGeneric( b, loc, {tensor}, elementType, dtypePromoteBody); } FailureOr torch_to_linalg::getBackendTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt) { FailureOr maybeType = getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); if (failed(maybeType)) { return failure(); } Type type = *maybeType; // The linalg-on-tensors backend currently expects integers to be signless. if (auto intType = type.dyn_cast()) { type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless); } return type; } bool torch_to_linalg::isUnsignedTorchType(Type type) { if (auto tty = dyn_cast(type)) return isUnsignedTorchType(tty.getDtype()); if (isa(type)) return false; if (isa(type)) return false; if (isa(type)) return true; if (isa(type)) return false; if (auto intTy = dyn_cast(type)) return intTy.isUnsigned(); llvm_unreachable("Unknown type checked for signedness"); return false; }