//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Helper function to check whether the `dtype` is None or Float type. static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { if (dtype.getType().isa()) return true; int64_t dtypeInt; if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) return false; FailureOr resDtype = getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); if (failed(resDtype)) return false; return resDtype->isa(); } // Helper function to compute the return type of the reduction function. // `dim` specifies the dimension to reduce and `keepDim` preserves the rank of // the input tensor. static Type computeReductionType(PatternRewriter &rewriter, Operation *op, BaseTensorType tensorType, Value dim, bool keepDim) { SmallVector sizes; int64_t dimInt; if (tensorType.hasSizes()) { ArrayRef inputShape = tensorType.getSizes(); int64_t inputRank = inputShape.size(); if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) { (void)rewriter.notifyMatchFailure(op, "dim is not a valid dim"); return nullptr; } sizes.append(inputShape.begin(), inputShape.end()); // The dimension to be reduced is set to 1 when `keepDim` is true else it // is removed. if (keepDim) sizes[dimInt] = 1; else sizes.erase(sizes.begin() + dimInt); } else { unsigned reducedRank = keepDim ? inputRank : inputRank - 1; sizes.resize(reducedRank, kUnknownSize); } } Type resultType = tensorType.getWithSizesAndDtype( !tensorType.hasSizes() ? std::optional>() : llvm::ArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } // Reduction function to calculate sum along given `dim`. static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value dimList = rewriter.create( loc, Torch::ListType::get(dim.getType()), dim); Value keepDimCst = rewriter.create(loc, keepDim); Value dtype = rewriter.create(loc); Type resultType = computeReductionType( rewriter, op, input.getType().cast(), dim, keepDim); if (!resultType) return nullptr; return rewriter.create(loc, resultType, input, dimList, keepDimCst, dtype); } // Reduction function to calculate max along given `dim`. static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value keepDimCst = rewriter.create(loc, keepDim); BaseTensorType valueType = computeReductionType(rewriter, op, input.getType().cast(), dim, keepDim) .cast(); if (!valueType) return nullptr; BaseTensorType indexType = valueType .getWithSizesAndDtype( !valueType.hasSizes() ? std::optional>() : llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed)) .cast(); return rewriter .create(loc, valueType, indexType, input, dim, keepDimCst) .getValues(); } // Helper for creating `aten::sub_tensor_op`. static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { Value alpha = rewriter.create(loc, rewriter.getF64FloatAttr(1)); Value sub = rewriter.create(loc, tensorType, lhs, rhs, alpha); return sub; } // Share code between `softmax_backward` and `log_softmax_backward` ops. // Returns x - y * sum(z, dim). static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, Location loc, Operation *op, Type tensorType, Value x, Value y, Value z, Value dim) { Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true); if (!sum) return nullptr; auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(op->getContext())); Value broadcastSize = rewriter.create(loc, broadcastSizeType, z); Value sumBroadcast = rewriter.create(loc, tensorType, sum, broadcastSize); Value temp = rewriter.create(loc, tensorType, y, sumBroadcast); Value sub = createTensorSub(rewriter, loc, tensorType, x, temp); return sub; } static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, int64_t dstDimInt, unsigned inputRank) { llvm::iota_range dimsOrderIR(0, inputRank, /*inclusive=*/false); SmallVector dimsOrder(dimsOrderIR.begin(), dimsOrderIR.end()); dimsOrder.erase(dimsOrder.begin() + srcDimInt); dimsOrder.insert(dimsOrder.begin() + dstDimInt, srcDimInt); return dimsOrder; } static bool rewriteEquationWithEllipsisSlicing(std::string &equation, SmallVector &inputRanks) { // split equation into input and result size_t arrowPos = equation.find("->"); if (arrowPos == std::string::npos) { return false; } std::string inputStr = equation.substr(0, arrowPos); std::string resultStr = equation.substr(arrowPos + 2); // split input into tokens SmallVector inputTokens; size_t start = 0; size_t end = 0; std::set usedTokens; while (end < inputStr.size()) { end = inputStr.find(",", start); if (end == std::string::npos) { end = inputStr.size(); } std::string token = inputStr.substr(start, end - start); inputTokens.push_back(token); start = end + 1; } if (inputTokens.size() != inputRanks.size()) { return false; } // find the rank which ellipsis represents, and max ellipsis rank because a // tensor can be broadcasted SmallVector ellipsisRanks; int maxEllipsisRank = 0; for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) { int explictRank = 0; for (auto c : token) { if (std::isalpha(c)) { usedTokens.insert(c); explictRank++; } else if (c == '.' || c == ' ') { continue; } else { return false; } } int ellipsisRank = inputRank - explictRank; if (ellipsisRank > maxEllipsisRank) { maxEllipsisRank = ellipsisRank; } if (ellipsisRank < 0) { return false; } ellipsisRanks.push_back(inputRank - explictRank); } auto isTokenUsed = [&usedTokens](char c) { return usedTokens.find(c) != usedTokens.end(); }; std::string ellipsisToken; int usedCount = 0; // Iterate over the alphabet to create a new token for ellipsis for (char c = 'a'; c <= 'z'; ++c) { if (!isTokenUsed(c)) { ellipsisToken.push_back(c); usedCount++; if (usedCount == maxEllipsisRank) { break; } } } // replace ellipsis with ellipsisToken for (size_t i = 0; i < inputTokens.size(); i++) { size_t ellipsisPos = inputTokens[i].find("..."); if (ellipsisPos == std::string::npos) { continue; } if (ellipsisRanks[i] == maxEllipsisRank) { inputTokens[i].replace(ellipsisPos, 3, ellipsisToken); } else if (ellipsisRanks[i] == 0) { inputTokens[i].replace(ellipsisPos, 3, ""); } else { inputTokens[i].replace( ellipsisPos, 3, ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i])); } } // replace ellipsis in result size_t ellipsisPos = resultStr.find("..."); if (ellipsisPos != std::string::npos) { resultStr.replace(ellipsisPos, 3, ellipsisToken); } // join input and result equation = llvm::join(inputTokens, ",") + " -> " + resultStr; return true; } static bool parseEquation(const std::string &equation, SmallVector> &inputTokens, SmallVector &resultTokens) { SmallVector inputToken; size_t index = 0; enum EquationVariable { kIsInput, kIsResult }; EquationVariable currentVariable = kIsInput; while (index < equation.size()) { if (std::isalpha(equation[index])) { if (currentVariable == kIsInput) { inputToken.push_back(equation[index]); } else { resultTokens.push_back(equation[index]); } } else if (equation[index] == ',') { inputTokens.push_back(inputToken); inputToken.clear(); } else if ((index < (equation.size() - 1)) && (equation.substr(index, 2).find("->") != std::string::npos)) { inputTokens.push_back(inputToken); inputToken.clear(); currentVariable = kIsResult; index++; } else if (equation[index] != ' ') { return false; } index++; } return true; } // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] => // [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, Value input, int64_t batchDimsLength, int64_t contractingDimsLength, int64_t otherDimsLength, int64_t reduceDimsLength, bool isLhs) { auto inputType = input.getType().cast(); auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + reduceDimsLength; SmallVector inputShapeTensor; for (auto i = 0; i < inputRank; ++i) { inputShapeTensor.emplace_back(rewriter.create( loc, input, rewriter.create(loc, rewriter.getI64IntegerAttr(i)))); } SmallVector outShapeTensor; Value constOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto dimOffset = 0; auto appendDims = [&](int64_t dimLength) { Value prod = constOne; for (auto i = 0; i < dimLength; ++i) { prod = rewriter.create(loc, prod, inputShapeTensor[i + dimOffset]); } outShapeTensor.emplace_back(prod); dimOffset += dimLength; }; appendDims(batchDimsLength); if (!isLhs) appendDims(contractingDimsLength); appendDims(otherDimsLength + reduceDimsLength); if (isLhs) appendDims(contractingDimsLength); auto outShapeValue = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), outShapeTensor); auto outType = inputType.getWithSizesAndDtype(std::nullopt, inputType.getOptionalDtype()); return rewriter.create(loc, outType, input, outShapeValue); } // classify every dim token into different categories. Note that although we // parse out reduce dims, we delay their execution until // `performLastPermuteAndReduce`. static void parseDimTokens( SmallVector &lhsTokens, SmallVector &rhsTokens, SmallVector &finalResultTokens, SmallVector &contractingDims, SmallVector &lhsReduceDims, SmallVector &rhsReduceDims, SmallVector &batchingDims, SmallVector &lhsOtherDims, SmallVector &rhsOtherDims) { llvm::SmallDenseSet lhsTokenSet(lhsTokens.begin(), lhsTokens.end()); llvm::SmallDenseSet rhsTokenSet(rhsTokens.begin(), rhsTokens.end()); llvm::SmallDenseSet finalResultTokenSet(finalResultTokens.begin(), finalResultTokens.end()); for (size_t i = 0; i < lhsTokens.size(); ++i) { bool rhsContains = rhsTokenSet.contains(lhsTokens[i]); bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]); // batching dim if (rhsContains && finalResultConatins) { batchingDims.push_back(lhsTokens[i]); // reduce dim of lhs } else if (!rhsContains && !finalResultConatins) { lhsReduceDims.push_back(lhsTokens[i]); // other dim of lhs } else if (finalResultConatins) { lhsOtherDims.push_back(lhsTokens[i]); // contracting dim of lhs } else if (rhsContains) { contractingDims.push_back(lhsTokens[i]); } } for (size_t i = 0; i < rhsTokens.size(); ++i) { bool lhsContains = lhsTokenSet.contains(rhsTokens[i]); bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]); // batching dim if (lhsContains && finalResultConatins) { // reduce dim of rhs } else if (!lhsContains && !finalResultConatins) { rhsReduceDims.push_back(rhsTokens[i]); // other dim of rhs } else if (finalResultConatins) { rhsOtherDims.push_back(rhsTokens[i]); // contracting dim of rhs } else if (lhsContains) { } } } static void generateIdealReusltDimTokens(SmallVector &batchingDims, SmallVector &lhsOtherDims, SmallVector &rhsOtherDims, SmallVector &lhsReduceDims, SmallVector &rhsReduceDims, SmallVector &resultTokens) { // generate ideal result dims, i.e., // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, // *rhsReduceDims] resultTokens.insert(resultTokens.end(), batchingDims.begin(), batchingDims.end()); resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(), lhsOtherDims.end()); resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(), lhsReduceDims.end()); resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(), rhsOtherDims.end()); resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(), rhsReduceDims.end()); } static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, Value input, SmallVector &dimTokens, SmallVector &batchingDims, SmallVector &contractingDims, SmallVector &otherDims, SmallVector &reduceDims, bool isLhs) { auto inputType = input.getType().cast(); llvm::SmallDenseMap dimTokenMap; for (size_t idx = 0; idx < dimTokens.size(); ++idx) { dimTokenMap[dimTokens[idx]] = idx; } SmallVector permuteVec; auto appendDims = [&](SmallVector dimTokens) { for (auto d : dimTokens) { permuteVec.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); } }; appendDims(batchingDims); if (!isLhs) appendDims(contractingDims); appendDims(otherDims); appendDims(reduceDims); if (isLhs) appendDims(contractingDims); Value dstDims = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), permuteVec); auto outType = inputType.getWithSizesAndDtype(std::nullopt, inputType.getOptionalDtype()); return rewriter.create(loc, outType, input, dstDims); } static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, Value lhs, SmallVector &lhsTokens, Value rhs, SmallVector &rhsTokens, Value &result, SmallVector &resultTokens, SmallVector &finalResultTokens) { auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() : rhsType.getOptionalDtype(); llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { char d = lhsTokens[idx]; lhsDimShapeMap[d] = rewriter.create( loc, lhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); } llvm::SmallDenseMap rhsDimShapeMap; for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { char d = rhsTokens[idx]; rhsDimShapeMap[d] = rewriter.create( loc, rhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); } // parse batch, contracting, other, reduce dims of lhs and rhs SmallVector contractingDims; SmallVector lhsReduceDims; SmallVector rhsReduceDims; SmallVector lhsOtherDims; SmallVector rhsOtherDims; SmallVector batchingDims; parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims, lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims, rhsOtherDims); llvm::SmallDenseMap outDimShapeMap; auto generateOutDimShapeMap = [&](SmallVector &dims) { for (auto d : dims) { bool lhsContains = lhsDimShapeMap.count(d) > 0; bool rhsContains = rhsDimShapeMap.count(d) > 0; if (lhsContains && rhsContains) { outDimShapeMap[d] = rewriter.create( loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); } else if (lhsContains) { outDimShapeMap[d] = lhsDimShapeMap[d]; } else if (rhsContains) { outDimShapeMap[d] = rhsDimShapeMap[d]; } } }; generateOutDimShapeMap(contractingDims); generateOutDimShapeMap(batchingDims); generateOutDimShapeMap(lhsReduceDims); generateOutDimShapeMap(rhsReduceDims); generateOutDimShapeMap(lhsOtherDims); generateOutDimShapeMap(rhsOtherDims); if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 && rhsOtherDims.size() == 0) { return rewriter.notifyMatchFailure( loc, "Hadamard product is currently not supported"); } // shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims, contractingDims, lhsOtherDims, lhsReduceDims, true); // shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims] rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims, contractingDims, rhsOtherDims, rhsReduceDims, false); // shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(), contractingDims.size(), lhsOtherDims.size(), lhsReduceDims.size(), true); // shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd] rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(), contractingDims.size(), rhsOtherDims.size(), rhsReduceDims.size(), false); // perform matmul auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType); result = rewriter.create(loc, outType, lhs, rhs); // generate ideal result dims. generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, lhsReduceDims, rhsReduceDims, resultTokens); // reshape matmul result to ideal shape: // [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] => // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, // *rhsReduceDims] SmallVector outShapeTensors; for (char d : resultTokens) { outShapeTensors.emplace_back(outDimShapeMap[d]); } auto outResultShape = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), outShapeTensors); result = rewriter.create( loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result, outResultShape); return success(); } static Value performLastReduceAndPermute(PatternRewriter &rewriter, Location loc, Type outType, Value input, SmallVector &inputTokens, SmallVector &outTokens) { auto inputType = input.getType().cast(); llvm::SmallDenseSet outTokenSet(outTokens.begin(), outTokens.end()); SmallVector sumDims; llvm::SmallDenseMap inputDimToIdx; int64_t idx = 0; for (size_t i = 0; i < inputTokens.size(); ++i) { char d = inputTokens[i]; if (!outTokenSet.contains(d)) { sumDims.emplace_back(i); } else { inputDimToIdx[d] = idx++; } } if (sumDims.size() > 0) { SmallVector sumDimsTensor; for (auto d : sumDims) { sumDimsTensor.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(d))); } auto sumDimsListValue = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), sumDimsTensor); auto falseValue = rewriter.create( loc, rewriter.getBoolAttr(false)); auto noneValue = rewriter.create(loc); input = rewriter.create( loc, inputType.getWithSizesAndDtype(std::nullopt, inputType.getOptionalDtype()), input, sumDimsListValue, falseValue, noneValue); } SmallVector permuteDimsTensor; for (auto d : outTokens) { permuteDimsTensor.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); } auto permuteDimsListValue = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), permuteDimsTensor); auto out = rewriter.create(loc, outType, input, permuteDimsListValue); return out; } namespace { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// number of dimensions across which the max needs to be computed. /// Eg: /// INPUT: /// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) /// /// OUTPUT: /// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 /// input_2 = aten.max.dim(input_1, 1, keepdim) #2 /// final_output = aten.max.dim(input_2, 0, keepdim) #3 /// /// NOTE: We iterate over, in reverse order, every dimension included in `dim` /// of the `aten.amax` op and create an `aten.amax.dim` op. /// Input tensor to the next `aten.amax.dim` op is thus the output of the /// previous `aten.amax.dim` op. class DecomposeAtenAmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAmaxOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); SmallVector dims; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); bool keepDim; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for keepDim"); Value input = op.getSelf(); auto inputTy = input.getType().dyn_cast(); if (!inputTy || !inputTy.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. std::sort(dims.rbegin(), dims.rend()); for (int64_t dimInt : dims) { int64_t inputRank = inputTy.getSizes().size(); dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(dimInt)); // The input to the next invocation of aten.max.dim is the output of the // previous aten.max.dim op. input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); } rewriter.replaceOp(op, input); return success(); } }; } // end namespace namespace { class DecomposeAtenTriuOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTriuOp op, PatternRewriter &rewriter) const override { MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes() || !inputType.hasDtype()) { return rewriter.notifyMatchFailure(op, "should have shape and dtype"); } if (inputType.getSizes().size() < 2) { return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); } auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value none = rewriter.create(loc); Value rowDim = rewriter.create( loc, rewriter.getI64IntegerAttr(-2)); Value colDim = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); Value rowSize = rewriter.create(loc, input, rowDim); Value colSize = rewriter.create(loc, input, colDim); Value rowArange = rewriter.create( loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value colArange = rewriter.create( loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value unsqueezeRowArange = rewriter.create(loc, baseType, rowArange, cstOne); Value unsqueezeColArange = rewriter.create(loc, baseType, colArange, cstZero); Value unsqueezeRowArangePlusDiagonal = rewriter.create( loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); Value condTensor = rewriter.create( loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); return success(); } }; } // namespace namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSizeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; SmallVector sizes; for (unsigned i = 0; i < rank; i++) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); sizes.push_back(rewriter.create(loc, self, dim)); } Value sizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), sizes); rewriter.replaceOp(op, sizeList); return success(); } }; } // namespace namespace { class DecomposeAtenSelectIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSelectIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getIndex(); Value dim = op.getDim(); Value self = op.getSelf(); auto resultTy = op.getType().cast(); if (!resultTy.hasSizes() || !resultTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have sizes and dtype"); } // convert `start` to non-negative: start += int(start < 0) * dimSize Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value isNegative = rewriter.create(loc, start, zero); isNegative = rewriter.create(loc, isNegative); Value dimSize = rewriter.create(loc, self, dim); Value indexOffset = rewriter.create(loc, isNegative, dimSize); start = rewriter.create(loc, start, indexOffset); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = rewriter.create(loc, one.getType(), start, one); Value slice = rewriter.create( loc, computeReductionType(rewriter, op, self.getType().cast(), dim, /*keepDim=*/true), op.getSelf(), dim, start, startPlusOne, /*step=*/one); auto sliceTy = cast(slice.getType()); if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { rewriter.replaceOp(op, slice); return success(); } // `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after // slicing, while `aten.select.int` does. rewriter.replaceOpWithNewOp(op, op.getResult().getType(), slice, op.getDim()); return success(); } }; } // namespace namespace { class DecomposePrimTolistOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimTolistOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto self = op.getOperands()[0]; auto selfTy = dyn_cast(self.getType()); if (!selfTy || !selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "Unknown self shape"); int64_t rank = selfTy.getSizes().size(); if (rank != 1) return rewriter.notifyMatchFailure(op, "Expected rank-1"); int64_t length = selfTy.getSizes().back(); if (length == Torch::kUnknownSize) return rewriter.notifyMatchFailure(op, "Tolist length is unknown"); auto resultTy = dyn_cast(op.getType(0)); if (!resultTy) return rewriter.notifyMatchFailure(op, "Result type is not list"); auto scalarTy = resultTy.getContainedType(); Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); auto extractTy = rewriter.getType( llvm::SmallVector{1}, selfTy.getOptionalDtype()); llvm::SmallVector results; llvm::SmallVector sizes(selfTy.getSizes()); for (int64_t i = 0; i < length; ++i) { Value iv = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); Value extract = rewriter.create( loc, extractTy, self, /*dim=*/zero, /*index=*/iv); Value scalar = rewriter.create(loc, scalarTy, extract); results.push_back(scalar); } rewriter.replaceOpWithNewOp(op, resultTy, results); return failure(); } }; } // namespace namespace { class DecomposeAtenSplitSizesOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSplitSizesOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); return success(); } }; } // namespace namespace { class DecomposeAtenSplitWithSizesOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSplitWithSizesOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); Value self = op.getSelf(); SmallVector splitSizes; if (!getListConstructElements(op.getSplitSizes(), splitSizes)) return rewriter.notifyMatchFailure(op, "Unable to get sizes"); if (splitSizes.empty()) return rewriter.notifyMatchFailure(op, "No split sizes"); auto selfTy = dyn_cast(self.getType()); if (!selfTy || !selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "Self shape unknown"); int64_t rank = selfTy.getSizes().size(); auto resultTy = dyn_cast(op.getResult().getType()); if (!resultTy) return rewriter.notifyMatchFailure(op, "Result type not a list"); auto sliceTy = dyn_cast_or_null(resultTy.getContainedType()); if (!isa(sliceTy)) return rewriter.notifyMatchFailure(op, "Slice type is unknown"); int64_t dimInt = 0; bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)); if (dimInt < 0) dimInt += rank; auto intTy = rewriter.getType(); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value begin = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); llvm::SmallVector slices; llvm::SmallVector sliceSizes(sliceTy.getSizes()); int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt]; for (auto size : splitSizes) { Value end = rewriter.create(loc, intTy, begin, size); int64_t sizeInt; if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) { sliceSizes[dimInt] = sizeInt; } else if (hasDim) { sliceSizes[dimInt] = defaultLength; } sliceTy = rewriter.getType(sliceSizes, sliceTy.getOptionalDtype()); Value slice = rewriter.create( loc, sliceTy, op.getSelf(), /*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one); slices.push_back(slice); begin = end; } rewriter.replaceOpWithNewOp(op, resultTy, slices); return success(); } }; } // namespace namespace { class DecomposeAtenNarrowOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNarrowOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getStart(); Value dim = op.getDim(); Value length = op.getLength(); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusLength = rewriter.create(loc, one.getType(), start, length); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start, /*end=*/startPlusLength, /*step=*/one); return success(); } }; } // namespace namespace { // Decompose `aten.narrow.Tensor` to `aten.narrow` op class DecomposeAtenNarrowTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNarrowTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto *context = op.getContext(); // PyTorch makes sure that `start` param is an 0-dim integral tensor. // REF: https://pytorch.org/docs/stable/generated/torch.narrow.html. auto start = rewriter.create( loc, Torch::IntType::get(context), op.getStart()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength()); return success(); } }; } // namespace namespace { class DecomposeAtenGluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenGluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); Value dim = op.getDim(); auto outputTy = op.getType().dyn_cast(); if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "Expected output type having sizes and dtype"); } Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value dimSize = rewriter.create(loc, self, dim); Value two = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); Value remainder = rewriter.create(loc, dimSize, two); Value eqOrNot = rewriter.create(loc, remainder, zero); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2")); Value splitLength = rewriter.create(loc, dimSize, two); Value a = rewriter.create(loc, outputTy, self, dim, zero, splitLength); Value b = rewriter.create(loc, outputTy, self, dim, splitLength, splitLength); // a⊗σ(b) Value sigmoidB = rewriter.create(loc, outputTy, b); Value result = rewriter.create(loc, outputTy, a, sigmoidB); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { Value zero = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), zero); return success(); } }; } // namespace namespace { class DecomposeAtenEyeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEyeOp op, PatternRewriter &rewriter) const override { Value n = op.getN(); Value m = op.getN(); rewriter.replaceOpWithNewOp(op, op.getType(), n, m, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { class DecomposeAtenEyeMOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); int64_t n; if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) return rewriter.notifyMatchFailure(op, "unimplemented: n must be constant"); int64_t m; if (!matchPattern(op.getM(), m_TorchConstantInt(&m))) return rewriter.notifyMatchFailure(op, "unimplemented: m must be constant"); Value none = rewriter.create(loc); auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } if (n < 0) { return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0"); } if (m < 0) { return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0"); } auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); Value rangeN = rewriter.create( loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); auto arangeType1 = outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); Value rangeM = rewriter.create( loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value constMinusOne = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, rangeN, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor"); } Value unsqzRangeN = *unsqzTensorInfo; // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( context, op.getType().cast().getSizes(), IntegerType::get(context, 1)); Value eqTensor = rewriter.create(loc, eqType, unsqzRangeN, rangeM); Value dtype = op.getDtype(); if (dtype.getType().isa()) { rewriter.replaceOp(op, eqTensor); return success(); } else { auto zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); auto one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value outTensor = rewriter.create(loc, outType, eqTensor, one, zero); rewriter.replaceOp(op, outTensor); return success(); } } }; } // namespace namespace { class DecomposeAtenIsnanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIsnanOp op, PatternRewriter &rewriter) const override { Value input = op.getSelf(); // Create a new aten.ne operation with the same type and input value. rewriter.replaceOpWithNewOp(op, op.getType(), input, input); return success(); } }; } // namespace namespace { class DecomposeAtenIsinfOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIsinfOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); mlir::FloatType f64Type = rewriter.getF64Type(); Value inf = rewriter.create( loc, rewriter.getFloatAttr( f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); Value abs = rewriter.create(loc, self.getType(), self); rewriter.replaceOpWithNewOp(op, op.getType(), abs, inf); return success(); } }; } // namespace namespace { class DecomposeAtenIsneginfOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIsneginfOp op, PatternRewriter &rewriter) const override { mlir::FloatType f64Type = rewriter.getF64Type(); Value inf = rewriter.create( op.getLoc(), rewriter.getFloatAttr( f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), inf); return success(); } }; } // namespace namespace { class DecomposeAtenIsposinfOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIsposinfOp op, PatternRewriter &rewriter) const override { mlir::FloatType f64Type = rewriter.getF64Type(); Value inf = rewriter.create( op.getLoc(), rewriter.getFloatAttr(f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), inf); return success(); } }; } // namespace namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenReshapeOp op, PatternRewriter &rewriter) const override { Value input = op.getSelf(); // TODO: Handle non value tensor type operands. if (!input.getType().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: only value tensor type operands are supported"); } rewriter.replaceOpWithNewOp(op, op.getType(), input, op.getShape()); return success(); } }; } // namespace namespace { // Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce // operation and permute operation. Currently, this pass doesn't support // Hadamard product. The basic idea is that: // Step 1: split the string equation to input/result tokens and find // batchingDims, contractingDims, otherDims and reduceDims. // Step 2: permute and reshape input tensors suitable // for matmul operations. // Step 3: use AtenMatmulOp to get the result. // Step 4: iteratively execute step 2 & 3 until we get the final result. // Step 5: perform remaining permute and reduce operations. // notice: support static shape only class DecomposeAtenEinsumOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEinsumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); SmallVector inputTensors; if (!getListConstructElements(op.getTensors(), inputTensors)) { return rewriter.notifyMatchFailure( op, "input should comes from a PrimListConstructOp"); } auto allTensorHasSizes = [](Value tensor) { auto type = tensor.getType().dyn_cast(); if (!type || !type.hasSizes()) return false; return true; }; if (!llvm::all_of(inputTensors, allTensorHasSizes)) { return rewriter.notifyMatchFailure(op, "all input tensors should have sizes"); } std::string equation; if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) { return rewriter.notifyMatchFailure(op, "Unsupported value of equation"); } // if "..." in equation, modify it if (equation.find("...") != std::string::npos) { SmallVector inputRanks; for (Value tensor : inputTensors) { auto type = tensor.getType().cast(); inputRanks.push_back(type.getSizes().size()); } if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) { return rewriter.notifyMatchFailure( op, "Unexpected character in equations encountered"); } } SmallVector resultTokens; SmallVector> inputTokens; if (!parseEquation(equation, inputTokens, resultTokens)) { return rewriter.notifyMatchFailure( op, "Unexpected character in equations encountered"); } SmallVector lhsTokens = inputTokens[0]; Value lhs = inputTensors[0]; Value result; for (size_t i = 1; i < inputTensors.size(); ++i) { auto rhs = inputTensors[i]; auto rhsTokens = inputTokens[i]; SmallVector outTokens; if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens, result, outTokens, resultTokens))) { return failure(); } lhs = result; lhsTokens = outTokens; } result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs, lhsTokens, resultTokens); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Calculate the trace of the input tensor as the sum over its diagonal // elements. This computation is performed as: // // Step1: Obtain the diagonal using AtenDiagonalOp // Step2: Compute the trace using AtenSumOp. // // It is verified that the input tensor has rank two. class DecomposeAtenTraceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTraceOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); std::optional inRank = getTensorRank(self); if (inRank != 2) return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank 2."); Value none = rewriter.create(loc); Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); BaseTensorType inputType = self.getType().cast(); Value output = op.getResult(); BaseTensorType outputType = output.getType().cast(); ArrayRef inputShape = inputType.getSizes(); int64_t diagonalSize = std::min(inputShape[0], inputShape[1]); SmallVector diagonalShape{diagonalSize}; Type elementType = inputType.getOptionalDtype(); Type diagonalType = inputType.getWithSizesAndDtype( llvm::ArrayRef(diagonalShape), elementType); Value diagonal = rewriter.create( loc, diagonalType, /*input=*/self, /*offset=*/zero, /*dim1=*/zero, /*dim2=*/one); Value sum = rewriter.create(loc, outputType, /*self=*/diagonal, /*dtype=*/none); rewriter.replaceOp(op, sum); return success(); } }; } // namespace // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: // x_max = max(input, dim, keepdim = True) // unnorm = aten.exp(input - x_max) // softmax = unnorm / sum(unnorm, dim, keepdim = True) template static Value getSoftmaxResult(OpTy op, Value self, Type resultType, Type accumulatorType, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); if (resultType != accumulatorType) self = convertTensorToDtype(rewriter, loc, self, accumulatorType); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) return nullptr; Value unNormalized = createTensorSub(rewriter, loc, self.getType(), self, xMax); Value unNormalizedExp = rewriter.create(loc, self.getType(), unNormalized); Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim, /*keepDim=*/true); if (!sum) return nullptr; Value result = rewriter.create(loc, self.getType(), unNormalizedExp, sum); if (resultType != accumulatorType) result = convertTensorToDtype(rewriter, loc, result, cast(resultType).getDtype()); return result; } // Decompose softmax into: exp(x) / sum(exp(x)) namespace { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType resultTensorType = op.getType().cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultTensorDtype = resultTensorType.getDtype(); if (!isa(resultTensorDtype)) return rewriter.notifyMatchFailure(op, "Only support floating-point type"); // If `dtype` arg is non-none then convert the input to `dtype`. if (!op.getDtype().getType().isa()) { Location loc = op.getLoc(); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); self = rewriter.create( loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); Value result = getSoftmaxResult(op, self, resultTensorType, accumulatorTensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); } }; } // namespace namespace { class DecomposeAten_SoftmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_SoftmaxOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); bool halfToFloat; if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); BaseTensorType resultTensorType = op.getType().cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultTensorDtype = resultTensorType.getDtype(); // `torch.ops.aten._softmax`'s softmax with half to float conversion is not // supported on CPU, but we go ahead with the decomposing. // TODO: Add an e2e test once upstream support is added. // If `half_to_float` is set, we convert the input's elemental type to match // that of output's. if (halfToFloat) { Location loc = op.getLoc(); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); self = rewriter.create( loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); Value result = getSoftmaxResult(op, self, resultTensorType, accumulatorTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); rewriter.replaceOpWithNewOp(op, resultTensorType, result); return success(); } }; } // namespace // Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) => // newGrad = gradOutput * output // result = newGrad - output * sum(newGrad, dim)) // // Refer to // https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31 namespace { class DecomposeAten_SoftmaxBackwardDataOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value output = op.getOutput(); Value dim = op.getDim(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value newGrad = rewriter.create(loc, tensorType, gradOutput, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, newGrad, output, newGrad, dim); if (!result) return rewriter.notifyMatchFailure( op, "nullptr returned by createSoftmaxBackwardCommonKernel function."); rewriter.replaceOp(op, result); return success(); } }; } // namespace // AtenTanhBackwardOp(gradOutput, output) => // result = gradOutput * (1 - output^2) // To get away from broadcasts the above formula is expanded i.e., // result = gradOutput - (gradOutput * output^2) namespace { class DecomposeAtenTanhBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTanhBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); // `output` is the value flowing out from tanh. Hence, tanh(x) = output. // Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2). Value output = op.getOutput(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value tanhSquare = rewriter.create(loc, tensorType, output, output); Value gradMulTanhSquare = rewriter.create( loc, tensorType, tanhSquare, gradOutput); Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput, gradMulTanhSquare); rewriter.replaceOp(op, newGrad); return success(); } }; } // namespace // Aten_LogSoftmaxBackwardDataOp(gradOutput, output, dim) => // result = gradOutput - (exp(output) * sum(gradOutput, dim)) namespace { class DecomposeAten_LogSoftmaxBackwardDataOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_LogSoftmaxBackwardDataOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value output = op.getOutput(); Value dim = op.getDim(); BaseTensorType tensorType = gradOutput.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value expOut = rewriter.create(loc, tensorType, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim); if (!result) return rewriter.notifyMatchFailure( op, "nullptr returned by createSoftmaxBackwardCommonKernel function."); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeAtenAMinMaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Torch::AtenAminOp op, PatternRewriter &rewriter) const override { llvm::SmallVector dimList; if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { return rewriter.notifyMatchFailure(op, "dims not foldable constants"); } bool keepdim; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); } auto loc = op.getLoc(); std::sort(dimList.begin(), dimList.end(), std::greater()); Value reduction = op.getSelf(); auto resultTy = cast(op.getType()); auto reductionTy = cast(reduction.getType()); llvm::SmallVector reductionShape(reductionTy.getSizes()); for (auto dim : dimList) { auto dimValue = rewriter.create( loc, rewriter.getI64IntegerAttr(dim)); reductionShape[dim] = 1; if (!keepdim) { for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) reductionShape[i] = reductionShape[i + 1]; reductionShape.resize(reductionShape.size() - 1); } reductionTy = rewriter.getType( reductionShape, resultTy.getOptionalDtype()); auto idxTy = rewriter.getType( reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); llvm::SmallVector types{reductionTy, idxTy}; reduction = rewriter .create(loc, types, reduction, dimValue, op.getKeepdim()) .getResult(0); } rewriter.replaceOp(op, reduction); return success(); } }; } // namespace // Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into // `AtenMinDimOp` namespace { template class DecomposeAtenArgMinMaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value dim = op.getDim(); Value keepDim = op.getKeepdim(); Value result = op.getResult(); BaseTensorType inputType = input.getType().cast(); BaseTensorType indicesTensorType = result.getType().cast(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a rank"); } unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) return failure(); BaseTensorType valueTensorType = inputType .getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), inputType.getOptionalDtype()) .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = inputType .getWithSizesAndDtype({kUnknownSize}, inputType.getOptionalDtype()) .cast(); dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank - 1)); input = rewriter.create(loc, flattenType, input, dim, end); } Value resultArg = rewriter .create(loc, valueTensorType, indicesTensorType, input, dim, keepDim) .getIndices(); rewriter.replaceOp(op, resultArg); return success(); } }; } // namespace // Decompose `aten.bucketize` into the following op sequence: // // def aten_bucketize(input, boundaries, out_int32, right): // unsqz_input = input.unsqueeze(-1) // if not right: // comparison = unsqz_input <= boundaries // else: // comparison = unsqz_input < boundaries // indices = torch.argmax(comparison.float(), dim=-1) // within_bound = comparison[..., -1] // result = torch.where(within_bound, indices, boundaries.shape[0]) // if out_int32: // result = result.int() // return result // namespace { class DecomposeAtenBucketizeTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBucketizeTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: input must have known sizes"); } ArrayRef inputShape = inputType.getSizes(); Value boundaries = op.getBoundaries(); auto boundariesType = boundaries.getType().cast(); if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) { return rewriter.notifyMatchFailure(op, "unimplemented: boundaries must have " "known sizes and must be a 1D array"); } int64_t boundariesSize = boundariesType.getSizes()[0]; bool outInt32; if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) { return rewriter.notifyMatchFailure( op, "unimplemented: out_int32 must be a constant bool"); } bool right; if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) { return rewriter.notifyMatchFailure( op, "unimplemented: right must be a constant bool"); } // unsqueeze input at the last dim to make it broadcastable with boundaries Value constMinusOne = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor"); } Value unsqzInput = *unsqzTensorInfo; // compare unsqueezed input with boundaries SmallVector compareShape(inputShape); compareShape.push_back(boundariesSize); Type compareType = inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type()); Value compare; if (!right) { compare = rewriter.create(loc, compareType, unsqzInput, boundaries); } else { compare = rewriter.create(loc, compareType, unsqzInput, boundaries); } // convert the comparison results to float32 as the argmax op input, // which does not support integer dtype in LINALG backend Value compareF32 = convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type()); // get the first boundary index where the input element is less than (or // equal to) the boundary value Type indicesType = inputType.getWithSizesAndDtype( inputShape, rewriter.getIntegerType(64, IntegerType::Signed)); Value constFalse = rewriter.create(loc, false); Value indices = rewriter.create(loc, indicesType, compareF32, /*dim=*/constMinusOne, /*keepdim=*/constFalse); // get the comparison results between each input element and the rightmost // boundary value Type withinUpperBoundType = inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type()); Value withinUpperBound = rewriter.create( loc, withinUpperBoundType, compare, /*dim=*/constMinusOne, /*index=*/constMinusOne); // If the input element is less than (or equal to) the rightmost boundary, // take the max index as result. Otherwise, the element is beyond the // rightmost boundary, so take the boundary size. Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value upperBound = rewriter.create(loc, boundaries, /*dim=*/constZero); Value result = rewriter.create( loc, indicesType, withinUpperBound, indices, upperBound); if (outInt32) { result = convertTensorToDtype( rewriter, loc, result, rewriter.getIntegerType(32, IntegerType::Signed)); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // To avoid overflow we use the following decomposition rule: // x_max = aten.max(x, dim, keepdim=True)[0] // shifted = x - x_max // shifted_logsumexp = aten.log(aten.sum(aten.exp(shifted), dim, keepdim=True)) // log_softmax = shifted - shifted_logsumexp template static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); Value self = op.getSelf(); BaseTensorType tensorType = self.getType().cast(); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); if (!xMax) return nullptr; Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax); Value shiftedExp = rewriter.create(loc, tensorType, shifted); Value shiftedSumExp = createSumAlongDimension(rewriter, loc, op, shiftedExp, dim, /*keepDim=*/true); if (!shiftedSumExp) return nullptr; Value shiftedLogSumExp = rewriter.create(loc, shiftedSumExp.getType(), shiftedSumExp); Value result = createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp); return result; } namespace { class DecomposeAtenLogSoftmaxIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); if (!op.getDtype().getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for log_softmax"); BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); Value logSoftmax = getLogSoftmaxResult(op, rewriter); if (!logSoftmax) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); rewriter.replaceOp(op, logSoftmax); return success(); } }; } // namespace namespace { class DecomposeAten_LogSoftmaxOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op, PatternRewriter &rewriter) const override { bool halfToFloat; if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) return rewriter.notifyMatchFailure( op, "Expected a boolean value for half_to_float"); // Currently, setting `halfToFloat` is not supported as the E2E testing for // the same is not present on CPU. if (halfToFloat) return rewriter.notifyMatchFailure( op, "halfToFloat is currently not supported."); Value _logSoftmax = getLogSoftmaxResult(op, rewriter); if (!_logSoftmax) return rewriter.notifyMatchFailure( op, "getLogSoftmaxResult function returned nullptr"); rewriter.replaceOp(op, _logSoftmax); return success(); } }; } // namespace // Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMatmulOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); Value rhs = op.getOther(); std::optional maybeLhsRank = getTensorRank(lhs); std::optional maybeRhsRank = getTensorRank(rhs); if (!maybeLhsRank || !maybeRhsRank) { return rewriter.notifyMatchFailure( op, "expected input tensors to have a rank"); } unsigned lhsRank = *maybeLhsRank; unsigned rhsRank = *maybeRhsRank; if (lhsRank == 2 && rhsRank == 2) { // If both lhs and rhs ranks are 2 then map it to `aten.mm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); } else if (lhsRank == 3 && rhsRank == 3) { // If both lhs and rhs ranks are 3 then map it to `aten.bmm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); } else { return failure(); } return success(); } }; } // namespace // Decompose aten.mv into: aten.matmul. namespace { class DecomposeAtenMvOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMvOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); Value rhs = op.getVec(); rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); return success(); } }; } // namespace // Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, // aten.add.Tensor and aten.mull.Tensor. See // https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. // def linalg_cross(self: Tensor, other: Tensor, dim: int = -1): // broadcast_shape = compute_broadcast_shape(self, other) // a = torch.broadcast_to(self, broadcast_shape) // b = torch.broadcast_to(other, broadcast_shape) // idx = torch.arange(3) // return a.index_select(dim, (idx + 1) % 3) * // b.index_select(dim, (idx + 2) % 3) - // a.index_select(dim, (idx + 2) % 3) * // b.index_select(dim, (idx + 1) % 3) namespace { class DecomposeAtenLinalgCrossOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLinalgCrossOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); Value other = op.getOther(); Type opType = op.getType(); Value dim = op.getDim(); auto resType = self.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Type dtype = resType.getDtype(); if (isa(dtype)) { return rewriter.notifyMatchFailure( op, "lowering of aten.linalg_cross for complex inputs dtype is " "currently unimplemented"); } // calculate common shape for broadcast SmallVector broadcastShape; SmallVector broadcastShapeValue; computeBroadcastShape(rewriter, loc, self, other, broadcastShape, broadcastShapeValue); Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(broadcastShape), dtype); Value indexBroadcastShapeTorchList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), broadcastShapeValue); // broadcast tensors to common shape auto a = rewriter.create(loc, broadcastType, self, indexBroadcastShapeTorchList); auto b = rewriter.create(loc, broadcastType, other, indexBroadcastShapeTorchList); // create constants Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constTwo = rewriter.create( loc, rewriter.getI64IntegerAttr(2)); Value constThree = rewriter.create( loc, rewriter.getI64IntegerAttr(3)); Value none = rewriter.create(loc); // idx = torch.arange(3) auto outType = dyn_cast(opType); auto arangeType = outType.getWithSizesAndDtype( llvm::ArrayRef(3), IntegerType::get(op.getContext(), 64, IntegerType::Signed)); auto idx = rewriter.create( loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); // (idx + 1) and (idx + 2) auto idxPlusOne = rewriter.create(loc, arangeType, idx, constOne, constOne); auto idxPlusTwo = rewriter.create(loc, arangeType, idx, constTwo, constOne); // (idx + 1) % 3 and (idx + 2) % 3 auto idxPlusOneRemainderThree = rewriter.create( loc, arangeType, idxPlusOne, constThree); auto idxPlusTwoRemainderThree = rewriter.create( loc, arangeType, idxPlusTwo, constThree); // a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3) auto idxSelectAPlusOne = rewriter.create( loc, opType, a, dim, idxPlusOneRemainderThree); auto idxSelectBPlusTwo = rewriter.create( loc, opType, b, dim, idxPlusTwoRemainderThree); auto firstMul = rewriter.create( loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); // a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) auto idxSelectAPlusTwo = rewriter.create( loc, opType, a, dim, idxPlusTwoRemainderThree); auto idxSelectBPlusOne = rewriter.create( loc, opType, b, dim, idxPlusOneRemainderThree); auto secondMul = rewriter.create( loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); // subtract the results of the two multiplications from above rewriter.replaceOpWithNewOp(op, opType, firstMul, secondMul, constOne); return success(); } }; } // namespace // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // // If input is a tensor of shape // (*leading_dims, C*r*r, H, W), // // where leading_dims is of size N, then // X = pixel_shuffle(input, upscale_factor) // // gets replaced with // X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) // X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) // X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) // # shape (*leading_dims, C, H, r, W, r) // X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) // X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) // // 'r' above is referred to as the 'upscale factor' or just 'factor' below. namespace { class DecomposeAtenPixelShuffleOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPixelShuffleOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value inValue = op.getSelf(); auto inType = inValue.getType().cast(); auto maybeSizes = inType.getOptionalSizes(); if (!maybeSizes) { return rewriter.notifyMatchFailure( op, "Expected input tensor to have known rank."); } auto inShape = maybeSizes.value(); auto inRank = inShape.size(); // The input tensor must have at least 3 dimensions: (1) the channel // dimension which gets smaller by 'factor*factor', (2) the H channel which // gets larger by 'factor' and (3) the W channel which get larger by // 'factor'. The total number of dimensions is 3 + N, where N is the number // of leading dimensions, and N >= 0 so the input must have rank at least 3. if (inRank < 3) return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); const auto inOptionalDType = inType.getOptionalDtype(); auto getTypeFromShape = [inOptionalDType](auto &&vals) { // Get a vector of integers from a vector of Values. auto getIntShape = [](auto &&vals) { SmallVector shape; shape.reserve(vals.size()); for (auto v : vals) { int64_t cst_val; if (matchPattern(v, m_TorchConstantInt(&cst_val))) { shape.push_back(cst_val); } else { shape.push_back(kUnknownSize); } } return shape; }; const auto intShape = getIntShape(vals); return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape), inOptionalDType); }; auto nLeadingDims = inRank - 3; // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead // of 'create': if the dimension size is known, then the AtenSizeIntOp is // folded to a ConstantOp. auto getDimSize = [&](uint64_t i) -> Value { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); return rewriter.createOrFold(loc, inValue, dim); }; auto inC = getDimSize(inRank - 3); auto inH = getDimSize(inRank - 2); auto inW = getDimSize(inRank - 1); auto factor = op.getUpscaleFactor(); Value factorSquared = rewriter.createOrFold(loc, factor, factor); Value outC = rewriter.createOrFold(loc, inC, factorSquared); Value outH = rewriter.createOrFold(loc, inH, factor); Value outW = rewriter.createOrFold(loc, inW, factor); SmallVector dimensionConstants; dimensionConstants.reserve(inRank + 2); for (unsigned i = 0; i < inRank + 2; ++i) { dimensionConstants.push_back( rewriter.create(loc, rewriter.getI64IntegerAttr(i))); } SmallVector leadingDims; leadingDims.reserve(nLeadingDims); for (unsigned i = 0; i < nLeadingDims; ++i) { Value leadingDimSize = rewriter.createOrFold( loc, inValue, dimensionConstants[i]); leadingDims.push_back(leadingDimSize); } SmallVector partiallyExpandedShape = leadingDims; partiallyExpandedShape.append({outC, factorSquared, inH, inW}); SmallVector prePermuteShape = leadingDims; prePermuteShape.append({outC, factor, factor, inH, inW}); SmallVector postPermuteShape = leadingDims; postPermuteShape.append({outC, inH, factor, inW, factor}); SmallVector partiallyCollapsedShape = leadingDims; partiallyCollapsedShape.append({outC, inH, factor, outW}); SmallVector outShape = leadingDims; outShape.append({outC, outH, outW}); SmallVector permutation{dimensionConstants.begin(), dimensionConstants.begin() + nLeadingDims}; SmallVector permutationTail{0, 3, 1, 4, 2}; for (uint64_t d : permutationTail) { permutation.push_back(dimensionConstants[nLeadingDims + d]); } Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); // Split input channel inC -> (inC, factorSquared) auto partiallyExpanded = rewriter .create( loc, getTypeFromShape(partiallyExpandedShape), inValue, dimensionConstants[nLeadingDims], outC) .getResult(); // Split new dimension factorSquared -> (factor, factor) auto fullyExpanded = rewriter.create( loc, getTypeFromShape(prePermuteShape), partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor); // Perform the permutation auto permuted = rewriter.create(loc, getTypeFromShape(postPermuteShape), fullyExpanded, permuteDimsOrder); // Collapse final 2 dimension auto partiallyCollapsed = rewriter.create( loc, getTypeFromShape(partiallyCollapsedShape), permuted, dimensionConstants[nLeadingDims + 3], dimensionConstants[nLeadingDims + 4]); // Collapse back to original rank rewriter.replaceOpWithNewOp( op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims + 1], dimensionConstants[nLeadingDims + 2]); return success(); } }; } // namespace // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); Value cst6 = rewriter.create(loc, rewriter.getI64IntegerAttr(6)); Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6); Value relu6Out = rewriter.create(loc, inputType, relu, sixTensor); return relu6Out; } namespace { class DecomposeAtenRelu6Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRelu6Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value relu6 = getRelu6Results(rewriter, loc, op.getSelf()); rewriter.replaceOp(op, relu6); return success(); } }; } // namespace // Hardswish(x) = x * Relu6(x+3)/6 namespace { class DecomposeAtenHardswishOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardswishOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type inputType = input.getType(); Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantThree = rewriter.create( loc, rewriter.getI64IntegerAttr(3)); Value constantSix = rewriter.create( loc, rewriter.getI64IntegerAttr(6)); Value inputPlusThree = rewriter.create( loc, inputType, input, constantThree, /*alpha=*/constantOne); Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree); Value divTensor = rewriter.create(loc, inputType, relu6, constantSix); Value mulTensor = rewriter.create(loc, inputType, divTensor, input); rewriter.replaceOp(op, mulTensor); return success(); } }; } // namespace // LeakyRelu = max(0,x) + negative_slope * min(0,x) namespace { class DecomposeAtenLeakyReluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLeakyReluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); Value negativeOutput = rewriter.create(loc, resType, zeroTensor, input); Value scaledNegativeOutput = rewriter.create( loc, resType, negativeOutput, negativeSlope); Value leakyReluOutput = rewriter.create( loc, resType, positiveOutput, scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluOutput); return success(); } }; } // namespace // LeakyReluBackward = max(0,grad) + negative_slope * min(0,x) namespace { class DecomposeAtenLeakyReluBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLeakyReluBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value gradOutput = op.getGradOutput(); Value input = op.getSelf(); Value negativeSlope = op.getNegativeSlope(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } bool selfIsResult = false; if (!matchPattern(op.getSelfIsResult(), m_TorchConstantBool(&selfIsResult)) || selfIsResult) return rewriter.notifyMatchFailure( op, "unimplemented: self_is_result should be false"); Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, gradOutput); Value negativeOutput = rewriter.create(loc, resType, zeroTensor, input); Value scaledNegativeOutput = rewriter.create( loc, resType, negativeOutput, negativeSlope); Value leakyReluBackwardOutput = rewriter.create( loc, resType, positiveOutput, scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluBackwardOutput); return success(); } }; } // namespace namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPreluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value weight = op.getWeight(); auto resType = op.getType().cast(); auto boolTensorType = rewriter.getType( resType.getOptionalSizes(), rewriter.getI1Type()); Value zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value inputMulWeight = rewriter.create(loc, resType, input, weight); Value lessThanZero = rewriter.create(loc, boolTensorType, input, zero); Value preluOutput = rewriter.create( loc, resType, lessThanZero, inputMulWeight, input); rewriter.replaceOp(op, preluOutput); return success(); } }; } // namespace namespace { class DecomposeAtenLerpScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLerpScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto start = op.getSelf(); auto inputType = start.getType().cast(); auto delta = rewriter.create(loc, inputType, op.getEnd(), start, cstOne); auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); auto lerp = rewriter.create(loc, inputType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); } }; } // namespace // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value alpha = op.getAlpha(); Value scale = op.getScale(); Value inputScale = op.getInputScale(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value maxZeroX = rewriter.create(loc, resType, zeroTensor, input); Value positiveOutput = rewriter.create(loc, resType, maxZeroX, scale); Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); Value scaledMinZeroX = rewriter.create(loc, resType, minZeroX, inputScale); Value expX = rewriter.create(loc, resType, scaledMinZeroX); Value expXM1 = rewriter.create(loc, resType, expX, constantOne, constantOne); Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, scale); Value negativeOutput = rewriter.create(loc, resType, scaledExpXM1, alpha); Value eluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); rewriter.replaceOp(op, eluOutput); return success(); } }; } // namespace // Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1))) namespace { class DecomposeAtenSeluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSeluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } // Define λ and α double scale = 1.0507009873554804934193349852946; double alpha = 1.6732632423543772848170429916717; // Create constants for λ and α Value scaleVal = rewriter.create( loc, rewriter.getF64FloatAttr(scale)); Value alphaVal = rewriter.create( loc, rewriter.getF64FloatAttr(alpha)); // Create zero tensor for comparison Value constantZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); // Calculate positive and negative parts Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); Value expInput = rewriter.create(loc, resType, minZeroX); Value expInputMinusOne = rewriter.create( loc, resType, expInput, constantOne, constantOne); Value negativeOutput = rewriter.create( loc, resType, expInputMinusOne, alphaVal); // Multiply the result by λ Value seluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); seluOutput = rewriter.create(loc, resType, seluOutput, scaleVal); // Replace the original operation rewriter.replaceOp(op, seluOutput); return success(); } }; } // namespace namespace { class DecomposeAtenTOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTOp op, PatternRewriter &rewriter) const override { Value lhs = op.getSelf(); std::optional lhsRank = getTensorRank(lhs); auto loc = op.getLoc(); if (!lhsRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } else if (*lhsRank > 2) { std::string errorMessage = "t() expects a tensor with <=2 dimensions, but self is " + std::to_string(*lhsRank) + "D"; return rewriter.notifyMatchFailure(op, errorMessage.c_str()); } else if (*lhsRank < 2) rewriter.replaceOp(op, lhs); else { Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), lhs, zero, one); } return success(); } }; } // namespace // Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`. namespace { class DecomposeAtenStackOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStackOp op, PatternRewriter &rewriter) const override { SmallVector tensors; if (!getListConstructElements(op.getTensors(), tensors)) { return rewriter.notifyMatchFailure( op, "unimplemented: the tensor list is not from list construct"); } // Ensure all tensors have known sizes for (Value tensor : tensors) { BaseTensorType tensorType = tensor.getType().cast(); if (!tensorType.hasSizes()) { return rewriter.notifyMatchFailure( op, "unimplemented: one tensor does not have known sizes"); } } SmallVector unsqueezedTensors; for (Value tensor : tensors) { auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim()); if (failed(unsqueezedInfo)) { return rewriter.notifyMatchFailure( op, "cannot generate unsqueeze tensor op"); } unsqueezedTensors.push_back(*unsqueezedInfo); } Type listElemType = op.getType().cast().getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value unsqueezedTensorList = rewriter.create( op.getLoc(), listType, unsqueezedTensors); rewriter.replaceOpWithNewOp(op, op.getType(), unsqueezedTensorList, op.getDim()); return success(); } }; } // namespace // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { class DecomposeAtenRollOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRollOp op, PatternRewriter &rewriter) const override { SmallVector shifts; if (!getListConstructElements(op.getShifts(), shifts)) return rewriter.notifyMatchFailure( op, "unimplemented: shifts not list of Scalar"); SmallVector dims; if (!getListConstructElements(op.getDims(), dims)) return rewriter.notifyMatchFailure( op, "unimplemented: dims not list of Scalar"); if (shifts.size() != dims.size()) return op.emitError("list sizes of shifts and dims are not the same"); auto loc = op.getLoc(); Value constNone = rewriter.create(loc); Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); auto self = op.getSelf(); auto selfTy = self.getType().cast(); // roll(input, shift, dim) = cat({ // slice(input, dim, -shift, none), // slice(input, dim, 0, -shift)}, dim) auto imitateRoll = [&](Value input, Value shift, Value dim, int64_t cstDim) { Value negShift = rewriter.create(loc, shift); ArrayRef inputShape = selfTy.getSizes(); SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), selfTy.getOptionalDtype()); Value slice0 = rewriter.create( loc, sliceTy, input, dim, negShift, constNone, constOne); Value slice1 = rewriter.create( loc, sliceTy, input, dim, constZero, negShift, constOne); Type listType = Torch::ListType::get(sliceTy); Value slices = rewriter.create( loc, listType, llvm::ArrayRef{slice0, slice1}); return rewriter.create(loc, self.getType(), slices, dim); }; std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; Value output = self; auto nShifts = shifts.size(); for (size_t k = 0; k < nShifts; ++k) { auto dim = dims[k]; int64_t cstDim = -1; if (!matchPattern(dim, m_TorchConstantInt(&cstDim))) return rewriter.notifyMatchFailure( op, "unimplemented: dim must be constant"); cstDim = toPositiveDim(cstDim, rank); output = imitateRoll(output, shifts[k], dim, cstDim); } rewriter.replaceOp(op, output); return success(); } }; } // namespace // Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRepeatOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure( op, "Unimplemented: no implementation for rankless tensor"); SmallVector repeats; if (!getListConstructElements(op.getRepeats(), repeats)) return rewriter.notifyMatchFailure( op, "Unimplemented: repeats not list of Scalar"); int64_t rank = selfTy.getSizes().size(); if (rank > static_cast(repeats.size())) { return rewriter.notifyMatchFailure( op, "repeats are not matched with self's rank"); } int64_t repeatSz = repeats.size(); int64_t batch = repeatSz - rank; if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "input sizes unknown"); // Materialize out 1 dimensions to broadcast along. This includes // materializing out preceding batch dimensions: for (int i = 0; i < repeatSz; ++i) { auto oldSizes = selfTy.getSizes(); llvm::SmallVector sizes; int64_t squeezeDim = i < batch ? i : i * 2 - batch; for (int j = 0; j < squeezeDim; ++j) sizes.push_back(oldSizes[j]); sizes.push_back(1); for (int j = squeezeDim, s = oldSizes.size(); j < s; j++) sizes.push_back(oldSizes[j]); Value dim = rewriter.create(loc, squeezeDim); selfTy = rewriter.getType(sizes, selfTy.getOptionalDtype()); self = rewriter.create(loc, selfTy, self, dim); } llvm::SmallVector lengths; for (int i = 0; i < repeatSz; ++i) { if (i < batch) { lengths.push_back(repeats[i]); continue; } Value iv = rewriter.create( loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch)); Value dim = rewriter.create(loc, self, /*dim=*/iv); lengths.push_back(repeats[i]); lengths.push_back(dim); } Value lengthv = rewriter.create( loc, ListType::get(rewriter.getType()), lengths); llvm::SmallVector expandShape(selfTy.getSizes()); for (int i = 0; i < repeatSz; ++i) { int64_t repeatDim = i < batch ? i : i * 2 - batch; int64_t repeat; if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) repeat = Torch::kUnknownSize; expandShape[repeatDim] = repeat; } auto mulDim = [](int64_t lhs, int64_t rhs) { if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) return Torch::kUnknownSize; return lhs * rhs; }; BaseTensorType expandTy = rewriter.getType( expandShape, selfTy.getOptionalDtype()); Value expand = rewriter.create(loc, expandTy, self, lengthv); for (int i = 0; i < rank; ++i) { auto oldShape = expandTy.getSizes(); llvm::SmallVector newShape; int64_t flattenDim = i + batch; for (int j = 0; j < flattenDim; ++j) newShape.push_back(oldShape[j]); newShape.push_back( mulDim(oldShape[flattenDim], oldShape[flattenDim + 1])); for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j) newShape.push_back(oldShape[j]); expandTy = rewriter.getType(newShape, expandTy.getOptionalDtype()); // Used to keep the return type the same on the last flatten: expandTy = i < rank - 1 ? expandTy : cast(op.getType()); Value start = rewriter.create( loc, rewriter.getI64IntegerAttr(flattenDim)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(flattenDim + 1)); expand = rewriter.create(loc, expandTy, expand, start, end); } rewriter.replaceOp(op, expand); return success(); } }; } // namespace // Decompose aten.flatten.using_ints into aten.view op. namespace { class DecomposeAtenFlattenUsingIntsOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); unsigned rank = *maybeRank; int64_t start, end; if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) || !matchPattern(op.getEndDim(), m_TorchConstantInt(&end))) { return rewriter.notifyMatchFailure( op, "unimplemented: requires start and end dims to be constants"); } SmallVector newSizes; if (rank == 0) { Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); newSizes.push_back(one); } else { start = toPositiveDim(start, rank); end = toPositiveDim(end, rank); if (start > end) { return rewriter.notifyMatchFailure( op, "expected end dim larger than start dim"); } newSizes.reserve(rank - end + start); for (int64_t k = 0; k < start; ++k) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( rewriter.create(loc, self, /*dim=*/dim)); } Value flattenDimSize = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); newSizes.push_back(flattenDimSize); for (int64_t k = end + 1; k < rank; ++k) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( rewriter.create(loc, self, /*dim=*/dim)); } } Value newSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), newSizes); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), newSizeList); return success(); } }; } // namespace // Decompose aten.unflatten.int into aten.view op. namespace { class DecomposeAtenUnflattenIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenUnflattenIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); MLIRContext *context = op.getContext(); BaseTensorType outputTensorType = op.getType().cast(); if (!outputTensorType.hasSizes()) return rewriter.notifyMatchFailure( op, "unimplemented: output must have known sizes"); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); unsigned inputRank = *maybeRank; auto inputTensorType = self.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); SmallVector sizesInts; if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizesInts))) return rewriter.notifyMatchFailure( op, "sizes must be a list of constant ints"); bool inferred = false; if (llvm::count(sizesInts, -1) > 1) return rewriter.notifyMatchFailure( op, "only one of sizes' elements can be -1"); int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) return rewriter.notifyMatchFailure( op, "unimplemented: requires dim to be constants"); dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( op, "Unimplemented: sizes not list of Scalar"); // Create new sizes based on the unflattened dimension. SmallVector newSizes; for (int64_t i = 0; i < inputRank; ++i) { Value dimValue = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); Value dimSize = rewriter.create(loc, self, /*dim=*/dimValue); if (i == dimInt) { int64_t inferredSizeInt = inputShape[i]; int64_t inferredDim; for (unsigned j = 0; j < sizesInts.size(); ++j) { if (sizesInts[j] == -1) { inferred = true; inferredDim = j; } else { Value sizeValue = rewriter.create( loc, rewriter.getI64IntegerAttr(sizesInts[j])); newSizes.push_back(sizeValue); inferredSizeInt = inferredSizeInt / sizesInts[j]; } } if (inferred) { Value inferredSize = rewriter.create( loc, rewriter.getI64IntegerAttr(inferredSizeInt)); newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize); } } else { newSizes.push_back(dimSize); } } // Create the AtenViewOp to replace the original op. Value newSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), newSizes); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), newSizeList); return success(); } }; } // namespace // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExpandOp op, PatternRewriter &rewriter) const override { bool implicit = false; if (!matchPattern(op.getImplicit(), m_TorchConstantBool(&implicit)) || implicit) { return rewriter.notifyMatchFailure( op, "unimplemented: requires implicit to be false"); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace // Decompose aten.where.Scalar into aten.where.self op. namespace { class DecomposeAtenWhereScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, otherTensor); return success(); } }; } // namespace // Decompose aten.where.ScalarOther into aten.where.self op. namespace { class DecomposeAtenWhereScalarOtherOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), op.getSelf(), otherTensor); return success(); } }; } // namespace // Decompose aten.where.ScalarSelf into aten.where.self op. namespace { class DecomposeAtenWhereScalarSelfOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, op.getOther()); return success(); } }; } // namespace namespace { class DecomposeAtenNanToNumOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNanToNumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); mlir::FloatType f64Type = rewriter.getF64Type(); Value nan = op.getNan(); Value posinf = op.getPosinf(); Value neginf = op.getNeginf(); auto baseType = ValueTensorType::getWithLeastStaticInformation(op.getContext()); if (dyn_cast_or_null(nan.getDefiningOp())) nan = rewriter.create( loc, rewriter.getFloatAttr( f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); if (dyn_cast_or_null(posinf.getDefiningOp())) posinf = rewriter.create( loc, rewriter.getFloatAttr( f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); if (dyn_cast_or_null(neginf.getDefiningOp())) neginf = rewriter.create( loc, rewriter.getFloatAttr( f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); Value isNan = rewriter.create(loc, baseType, op.getSelf()); Value where = rewriter.create( loc, baseType, isNan, nan, op.getSelf()); Value isposinf = rewriter.create(loc, baseType, where); where = rewriter.create( loc, baseType, isposinf, posinf, where); Value isneginf = rewriter.create(loc, baseType, where); rewriter.replaceOpWithNewOp( op, op.getType(), isneginf, neginf, where); return success(); } }; } // namespace // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value mask = op.getMask(); Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); rewriter.replaceOpWithNewOp(op, resType, mask, value, op.getSelf()); return success(); } }; } // namespace // Decompose aten._convolution-like to aten.convolution namespace { template class DecomposeAten_ConvolutionLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConvolutionLikeOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), op.getOutputPadding(), op.getGroups()); return success(); } }; } // namespace namespace { static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, Location loc, Value input, int64_t dimA, int64_t dimB, Value &transposed) { Type transposedType; if (failed(getTransposedType(input.getType().cast(), dimA, dimB, transposedType))) return failure(); Value cstDimA = rewriter.create( loc, rewriter.getI64IntegerAttr(dimA)); Value cstDimB = rewriter.create( loc, rewriter.getI64IntegerAttr(dimB)); transposed = rewriter.create( loc, transposedType, input, cstDimA, cstDimB); return success(); } class DecomposeAtenConvTbcOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvTbcOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); Value cstFalse = rewriter.create(op.getLoc(), false); Value oneList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector{rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(1))}); Value padding = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector{op.getPad()}); Value groups = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(1)); // convtbc has WNC layout for input and output // and WCF layout for weight // whereas Convolution is going to use Conv1DNcwFcwOp for 1d // which means we need the inputs in NCW and the weight in FCW Value selfWnc = op.getSelf(); Value selfNwc; Value selfNcw; if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc))) return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc"); if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw))) return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw"); Value weightWcf = op.getWeight(); Value weightFcw; if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw))) return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw"); Value outputNcw = rewriter.create( op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/ oneList, /*padding*/ padding, /*dilation*/ oneList, /*transpose*/ cstFalse, /*output_padding*/ emptyList, groups); // convert output from Ncw to Wnc Value outputNwc; Value outputWnc; if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc))) return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc"); if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc))) return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc"); rewriter.replaceOp(op, outputWnc); return success(); } }; } // namespace // Decompose aten.conv1d to aten.convolution namespace { class DecomposeAtenConv1dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConv1dOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, op.getGroups()); return success(); } }; } // namespace // Decompose aten.conv2d to aten.convolution namespace { class DecomposeAtenConv2dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConv2dOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, op.getGroups()); return success(); } }; } // namespace // Decompose aten.conv3d to aten.convolution namespace { class DecomposeAtenConv3dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConv3dOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, op.getGroups()); return success(); } }; } // namespace // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, PatternRewriter &rewriter) const override { Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); return success(); } }; } // namespace // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ // inputH // - 1 // + 2 * padding_[0] // - dilation_[0] * (weight.shape[2] - 1) // - (grad_output.shape[2] - 1) * stride_[0], // inputW // - 1 // + 2 * padding_[1] // - dilation_[1] * (weight.shape[3] - 1) // - (grad_output.shape[3] - 1) * stride_[1], // ] // // decomp_grad_input = torch.nn.functional.conv_transpose2d( // grad_output, // weight, // None, // stride_, // padding_, // output_padding_, // groups_, // dilation_, // ) // // input_transposed = torch.ops.aten.transpose(input, 0, 1) // grad_output_transposed = grad_output.view( // grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] // ) // decomp_grad_weight = torch.ops.aten.convolution( // input_transposed, // grad_output_transposed, // bias=None, // stride=dilation_, // padding=padding_, // dilation=stride_, // transposed=False, // output_padding=[0, 0], // groups=input.shape[0], // ) // decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) // decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) // decomp_grad_weight = decomp_grad_weight.view( // input_transposed.shape[0], // input_transposed.shape[1], // grad_output.shape[1], // *decomp_grad_weight.shape[2:] // ) // decomp_grad_weight = decomp_grad_weight.movedim(0, 2) // decomp_grad_weight = decomp_grad_weight.sum(dim=0) // // decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) namespace { class DecomposeAtenConvolutionBackwardOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value gradOutput = op.getGradOutput(); std::optional maybeGradRank = getTensorRank(gradOutput); if (!maybeGradRank) { return rewriter.notifyMatchFailure(op, "expected grad output to have a rank"); } unsigned gradRank = *maybeGradRank; if (gradRank != 4) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolutions supported."); Value cstZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value cstTwo = rewriter.create( loc, rewriter.getI64IntegerAttr(2)); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); SmallVector padding, dilation, stride; SmallVector paddingInt, dilationInt, strideInt, outputPaddingInt; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) return rewriter.notifyMatchFailure( op, "padding must be a list of constant ints"); if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) return rewriter.notifyMatchFailure( op, "stride must be a list of constant ints"); if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) return rewriter.notifyMatchFailure( op, "dilation must be a list of constant ints"); if (!llvm::all_of(dilationInt, [](int64_t dilationVal) { return dilationVal == 1; })) return rewriter.notifyMatchFailure( op, "unimplemented: only dilations of 1 supported."); if (!matchPattern(op.getOutputPadding(), m_TorchListOfConstantInts(outputPaddingInt))) return rewriter.notifyMatchFailure( op, "output padding must be a list of constant ints"); if (!llvm::all_of(outputPaddingInt, [](int64_t outPad) { return outPad == 0; })) return rewriter.notifyMatchFailure( op, "unimplemented: only output padding of 0 supported."); SmallVector outMask; if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask))) return rewriter.notifyMatchFailure( op, "only constant bool output_mask is supported."); for (unsigned i = 0; i < outMask.size(); i++) { if (outMask[i] == false) { Value result = op->getResults()[i]; if (!result.getUsers().empty()) return rewriter.notifyMatchFailure( op, "unimplemented: false value supported for output_mask only " "when the result tensor corresponding to that has no users."); } } bool transposed; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "transposed arg should be a constant bool."); if (transposed) return rewriter.notifyMatchFailure( op, "unimplemented: transposed convolutions are not supported."); getListConstructElements(op.getPadding(), padding); getListConstructElements(op.getStride(), stride); getListConstructElements(op.getDilation(), dilation); // Computing Grad Input. // Calculate output padding for first convolution. // output_padding_ = [ // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - // ((grad_out.size()[3] - 1) * stride_[1]), // ] SmallVector outputPaddingValues; for (unsigned i = 2; i < gradRank; i++) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); Value inputVecDim = rewriter.create(loc, input, dim); Value gradOutDim = rewriter.create(loc, gradOutput, dim); Value weightDim = rewriter.create(loc, weight, dim); Value inputVecDimMinusOne = rewriter.create(loc, inputVecDim, cstOne); Value gradOutDimMinusOne = rewriter.create(loc, gradOutDim, cstOne); Value weightDimMinusOne = rewriter.create(loc, weightDim, cstOne); Value twoTimesPadding = rewriter.create(loc, padding[i - 2], cstTwo); Value tmpA = rewriter.create(loc, weightDimMinusOne, dilation[i - 2]); Value tmpB = rewriter.create(loc, gradOutDimMinusOne, stride[i - 2]); Value outputPaddingVal = rewriter.create( loc, inputVecDimMinusOne, twoTimesPadding); outputPaddingVal = rewriter.create(loc, outputPaddingVal, tmpA); outputPaddingVal = rewriter.create(loc, outputPaddingVal, tmpB); outputPaddingValues.push_back(outputPaddingVal); } Value outputPaddingForGradInput = rewriter.create( loc, ListType::get(IntType::get(context)), outputPaddingValues); Value gradInput = rewriter.create( loc, op.getResultTypes()[0], gradOutput, weight, cstNone, op.getStride(), op.getPadding(), outputPaddingForGradInput, op.getGroups(), op.getDilation()); Type transposedType; if (failed(getTransposedType(input.getType().cast(), 0, 1, transposedType))) return failure(); Value inputTransposed = rewriter.create( loc, transposedType, input, cstZero, cstOne); // For the cases where the stride is non-unit, we compute the `GradWeight` // through this implementation. Value gradWeight; if (!llvm::all_of(strideInt, [](int64_t stride) { return stride == 1; })) { // Computing Grad Weight. SmallVector gradOutputSize; for (unsigned i = 0; i < gradRank; i++) { gradOutputSize.push_back(rewriter.create( loc, gradOutput, rewriter.create( loc, rewriter.getI64IntegerAttr(i)))); } Value gradOutputViewDimZero = rewriter.create( loc, gradOutputSize[0], gradOutputSize[1]); Value gradOutputViewShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], gradOutputSize[3]}); BaseTensorType gradOutputTy = gradOutput.getType().cast(); if (!gradOutputTy.hasSizes()) return failure(); SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); SmallVector gradOutputViewSizesInt(gradOutputSizesInt); if (gradOutputViewSizesInt[0] != kUnknownSize && gradOutputViewSizesInt[1] != kUnknownSize) gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; else gradOutputViewSizesInt[0] = kUnknownSize; gradOutputViewSizesInt[1] = 1; BaseTensorType gradOutputTypeForView = gradOutputTy .getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), gradOutputTy.getOptionalDtype()) .cast(); Value gradOutputView = rewriter.create( loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); BaseTensorType inputTransposedTy = inputTransposed.getType().cast(); if (!inputTransposedTy.hasSizes()) return failure(); SmallVector inputTransposedSizesInt( inputTransposedTy.getSizes()); SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], gradOutputViewSizesInt[0]}; for (unsigned i = 2; i < gradRank; i++) { if (inputTransposedSizesInt[i] != kUnknownSize && gradOutputViewSizesInt[i] != kUnknownSize) { int64_t kernelSizeInt = strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; gradWeightSizesInt.push_back( ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - kernelSizeInt) / dilationInt[i - 2]) + 1); } else { gradWeightSizesInt.push_back(kUnknownSize); } } BaseTensorType gradWeightTy = inputTransposedTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), inputTransposedTy.getOptionalDtype()) .cast(); Value numGroup = rewriter.create(loc, input, cstZero); gradWeight = rewriter.create( loc, gradWeightTy, inputTransposed, gradOutputView, cstNone, /*stride=*/op.getDilation(), op.getPadding(), /*dilation=*/op.getStride(), op.getTransposed(), op.getOutputPadding(), numGroup); BaseTensorType weightTy = weight.getType().cast(); if (!weightTy.hasSizes()) return failure(); SmallVector weightSizes(weightTy.getSizes()); for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { gradWeightSizesInt[i + 2] = weightSizes[i + 2]; BaseTensorType gradWeightNarrowTy = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), gradWeightTy.getOptionalDtype()) .cast(); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i + 2)); Value length = rewriter.create(loc, weight, dim); gradWeight = rewriter.create( loc, gradWeightNarrowTy, gradWeight, dim, /*start=*/cstZero, length); } SmallVector gradWeightViewShapeInt{ inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); gradWeightViewShapeInt.insert( gradWeightViewShapeInt.end(), {gradWeightSizesInt[2], gradWeightSizesInt[3]}); SmallVector gradWeightViewShapeValue; for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { gradWeightViewShapeValue.push_back( rewriter.create( loc, rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); } Value gradWeightViewShapeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), gradWeightViewShapeValue); BaseTensorType gradWeightTypeForView = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), gradWeightTy.getOptionalDtype()) .cast(); gradWeight = rewriter.create( loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); gradWeightTy = gradWeight.getType().cast(); SmallVector gradWeightDimsOrder = computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); SmallVector gradWeightMoveDimShape; for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { gradWeightMoveDimShape.push_back( gradWeightViewShapeInt[gradWeightDimsOrder[i]]); } BaseTensorType gradWeightTypeForMoveDim = gradWeightTy .getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), gradWeightTy.getOptionalDtype()) .cast(); gradWeight = rewriter.create( loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, /*destination=*/cstTwo); Value gradIntList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), llvm::ArrayRef{cstZero}); gradWeight = rewriter.create( loc, op.getResultTypes()[1], /*self=*/gradWeight, /*dim=*/gradIntList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); } else { if (failed(getTransposedType(gradOutput.getType().cast(), 0, 1, transposedType))) return failure(); Value gradOutputTransposed = rewriter.create( loc, transposedType, gradOutput, cstZero, cstOne); // Convolve input with grad_output. if (failed( getTransposedType(op.getResultTypes()[1].cast(), 0, 1, transposedType))) return failure(); gradWeight = rewriter.create( loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), op.getOutputPadding(), op.getGroups()); gradWeight = rewriter.create( loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); } // Computing Grad Bias. SmallVector dimIntList{cstZero}; for (unsigned i = 2; i < gradRank; i++) dimIntList.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); Value gradIntList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dimIntList); // Sum grad_output along dim 1. Value gradBias = rewriter.create( loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse, cstNone); rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); return success(); } }; } // namespace // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAddmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value mat1 = op.getMat1(); Value mat2 = op.getMat2(); std::optional mat1Rank = getTensorRank(mat1); std::optional mat2Rank = getTensorRank(mat2); // The operands `mat1`, `mat2` to aten.addmm must be of rank 2. if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) { return rewriter.notifyMatchFailure( op, "expected mat1, mat2 operands to aten.addmm to be rank 2"); } // TODO: Handle integer type operands. auto inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); } // matrix multiplication: matmul = mat1 @ mat2 Value matmul = rewriter.create(loc, op.getType(), mat1, mat2); // scaledInput = self * beta Value scaledInput = rewriter.create(loc, input.getType(), input, op.getBeta()); // result = scaledInput + alpha * matmul rewriter.replaceOpWithNewOp(op, op.getType(), scaledInput, matmul, op.getAlpha()); return success(); } }; } // namespace // Decompose aten.mean into: sum(x)/div(numTensorElements). namespace { class DecomposeAtenMeanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMeanOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value output = op.getResult(); BaseTensorType outputTensorType = output.getType().cast(); Value sum = rewriter.create(loc, outputTensorType, input, op.getDtype()); Value numTensorElements = rewriter.create(loc, input); rewriter.replaceOpWithNewOp(op, outputTensorType, sum, numTensorElements); return success(); } }; } // namespace // productDimSize = product(size(dim) for dim in dims) // aten.mean(x, dims) = aten.sum(x, dims) / productDimSize. namespace { class DecomposeAtenMeanDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMeanDimOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; Value dimList = op.getDim(); Value keepDim = op.getKeepdim(); Value dtype = op.getDtype(); Type outputType = op.getType(); MLIRContext *context = op.getContext(); BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa() || !isNoneOrFloatDtype(context, dtype)) { return rewriter.notifyMatchFailure( op, "only floating-point type is supported"); } SmallVector dimListElements; if (!getListConstructElements(dimList, dimListElements) && !dimList.getType().isa()) { return rewriter.notifyMatchFailure( op, "expected `dim` to be `None` or constructed from list construct"); } // Compute sum along dimensions specified in `dimList`. Value sumAlongDims = rewriter.create( loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. Value productDimSize; // Case: Reduce along all dims. if (dimListElements.empty() && inputRank != 0) { productDimSize = rewriter.create(loc, input); } else { productDimSize = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); for (Value dim : dimListElements) { Value dimSize = rewriter.create(loc, input, dim); productDimSize = rewriter.create(loc, productDimSize, dimSize); } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, productDimSize); return success(); } }; } // namespace namespace { class DecomposeAtenSquareOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSquareOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); rewriter.replaceOpWithNewOp(op, op.getType(), self, self); return success(); } }; } // namespace // Silu(x) = sigmoid(x) * x namespace { class DecomposeAtenSiluOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSiluOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); Value sigmoid = rewriter.create(op.getLoc(), op.getType(), self); rewriter.replaceOpWithNewOp(op, op.getType(), sigmoid, self); return success(); } }; } // namespace // pDash = 1.0 - p // boolMask = aten.rand_like(input) < pDash // dropout(input, p, train=True) = (boolMask * input) / pDash // dropout(input, p, train=False) = input namespace { class DecomposeAtenDropoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenDropoutOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getInput(); Value prob = op.getP(); bool train = false; if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); if (!train) { rewriter.replaceOp(op, input); return success(); } BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); Value noneVal = rewriter.create(loc); Value floatOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value oneMinusP = rewriter.create(loc, floatOne, prob); Value boolMask = rewriter.create( loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = rewriter.create(loc, inputType, boolMask, input); rewriter.replaceOpWithNewOp(op, op.getType(), maskedInput, oneMinusP); return success(); } }; class DeomposeAtenNativeDropoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeDropoutOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op->getContext(); Value input = op.getInput(); Value prob = op.getP(); bool train = false; if (!op.getTrain().getType().isa()) { if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { return rewriter.notifyMatchFailure( op, "train must be a boolean constant or none"); } } Value noneVal = rewriter.create(loc); if (!train) { Value i1Type = getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1)); Value inputSize = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), input); Value trueValue = rewriter.create(loc, 1); Value trueMask = rewriter.create( loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); rewriter.replaceOp(op, ArrayRef{input, trueMask}); return success(); } BaseTensorType inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); } Value floatOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value oneMinusP = rewriter.create(loc, floatOne, prob); Value boolMask = rewriter.create( loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = rewriter.create(loc, inputType, boolMask, input); Value output = rewriter.create( loc, op->getResultTypes()[0], maskedInput, oneMinusP); rewriter.replaceOp( op, ArrayRef{ output, convertTensorToDtype(rewriter, loc, boolMask, IntegerType::get(context, 1))}); return success(); } }; } // namespace // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; BaseTensorType rank0FloatTensorTy = op.getType().cast(); if (!rank0FloatTensorTy.hasSizes() || rank0FloatTensorTy.getSizes().size() != 0) { return rewriter.notifyMatchFailure( op, "expected aten.var to have a rank 0 tensor type"); } SmallVector dims; for (unsigned i = 0; i < inputRank; i++) dims.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dims); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp(op, rank0FloatTensorTy, self, dimList, op.getUnbiased(), /*keepdim=*/cstFalse); return success(); } }; } // namespace // Decompose aten.std to sqrt(var(x)) namespace { class DecomposeAtenStdOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorTy = self.getType().cast(); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "Only aten.std support floating type"); } Value var = rewriter.create(op->getLoc(), op.getType(), op.getSelf(), op.getUnbiased()); rewriter.replaceOpWithNewOp(op, op.getType(), var); return success(); } }; } // namespace // Softplus(x, beta, threshold) = // x * beta > threshold ? x : log(1 + exp(x * beta)) / beta namespace { class DecomposeAtenSoftplusOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSoftplusOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); Value inputTimesBeta = rewriter.create(loc, inputType, input, op.getBeta()); // out = log1p(exp(input * beta)) / beta Value exp = rewriter.create(loc, inputType, inputTimesBeta); Value log1p = rewriter.create(loc, inputType, exp); Value out = rewriter.create(loc, inputType, log1p, op.getBeta()); // Select where x * beta > threshold auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value condition = rewriter.create( loc, boolResType, inputTimesBeta, op.getThreshold()); rewriter.replaceOpWithNewOp(op, op.getType(), condition, input, out); return success(); } }; } // namespace // Decompose aten.std.dim to sqrt(var.dim(x)) namespace { class DecomposeAtenStdDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdDimOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorType = self.getType().cast(); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "aten.std.dim expects input tensor of floating-point type"); } Value varDim = rewriter.create( op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varDim); return success(); } }; } // namespace // Decompose aten.std.correction to sqrt(var.correction(x)) namespace { class DecomposeAtenStdCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenStdCorrectionOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); BaseTensorType inputTensorType = self.getType().cast(); if (!inputTensorType.hasDtype() || !inputTensorType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "aten.std.correction expects input tensor of floating-point type"); } Value varCorrection = rewriter.create( op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varCorrection); return success(); } }; } // namespace // Hardsigmoid(x) = max(0, min(1, (x+3)/6)) namespace { class DecomposeAtenHardsigmoidOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardsigmoidOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } // outputTensor = (input + 3) / 6. Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantThree = rewriter.create( loc, rewriter.getI64IntegerAttr(3)); Value constantSix = rewriter.create( loc, rewriter.getI64IntegerAttr(6)); Value inputPlusThree = rewriter.create( loc, inputType, input, constantThree, /*alpha=*/constantOne); Value outputTensor = rewriter.create( loc, inputType, inputPlusThree, constantSix); // result = max(0, min(1, (input+3)/6)) Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne); Value minResult = rewriter.create(loc, inputType, oneTensor, outputTensor); Value zeroTensor = createRank0Tensor(rewriter, loc, inputType, constantZero); rewriter.replaceOpWithNewOp(op, op.getType(), zeroTensor, minResult); return success(); } }; } // namespace namespace { class DecomposeAtenHardtanhOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenHardtanhOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } // result = min(maxVal, max(minVal, x)) Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal()); Value maxResult = rewriter.create(loc, inputType, input, minVal); Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.getMaxVal()); rewriter.replaceOpWithNewOp(op, op.getType(), maxVal, maxResult); return success(); } }; } // namespace namespace { class DecomposeAtenRandLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandLikeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type resultType = op.getType(); auto inputType = input.getType().cast(); if (!inputType.hasDtype() || !inputType.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); } // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. Value none = rewriter.create(loc); Value zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value emptyTensor = rewriter.create( loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, /*generator=*/none); return success(); } }; } // namespace namespace { // Bernoulli(x, p) = (randLike(float(x)) < p).cast(type(x)). Here, // 1. p must be a float tensor. // 2. The shape of p should be broadcastable to the shape of x. // 3. Bernoulli(x, p) returns a tensor of the same type as that of x. static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op, Location loc, Value input, Value prob, Value &output) { auto inputType = input.getType().cast(); auto probType = prob.getType().cast(); // Both the `input` and `prob` must be ranked tensors. if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() || !probType.hasDtype()) { return rewriter.notifyMatchFailure( op, "can't decompose bernoulli like ops without sizes or dtype"); } // The `prob` is expected to be a float type tensor. if (!probType.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "probabilities must be a float type tensor"); } // Since the `aten.randLike` op expects float-type operand, create a // float-type tensor with the same shape as that of the `input`. Value floatTensor = convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); Value none = rewriter.create(loc); Value randomVal = rewriter.create( loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); // Bernoulli(x, p) = randLike(float(x)) < p. auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value lessThanP = rewriter.create(loc, boolResType, randomVal, prob); // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. output = convertTensorToDtype(rewriter, loc, lessThanP, inputType.getDtype()); return success(); } // aten.bernoulli(x) = randLike(x) < x. Here, the input x is a tensor // containing probabilities to be used for drawing the binary random number. class DecomposeAtenBernoulliOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBernoulliOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, input, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; // aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)). // Since the input x can be an integer tensor, it's important to cast it to // float type before passing it to the `aten.randLike` op. template class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BernoulliLikeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value p = op.getP(); if (!op.getGenerator().getType().template isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); auto inputType = input.getType().cast(); SmallVector empty; Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), rewriter.getF64Type()); Value prob = rewriter.create(loc, tensorType, p); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; // aten.bernoulli.Tensor(x, p) = (randLike(float(x)) < p).cast(type(x)). // Since the input x can be an integer tensor, it's important to cast it to // float type before passing it to the `aten.randLike` op. class DecomposeAtenBernoulliTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBernoulliTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value prob = op.getP(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) return rewriter.notifyMatchFailure( op, "decomposeBernoulliLikeOp failed to decompose the op"); rewriter.replaceOp(op, output); return success(); } }; } // namespace namespace { // Decompose exponential() to do inverse transform sampling. // - https://en.wikipedia.org/wiki/Inverse_transform_sampling // With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus, // exponential() = - ln(1 - uniform(0, 1)) / lambda. class DecomposeAtenExponentialOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExponentialOp op, PatternRewriter &rewriter) const override { if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); Location loc = op.getLoc(); Type resultType = op.getType(); // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. Value none = rewriter.create(loc); Value zero = rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value emptyTensor = rewriter.create( loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); Value x = rewriter.create(loc, resultType, emptyTensor, /*from=*/zero, /*to=*/one, /*generator=*/none); Value negX = rewriter.create(loc, resultType, x); Value oneMinusX = rewriter.create(loc, resultType, negX, one, /*alpha=*/one); Value lnOneMinusX = rewriter.create(loc, resultType, oneMinusX); Value negLambda = rewriter.create(loc, op.getLambd()); rewriter.replaceOpWithNewOp(op, resultType, lnOneMinusX, negLambda); return success(); } }; // aten.normal_functional(mean, sigma) = randn() * sigma + mean. class DecomposeAtenNormalFunctionalOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, PatternRewriter &rewriter) const override { if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); Location loc = op.getLoc(); Type resultType = op.getType(); Value std = op.getStd(); Value mean = op.getMean(); Value none = rewriter.create(loc); Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value randN = rewriter.create( loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); Value stdRandN = rewriter.create(loc, resultType, randN, std); rewriter.replaceOpWithNewOp(op, resultType, stdRandN, mean, /*alpha=*/one); return success(); } }; template class DecomposeAtenAddCLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value tensor1 = op.getTensor1(); Value tensor2 = op.getTensor2(); Value value = op.getValue(); Value product = rewriter.create(loc, op.getType(), tensor1, tensor2); rewriter.replaceOpWithNewOp(op, op.getType(), input, product, value); return success(); } }; class DecomposeAtenLayerNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLayerNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto input = op.getInput().getType().cast(); if (!input.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = input.getSizes().size(); Value normalizedShape = op.getNormalizedShape(); SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); std::vector meanVarSizes(inputRank, 1); for (int i = 0; i < axis; i++) meanVarSizes[i] = input.getSizes()[i]; auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), input.getOptionalDtype()); auto nativeLayerNorm = rewriter.create( loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); rewriter.replaceOp(op, nativeLayerNorm.getResult(0)); return success(); } }; } // namespace namespace { class DecomposeAtenInstanceNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenInstanceNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); auto inputTy = op.getInput().getType().cast(); int64_t inputRank = inputTy.getSizes().size(); SmallVector reducedShape(inputTy.getSizes()); SmallVector reduceDimInts; SmallVector reduceDimVals; for (int i = 2; i < inputRank; ++i) { reducedShape[i] = 1; reduceDimVals.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); } Type dtype = inputTy.getOptionalDtype(); Type reducedTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(reducedShape), dtype); auto sizeListType = ListType::get(IntType::get(context)); Value reduceDimList = rewriter.create(loc, sizeListType, reduceDimVals); Value cstTrue = rewriter.create(loc, true); Value none = rewriter.create(loc); Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); // mean(x) Value inputMean = rewriter.create( loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); // x - mean(x) Value inputMeanExpanded = rewriter.create(loc, inputTy, inputMean, op.getInput()); Value inputSubMean = rewriter.create( loc, inputTy, op.getInput(), inputMeanExpanded, one); // (x - mean(x))^2 Value inputSubMeanSquare = rewriter.create( loc, inputTy, inputSubMean, inputSubMean); Value variancesum = rewriter.create( loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, /*dtype=*/none); int64_t elemCount = 1; for (int i = 2; i < inputRank; ++i) elemCount *= inputTy.getSizes()[i]; Value hw = rewriter.create( loc, rewriter.getI64IntegerAttr(elemCount)); Value inputVar = rewriter.create(loc, reducedTy, variancesum, hw); // rsqrt(var(x) + eps) Value inputVarPlusEps = rewriter.create( loc, reducedTy, inputVar, op.getEps(), one); Value inputRsqrtVar = rewriter.create(loc, reducedTy, inputVarPlusEps); // (x - mean(x)) * rsqrt(var(x) + eps) Value inputRsqrtVarExpanded = rewriter.create( loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputSubMean, inputRsqrtVarExpanded); Value out = rewriter.create( loc, op.getResult().getType(), inputNormalized); Value weight = op.getWeight(); auto weightTy = weight.getType().cast(); dtype = weightTy.getOptionalDtype(); SmallVector weightShape(weightTy.getSizes()); SmallVector newWeightShape; newWeightShape.push_back(1); newWeightShape.append(weightShape); Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Type newWeightTy = ValueTensorType::get( op.getContext(), llvm::ArrayRef(newWeightShape), dtype); weight = rewriter.create(loc, newWeightTy, weight, zero); while (static_cast(newWeightShape.size()) < inputRank) { Value i = rewriter.create( loc, rewriter.getI64IntegerAttr(newWeightShape.size())); newWeightShape.push_back(1); newWeightTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newWeightShape), dtype); weight = rewriter.create(loc, newWeightTy, weight, i); } Value weightExpanded = rewriter.create(loc, inputTy, weight, op.getInput()); Value bias = op.getBias(); auto biasTy = bias.getType().cast(); dtype = biasTy.getOptionalDtype(); SmallVector biasShape(biasTy.getSizes()); SmallVector newBiasShape; newBiasShape.push_back(1); newBiasShape.append(biasShape); Type newBiasTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newBiasShape), dtype); bias = rewriter.create(loc, newBiasTy, bias, zero); while (static_cast(newBiasShape.size()) < inputRank) { Value i = rewriter.create( loc, rewriter.getI64IntegerAttr(newBiasShape.size())); newBiasShape.push_back(1); newBiasTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newBiasShape), dtype); bias = rewriter.create(loc, newBiasTy, bias, i); } Value biasExpanded = rewriter.create(loc, inputTy, bias, op.getInput()); out = rewriter.create(loc, out.getType(), out, weightExpanded); out = rewriter.create(loc, out.getType(), out, biasExpanded, one); rewriter.replaceOp(op, out); return success(); } }; } // namespace namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); auto inputTy = op.getInput().getType().cast(); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputTy.getSizes().size(); Value normalizedShape = op.getNormalizedShape(); SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); auto reducedTy = op.getResult(1).getType(); auto sizeListType = ListType::get(IntType::get(context)); // build reduce dims SmallVector reduceDimVals; reduceDimVals.reserve(reduceDimInts.size()); std::transform(reduceDimInts.begin(), reduceDimInts.end(), std::back_inserter(reduceDimVals), [&](int64_t d) { return rewriter.create( loc, rewriter.getI64IntegerAttr(d)); }); Value reduceDimList = rewriter.create(loc, sizeListType, reduceDimVals); Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value cstTrue = rewriter.create(loc, true); Value none = rewriter.create(loc); // mean(x) Value inputMean = rewriter.create( loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); // x - mean(x) Value inputMeanExpanded = rewriter.create(loc, inputTy, inputMean, op.getInput()); Value inputZeroMean = rewriter.create( loc, inputTy, op.getInput(), inputMeanExpanded, one); // var(x) = mean((x - mean(x))^2) Value inputZeroMeanSquare = rewriter.create( loc, inputTy, inputZeroMean, inputZeroMean); Value inputVar = rewriter.create( loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); // rsqrt(var(x) + eps) Value inputVarPlusEps = rewriter.create( loc, reducedTy, inputVar, op.getEps(), one); Value inputRsqrtVar = rewriter.create(loc, reducedTy, inputVarPlusEps); // (x - mean(x)) * rsqrt(var(x) + eps) Value inputRsqrtVarExpanded = rewriter.create( loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); Value out = rewriter.create( loc, op.getResult(0).getType(), inputNormalized); Value weight = op.getWeight(); Value bias = op.getBias(); if (!weight.getType().isa()) { out = rewriter.create(loc, out.getType(), out, weight); } if (!bias.getType().isa()) { out = rewriter.create(loc, out.getType(), out, bias, one); } rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); return success(); } }; } // namespace namespace { // Decompose `aten.emptyLike` op into `aten.size` and `aten.empty` ops. class DecomposeAtenEmptyLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEmptyLikeOp op, PatternRewriter &rewriter) const override { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // The `aten.arange` op is converted to `aten.arange.startStep` op. class DecomposeAtenArangeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenArangeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // The AtenArangeOp doesn't have a start and step value. Therefore we set // them as default values 0 and 1, respectively. Value start, step; start = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // The `aten.arange.start` op is converted to `aten.arange.startStep` op. class DecomposeAtenArangeStartOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenArangeStartOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); // The AtenArangeStartOp doesn't have a step value. Therefore we set it as // default value 1. Value step; step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose constant tensor full like ops. template class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value constVal = rewriter.create( loc, rewriter.getI64IntegerAttr(fillVal)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), constVal, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { class DecomposeAtenGroupNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenGroupNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); Value numGroups = op.getNumGroups(); Value eps = op.getEps(); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value N = rewriter.create(loc, input, cstZero); Value C = rewriter.create(loc, input, cstOne); Value numElements = rewriter.create(loc, input); Value numElementsDivN = rewriter.create(loc, numElements, N); Value HxW = rewriter.create(loc, numElementsDivN, C); AtenNativeGroupNormOp newOp = rewriter.create( loc, ArrayRef{op.getResult().getType(), baseType, baseType}, input, weight, bias, N, C, HxW, numGroups, eps); rewriter.replaceOp(op, newOp.getResult0()); return success(); } }; } // namespace namespace { class DecomposeAtenNativeGroupNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeGroupNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); Value numGroups = op.getGroup(); Value eps = op.getEps(); // Check the rank of the input/outputs tensor. auto inputType = input.getType().cast(); auto outputType = op.getResult0().getType().cast(); auto meanType = op.getResult1().getType().cast(); auto rsqrtVarType = op.getResult2().getType().cast(); if (!inputType.hasSizes() || !outputType.hasSizes() || !meanType.hasSizes() || !rsqrtVarType.hasSizes()) { return rewriter.notifyMatchFailure( op, "input/outputs tensor should have known sizes."); } Value none = rewriter.create(loc); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value cstNegtiveOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); Value cstTrue = rewriter.create(loc, true); Value cstFalse = rewriter.create(loc, false); auto baseType = ValueTensorType::getWithLeastStaticInformation(context); // GroupNorm requires the channel dimension (C) to be exactly divisible by // the number of groups. Value channel = rewriter.create(loc, input, cstOne); Value remainder = rewriter.create(loc, channel, numGroups); Value eqOrNot = rewriter.create(loc, remainder, cstZero); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("the number of channels must be divisible by " "the number of groups")); // Reshape the input tensor to (N, numGroups, -1) to apply normalization. SmallVector newShape; newShape.push_back(rewriter.create(loc, input, cstZero)); newShape.push_back(numGroups); newShape.push_back(cstNegtiveOne); Value reshapedInput = rewriter.create( loc, baseType, input, rewriter.create( loc, Torch::ListType::get(IntType::get(context)), newShape)); // Now we proceed with the normalization steps across the 'groupSize' // Compute the mean and variance for each group Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ArrayRef{cstNegtiveOne}); auto mean = rewriter.create( loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, /*dtype=*/none); auto var = rewriter.create( loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, /*keepdim=*/cstTrue); // Compute the normalized output: (input - mean) * rsqrt(var + eps) auto varPlusEps = rewriter.create(loc, baseType, var, eps, /*alpha=*/cstOne); auto invStd = rewriter.create(loc, baseType, varPlusEps); auto inputSubMean = rewriter.create( loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); auto normalizedOutput = rewriter.create(loc, baseType, inputSubMean, invStd); // Reshape normalized output back to the original input shape auto inputShape = rewriter.create( loc, Torch::ListType::get(IntType::get(context)), input); auto reshapedOutput = rewriter.create( loc, inputType, normalizedOutput, /*shape=*/inputShape); // Apply weight and bias if they are not None // Reshape weight and bias to C,1,1,... SmallVector viewShape = {channel}; for (unsigned i = 2; i < inputType.getSizes().size(); i++) { viewShape.push_back(cstOne); } Value viewShapeSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), viewShape); Value groupNormOutput = reshapedOutput; if (!weight.getType().isa()) { auto weightReshaped = rewriter.create( loc, baseType, weight, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, weightReshaped); } if (!bias.getType().isa()) { auto biasReshaped = rewriter.create( loc, baseType, bias, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, biasReshaped, /*alpha=*/cstOne); } Value squeezedMean = rewriter.create(loc, meanType, mean, cstNegtiveOne); Value squeezedRsqrtVar = rewriter.create( loc, rsqrtVarType, invStd, cstNegtiveOne); rewriter.replaceOp( op, ArrayRef{groupNormOutput, squeezedMean, squeezedRsqrtVar}); return success(); } }; } // namespace namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNativeBatchNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); Value runningMean = op.getRunningMean(); Value runningVar = op.getRunningVar(); Value eps = op.getEps(); // TODO: Add support for `training` mode. bool training = false; if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training)) || training) return rewriter.notifyMatchFailure( op, "unimplemented: training mode is not supported"); // Rank of the input tensor must be greater than or equal to 2. The shape of // the `input` is supposed to be (N, C, D?, H?, W?). std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank || *maybeInputRank < 2) return rewriter.notifyMatchFailure( op, "input must have rank greater than or equal to 2"); unsigned inputRank = *maybeInputRank; // In the inference mode, the `runningMean` and `runningVar` must not be // None. if (runningMean.getType().isa() || runningVar.getType().isa()) return rewriter.notifyMatchFailure( op, "running stats must not be None in inference mode"); // Rank of `runningMean` and `runningVar` must be exactly 1. std::optional runningMeanRank = getTensorRank(runningMean); std::optional runningVarRank = getTensorRank(runningVar); if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 || *runningVarRank != 1) return rewriter.notifyMatchFailure( op, "expected runningMean and runningVar to be rank 1"); Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value numFeatures = rewriter.create(loc, input, /*dim=*/one); // TODO: Add Runtime Asserts to check the shape of weight, bias, // runningMean and runningVar to be (numFeatures). // The `runningMean` and `runningVar` must be reshaped to (1, C, 1?, 1?, 1?) // to make it broadcast-compatible with (N, C, D?, H?, W?). // 1. runningMean = runningMean.view(1, C, 1?, 1?, 1?) // 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?) SmallVector runningStatsShape(inputRank, one); runningStatsShape[1] = numFeatures; Value runningStatsSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), runningStatsShape); SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = runningMean.getType().cast().getSizes()[0]; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); runningMean = rewriter.create(loc, reshapeType, runningMean, runningStatsSizeList); runningVar = rewriter.create(loc, reshapeType, runningVar, runningStatsSizeList); // normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)). Value inputSubMean = rewriter.create( loc, input.getType(), input, runningMean, /*alpha=*/one); Value varEps = rewriter.create( loc, runningVar.getType(), runningVar, eps, /*alpha=*/one); Value invStd = rewriter.create(loc, varEps.getType(), varEps); Value normalizedInput = rewriter.create( loc, inputSubMean.getType(), inputSubMean, invStd); // The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it // broadcast-compatible with (N, C, D?, H?, W?). // 1. weight = weight.view(1, C, 1?, 1?, 1?) // 2. bias = bias.view(1, C, 1?, 1?, 1?) // 3. output = normalizedInput * weight + bias Value batchNormOutput = normalizedInput; if (!weight.getType().isa()) { // Rank of `weight` must be exactly 1. std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); weight = rewriter.create(loc, reshapeType, weight, runningStatsSizeList); batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, weight); } if (!bias.getType().isa()) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); bias = rewriter.create(loc, reshapeType, bias, runningStatsSizeList); batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one); } // The `mean` and `invstd` outputs are empty tensors in inference mode. Value zeroList = rewriter.create( loc, Torch::ListType::get(zero.getType()), zero); Value none = rewriter.create(loc); Value emptyMeanTensor = rewriter.create( loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); Value emptyInvStdTensor = rewriter.create( loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); rewriter.replaceOp(op, {batchNormOutput, emptyMeanTensor, emptyInvStdTensor}); return success(); } }; } // namespace // Decompse `Aten_UnsafeViewOp` into `AtenViewOp`. UnsafeView() differs from // view() in that the returned tensor isn't treated as a view for the purposes // of automatic differentiation. It's only safe to use if the `self` tensor is // temporary. For example, the viewed tensor here (a + b) is discarded // immediately after viewing: // // res = UnsafeView(a + b, size); // // This is a hack because in-place operations on tensors treated like views // can be much more expensive than the same operations on non-view tensors. // Refer to // https://github.com/pytorch/pytorch/blob/364055b2771ecf9b54f1d67a8bf44bb5496476d4/aten/src/ATen/native/TensorShape.cpp#L2072 namespace { class DecomposeAten_UnsafeViewOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_UnsafeViewOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace // In PyTorch, ReshapeAlias just uses an already computed stride. // See // https://github.com/pytorch/pytorch/blob/d8c31a819d4a65e732b5901e3b994e1869851f1a/aten/src/ATen/native/TensorShape.cpp#L1153 // Note that this is the same decomposition as in AOTAutograd // https://github.com/pytorch/functorch/blob/a3042d94e616d4143813668b1372d9d4545be14e/functorch/Src/aotAutograd.py#L104 namespace { class DecomposeAten_ReshapeAliasOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ReshapeAliasOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getSize()); return success(); } }; } // namespace namespace { // Decompose constant tensor like ops. template class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().template cast(); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.full` op into `aten.broadcastTo` class DecomposeAtenFullOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); BaseTensorType outTy = op.getType().template cast(); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create(loc, tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, op.getSize()); return success(); } }; } // namespace namespace { // Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops. class DecomposeAtenLinearOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLinearOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getInput(); Value weight = op.getWeight(); Value bias = op.getBias(); BaseTensorType inputType = input.getType().cast(); if (!inputType.hasSizes() || inputType.getSizes().size() < 2) return rewriter.notifyMatchFailure( op, "expected input to be rank 2 or greater"); BaseTensorType weightType = weight.getType().cast(); // `weight` must be a rank 2 matrix. if (!weightType.hasSizes() || weightType.getSizes().size() != 2) return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); SmallVector transposeShape = llvm::to_vector(llvm::reverse(weightType.getSizes())); Type transposeType = weightType.getWithSizesAndDtype( llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = rewriter.create(loc, transposeType, weight); Value matmul = rewriter.create(loc, op.getType(), input, transposeWeight); if (bias.getType().isa()) { rewriter.replaceOp(op, matmul); return success(); } BaseTensorType biasType = bias.getType().cast(); if (!biasType.hasSizes() || biasType.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); Value alpha = rewriter.create(loc, rewriter.getF64FloatAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), matmul, op.getBias(), alpha); return success(); } }; } // namespace namespace { // Decompose `aten.mish` op into `aten.tanh` and `aten.softplus` ops. // Mish(x) = x * Tanh(Softplus(x)) class DecomposeAtenMishOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMishOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Type type = op.getType(); auto inputType = input.getType().cast(); if (!inputType.hasDtype()) return rewriter.notifyMatchFailure(op, "Dtype not present"); Type dType = inputType.getDtype(); // Form default Value tensors for `beta` and `threshold` operands // of `aten.softplus` op. Value beta = getConstantWithGivenDtypeAndValue(rewriter, loc, 1.0, dType); Value threshold = getConstantWithGivenDtypeAndValue(rewriter, loc, 20.0, dType); Value softplusOp = rewriter.create(loc, type, input, beta, threshold); Value tanhOp = rewriter.create(loc, type, softplusOp); rewriter.replaceOpWithNewOp(op, type, input, tanhOp); return success(); } }; } // namespace namespace { // Decompose `aten.fullLike` op into `aten.emptyLike` and `aten.fill` ops. class DecomposeAtenFullLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFullLikeOp op, PatternRewriter &rewriter) const override { BaseTensorType outTy = op.getType().template cast(); if (!outTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create( op.getLoc(), tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, op.getSelf()); return success(); } }; } // namespace namespace { // Decompose `aten.new_full` op into `aten.full` op. class DecomposeAtenNewFullOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewFullOp op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().cast(); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op. class DecomposeAtenIndexPutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexPutOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { class DecomposeAtenExpandAsOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExpandAsOp op, PatternRewriter &rewriter) const override { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getOther()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), sizeList); return success(); } }; } // namespace namespace { // Decompose `aten.ToCopy` op into `valsem.aten.copy` op. class DecomposeAten_ToCopyOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Type resultDtype = resultType.getDtype(); Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); Value emptyTensor = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, op.getType(), emptyTensor, op.getSelf(), op.getNonBlocking()); return success(); } }; } // namespace namespace { // Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`. class DecomposeAtenCopyOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCopyOp op, PatternRewriter &rewriter) const override { auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } auto srcTy = op.getSrc().getType().cast(); if (!srcTy.hasSizes() || !srcTy.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected src type to have a known rank and dtype"); } Type resultDtype = resultType.getDtype(); Value srcToDtype = convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype); rewriter.replaceOpWithNewOp(op, op.getType(), srcToDtype, op.getSelf()); return success(); } }; } // namespace namespace { // Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op. class DecomposeAtenNewEmptyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyOp op, PatternRewriter &rewriter) const override { Value noneVal = rewriter.create(op.getLoc()); Value dtype = op.getDtype(); if (dtype.getType().isa()) { BaseTensorType tensorType = op.getSelf().getType().cast(); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a dtype"); } dtype = getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); } }; } // namespace namespace { // Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl` // op. class DecomposeAtenIndexPutHackedTwinOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` // op. class DecomposeAten_UnsafeIndexPutHackedTwinOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPadOp op, PatternRewriter &rewriter) const override { Value value = op.getValue(); if (value.getType().isa()) return rewriter.notifyMatchFailure(op, "optional type not supported"); if (value.getType().isa()) value = rewriter.create( op.getLoc(), rewriter.getF64FloatAttr(0)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getPad(), value); return success(); } }; } // namespace namespace { // Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op. class DecomposeAtenToDtypeLayoutOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op, PatternRewriter &rewriter) const override { // TODO: Add support for pinMemory arg equal to `True`. if (!op.getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return rewriter.notifyMatchFailure( op, "unimplemented: pinMemory must be a constant"); else if (pinMemory) return rewriter.notifyMatchFailure( op, "unimplemented: pinMemory is expected to be false"); } // TODO: Add support for device arg other than cpu. if (!op.getDevice().getType().isa()) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( op, "unimplemented: device must be a constant str"); else if (device != "cpu") return rewriter.notifyMatchFailure( op, "unimplemented: device is expected to be cpu"); } // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( op, "unimplemented: layout must be a constant"); else if (tensorLayout != torch_upstream::Layout::Strided) return rewriter.notifyMatchFailure( op, "unimplemented: layout is expected to be strided"); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // Decompose `aten.to.prim_Device` op into `aten.to.dtype` op. class DecomposeAtenToPrimDeviceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToPrimDeviceOp op, PatternRewriter &rewriter) const override { // Device information isn't relevant to torch-mlir, so we can drop that info // here. auto loc = op.getLoc(); Value constNone = rewriter.create(loc); Value dtype = op.getDtype(); if (dtype.getType().template isa()) { dtype = rewriter.create(loc, op.getSelf()); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), dtype, op.getNonBlocking(), op.getCopy(), constNone); return success(); } }; } // namespace namespace { // Decompose `aten.to.device` op into `aten.to.dtype` op. class DecomposeAtenToDeviceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToDeviceOp op, PatternRewriter &rewriter) const override { // Device information isn't relevant to torch-mlir, so we can drop that info // here. rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); } }; } // namespace namespace { // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. // The logic of this decomposition is totally same with // the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two // cases are supported: // 1. inputSize = outputSize // 2. outputSize = 1 class DecomposeAtenAdaptiveAvgPool1dOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value input = op.getSelf(); std::optional maybeRank = getTensorRank(input); if (!maybeRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned rank = *maybeRank; Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); Value inputSize = rewriter.create(loc, input, sizeDim); Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; getListConstructElements(outputShape, outputShapeSizesTorchInt); Value outputSize = outputShapeSizesTorchInt[0]; Value constantOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); int64_t outputSizeInt; if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { return rewriter.notifyMatchFailure( op, "the output size of adaptive_pool_1d must be a constant int"); } SmallVector kernelSize; if (outputSizeInt == 1) { BaseTensorType inputTensorType = input.getType().cast(); ArrayRef inputShape = inputTensorType.getSizes(); kernelSize.push_back( inputShape[rank - 1] == kUnknownSize ? inputSize : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { if (!isAssumingStrictSymbolicShapes(rewriter)) { Value cond = rewriter.create(loc, inputSize, outputSize); rewriter.create( loc, cond, "unimplemented: only support cases where input and output size are " "equal for non-unit output size"); } kernelSize.push_back(constantOne); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); Value strideList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne}); Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero}); rewriter.replaceOpWithNewOp( op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); return success(); } }; } // namespace namespace { // Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. // // For AdaptiveAvgPool2d op, when the input size is an integer multiple of // output size the kernelSize, stride and padding is calculated as follows: // strideH = inH // outH // strideW = inH // outH // kernelH = inH - [(outH - 1) * strideH] = strideH // kernelW = inW - [(outW - 1) * strideW] = strideW // paddingH = 0, paddingW = 0 // class DecomposeAtenAdaptiveAvgPool2dOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value input = op.getSelf(); std::optional maybeRank = getTensorRank(input); if (!maybeRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned rank = *maybeRank; SmallVector inputHW; Value dimH = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 2)); inputHW.push_back( /*inH=*/rewriter.create(loc, input, dimH)); Value dimW = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); inputHW.push_back( /*inW=*/rewriter.create(loc, input, dimW)); Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; getListConstructElements(outputShape, outputShapeSizesTorchInt); // TODO: Add support for cases other than: // inH % outH != 0 or inW % outW != 0 Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); Value constantNone = rewriter.create(loc); SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { Value remainder = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); Value cond = rewriter.create(loc, remainder, constantZero); rewriter.create(loc, cond, "unimplemented: only support cases " "input size is an integer multiple of " "output size"); Value stride = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); Value kernelSizeValue = stride; kernelSize.push_back(kernelSizeValue); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); Value strideList = kernelSizeList; Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); rewriter.replaceOpWithNewOp( op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, /*divisorOverride=*/constantNone); return success(); } }; } // namespace namespace { // Decompose `aten.clampMin` op into `aten.clamp` op. class DecomposeAtenClampMinOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMinOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); return success(); } }; } // namespace namespace { // Decompose `aten.clamp_min.Tensor` op into `aten.clamp.Tensor` op. class DecomposeAtenClampMinTensorOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMinTensorOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); return success(); } }; } // namespace namespace { // Decompose `aten.clampMax` op into `aten.clamp` op. class DecomposeAtenClampMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMaxOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*min=*/constantNone, op.getMax()); return success(); } }; } // namespace namespace { class DecomposeAtenCosineSimilarityOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCosineSimilarityOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value x1 = op.getX1(); Value x2 = op.getX2(); Value dim = op.getDim(); // Broadcast x1 and x2 to the same shape SmallVector indexBroadcastShapeInt; SmallVector indexBroadcastShapeValue; computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, indexBroadcastShapeValue); Type dtype = x1.getType().cast().getOptionalDtype(); Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype); Value indexBroadcastShapeTorchList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), indexBroadcastShapeValue); x1 = rewriter.create(loc, broadcastType, x1, indexBroadcastShapeTorchList); x2 = rewriter.create(loc, broadcastType, x2, indexBroadcastShapeTorchList); // Compute the mul of A and B Value dotProduct = rewriter.create(loc, broadcastType, x1, x2); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), ValueRange{dim}); Value sumDotProduct = rewriter.create( loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); // Compute the norm of A and B Value ord = rewriter.create( loc, rewriter.getF64FloatAttr(2.0)); Value normA = rewriter.create( loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); Value normB = rewriter.create( loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); // Compute the product of the norms Value normProduct = rewriter.create(loc, op.getType(), normA, normB); Value normProductClamp = rewriter.create( loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone); // Compute the final cosine similarity by division rewriter.replaceOpWithNewOp( op, op.getType(), sumDotProduct, normProductClamp); return success(); } }; } // namespace namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. class DecomposeAtenBaddbmmOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBaddbmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value bmm = rewriter.create(loc, op.getType(), op.getBatch1(), op.getBatch2()); Value alphaTimesBmm = rewriter.create(loc, op.getType(), bmm, op.getAlpha()); Value input = op.getSelf(); BaseTensorType inputType = input.getType().cast(); BaseTensorType resultType = op->getResult(0).getType().cast(); if (inputType.hasDtype() && resultType.hasDtype() && inputType.getDtype() != resultType.getDtype()) { input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype()); } rewriter.replaceOpWithNewOp( op, op.getType(), alphaTimesBmm, op.getSelf(), op.getBeta()); return success(); } }; } // namespace namespace { // Decompose `aten.floorDivide` op into `aten.div.TensorMode` op. class DecomposeAtenFloorDivideOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFloorDivideOp op, PatternRewriter &rewriter) const override { // https://pytorch.org/docs/stable/generated/torch.floorDivide.html // PyTorch aten.floorDivide is a misnomer because it actually rounds // the quotient towards zero instead of taking its floor. Value cstStrFloor = rewriter.create(op.getLoc(), "trunc"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther(), /*roundingMode=*/cstStrFloor); return success(); } }; } // namespace namespace { // Decompose `aten.numpyT` op into `aten.permute` op. class DecomposeAtenNumpyTOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNumpyTOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; SmallVector dimListElements; SmallVector dimListInts(llvm::reverse( llvm::iota_range(0, inputRank, /*inclusive=*/false))); for (int dimListInt : dimListInts) { dimListElements.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dimListInt))); } Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), dimListElements); rewriter.replaceOpWithNewOp(op, op.getType(), self, dimList); return success(); } }; } // namespace template static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, bool unbiased, double correction) { Location loc = op.getLoc(); Value self = op.getSelf(); Value dimList = op.getDim(); Value keepDim = op.getKeepdim(); BaseTensorType inputTensorTy = self.getType().cast(); Type outputType = op.getType(); BaseTensorType outputTensorType = cast(outputType); if (!outputTensorType.hasDtype()) { return rewriter.notifyMatchFailure(op, "expected result type to have a dtype"); } Type newOutputType = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || !inputTensorTy.getDtype().isa()) { return rewriter.notifyMatchFailure( op, "support floating-point type input only"); } // Upcasting the input tensor to `F64` dtype for higher precision during the // computation of the result. if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); inputTensorTy = self.getType().cast(); } std::optional maybeInputRank = getTensorRank(self); if (!maybeInputRank) { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned inputRank = *maybeInputRank; SmallVector dimListElements; bool isNoneOrEmpty = true; if (!dimList.getType().template isa()) { if (!getListConstructElements(dimList, dimListElements)) return rewriter.notifyMatchFailure( op, "expect dimList to be constructed from list construct"); if (!dimListElements.empty() || inputRank == 0) isNoneOrEmpty = false; } if (isNoneOrEmpty) { for (unsigned i = 0; i < inputRank; i++) dimListElements.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dimListElements); } Type meanDimResultType = inputTensorTy; for (unsigned i = 0; i < dimListElements.size(); i++) meanDimResultType = computeReductionType( rewriter, op, cast(meanDimResultType), dimListElements[i], /*keepDim=*/true); Value constantNone = rewriter.create(loc); Value constantTrue = rewriter.create(loc, true); Value meanAlongDims = rewriter.create( loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, /*dtype=*/constantNone); Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); Value square = rewriter.create(loc, inputTensorTy, subMean); if (!unbiased) { Value result = rewriter.create( loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); return success(); } // Divide the square sum by productDimSize - correction. Value squareSum = rewriter.create( loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); // `productDimSize` is product of sizes of dimensions to be reduced. Value constantOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value productDimSize = constantOne; for (Value dim : dimListElements) { Value dimSize = rewriter.create(loc, self, dim); productDimSize = rewriter.create(loc, productDimSize, dimSize); } productDimSize = rewriter.create(loc, productDimSize); constantOne = rewriter.create( loc, rewriter.getF64FloatAttr(1.0)); Value cstCorrection = rewriter.create( loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. if (!isAssumingStrictSymbolicShapes(rewriter)) { Value productDimSizePlusOne = rewriter.create( loc, productDimSize.getType(), productDimSize, constantOne); Value cond = rewriter.create(loc, productDimSizePlusOne, cstCorrection); rewriter.create( loc, cond, "correction value should be less than or equal to productDimSize + 1"); } Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, productDimSizeSubCorrection); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); return success(); } // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) // For Unbiased case: // out = aten.sum(square, dims) / (productDimSize-1) // For Biased case: // out = aten.mean(square, dims) namespace { class DecomposeAtenVarDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarDimOp op, PatternRewriter &rewriter) const override { bool unbiased; if (!matchPattern(op.getUnbiased(), m_TorchConstantBool(&unbiased))) { return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } double correction = unbiased ? 1.0 : 0.0; if (failed(calculateVariance(op, rewriter, unbiased, correction))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; } // namespace // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) // out = aten.sum(square, dims) / (productDimSize - correction) namespace { class DecomposeAtenVarCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarCorrectionOp op, PatternRewriter &rewriter) const override { int64_t correctionValInt; double correctionValFloat = 1.0; if (!op.getCorrection().getType().isa()) { if (op.getCorrection().getType().isa()) { if (!matchPattern(op.getCorrection(), m_TorchConstantFloat(&correctionValFloat))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); } else if (op.getCorrection().getType().isa()) { if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correctionValInt))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); correctionValFloat = (double)correctionValInt; } else { return rewriter.notifyMatchFailure( op, "unimplemented: correction value should be only constant int " "or float for aten.var"); } } bool unbiased = correctionValFloat == 0.0 ? false : true; if (failed(calculateVariance(op, rewriter, unbiased, correctionValFloat))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; } // namespace namespace { // Decompose the `aten.selectScatter` operation into `aten.sliceScatter` op. class DecomposeAtenSelectScatterOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSelectScatterOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value start = op.getIndex(); Value dim = op.getDim(); Value self = op.getSelf(); Value src = op.getSrc(); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = rewriter.create(loc, one.getType(), start, one); auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim); if (failed(unsqueezedInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor op"); } src = *unsqueezedInfo; rewriter.replaceOpWithNewOp( op, op.getSelf().getType(), self, src, dim, start, startPlusOne, /*step=*/one); return success(); } }; } // namespace namespace { class DecomposeAten_EmbeddingBagOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op, PatternRewriter &rewriter) const override { Value weight = op.getWeight(); Value indices = op.getIndices(); Value offsets = op.getOffsets(); Value scaleGradByFreq = op.getScaleGradByFreq(); Value mode = op.getMode(); Value sparse = op.getSparse(); Value perSampleWeights = op.getPerSampleWeights(); Value includeLastOffset = op.getIncludeLastOffset(); Value paddingIdx = op.getPaddingIdx(); auto resultType0 = op->getResult(0).getType(); auto resultType1 = op->getResult(1).getType(); auto resultType2 = op->getResult(2).getType(); auto resultType3 = op->getResult(3).getType(); llvm::SmallVector returnTypes{resultType0, resultType1, resultType2, resultType3}; rewriter.replaceOpWithNewOp( op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, sparse, perSampleWeights, includeLastOffset, paddingIdx); return success(); } }; } // namespace namespace { // Decompose `aten.liftFreshCopy` op into `aten.clone` op. class DecomposeAtenLiftFreshCopyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op, PatternRewriter &rewriter) const override { Value constantNone = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*memoryFormat=*/constantNone); return success(); } }; } // namespace namespace { class DecomposeAtenMseLossOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMseLossOp op, PatternRewriter &rewriter) const override { // The `reduction` arg would have only three valid values. // 0 means no reduction. // 1 means mean reduction. // 2 means sum reduction. int64_t reductionType; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reductionType))) return rewriter.notifyMatchFailure( op, "Expected a constant integer value for reduction"); Location loc = op.getLoc(); BaseTensorType resultType = op.getType().cast(); BaseTensorType inputType = op.getSelf().getType().cast(); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "Expected the input tensor to have sizes"); BaseTensorType subType = inputType .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), resultType.getOptionalDtype()) .cast(); Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); Value result = rewriter.create(loc, subType, sub); if (reductionType == torch_upstream::Reduction::None) { rewriter.replaceOp(op, result); return success(); } Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); if (reductionType == torch_upstream::Reduction::Mean) result = rewriter.create(loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, /*dtype=*/cstNone); else result = rewriter.create( loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, /*dtype=*/cstNone); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op class DecomposeAtenNormScalarOptDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value none = rewriter.create(loc); Value ord = op.getP(); if (ord.getType().isa()) { ord = rewriter.create( loc, rewriter.getF64FloatAttr(2.0)); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), /*dtype=*/none); return success(); } }; } // namespace namespace { class DecomposeAtenRandintLowOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandintLowOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = op.getType(); BaseTensorType resultTensorType = cast(resultType); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } int64_t cstLow, cstHigh; if (!matchPattern(op.getLow(), m_TorchConstantInt(&cstLow))) return rewriter.notifyMatchFailure( op, "unimplemented: low must be a constant integer"); if (!matchPattern(op.getHigh(), m_TorchConstantInt(&cstHigh))) return rewriter.notifyMatchFailure( op, "unimplemented: high must be a constant integer"); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)cstLow)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)cstHigh)); BaseTensorType floatResultType = resultTensorType .getWithSizesAndDtype(resultTensorType.getSizes(), rewriter.getF32Type()) .cast(); Value emptyTensor = rewriter.create( loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*memoryFormat=*/none); Value result = rewriter.create(loc, floatResultType, emptyTensor, /*from=*/low, /*to=*/high, /*generator=*/none); rewriter.replaceOpWithNewOp( op, resultType, result, getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), /*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/none); return success(); } }; } // namespace namespace { class DecomposeAtenRandintOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandintOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type resultType = op.getType(); Value low = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp( op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.varMean.correction` op into `aten.var.correction` and // `aten.mean.dim` op. class DecomposeAtenVarMeanCorrectionOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanCorrectionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create( loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim()); Value mean = rewriter.create( loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { // Decompose `prims.convertElementType` op into `aten.to.dtype` op. class DecomposePrimsConvertElementTypeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); rewriter.replaceOpWithNewOp( op, op.getType(), op.getA(), op.getDtype(), /*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/cstNone); return success(); } }; } // namespace namespace { // Decompose `prims.var` op into `aten.var.correction` op. class DecomposePrimsVarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsVarOp op, PatternRewriter &rewriter) const override { if (!op.getOutputDtype().getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for prims::var op"); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(), /*keepdim=*/cstFalse); return success(); } }; } // namespace namespace { // Decompose `prims.sqrt` op into `aten.sqrt` op. class DecomposePrimsSqrtOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsSqrtOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); return success(); } }; } // namespace namespace { // The op is decomposed using the Box-Muller transform. // Refer: https://en.wikipedia.org/wiki/Box-Muller_transform class DecomposeAtenRandnGeneratorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnGeneratorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); Value none = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)1.0)); Value cstMinusTwo = rewriter.create( loc, rewriter.getF64FloatAttr((double)-2.0)); Value cstTwoPie = rewriter.create( loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); Value emptyTensorA = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value emptyTensorB = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value uOne = rewriter.create(loc, resultType, emptyTensorA, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); Value uTwo = rewriter.create(loc, resultType, emptyTensorB, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); Value logUOne = rewriter.create(loc, resultType, uOne); Value minusTwoLogUOne = rewriter.create(loc, resultType, logUOne, cstMinusTwo); Value r = rewriter.create(loc, resultType, minusTwoLogUOne); Value theta = rewriter.create(loc, resultType, uTwo, cstTwoPie); Value cosTheta = rewriter.create(loc, resultType, theta); rewriter.replaceOpWithNewOp(op, op.getType(), r, cosTheta); return success(); } }; } // namespace namespace { // Decompose `aten.randn` op into `aten.randn.generator` op. class DecomposeAtenRandnOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnOp op, PatternRewriter &rewriter) const override { Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { // Decompose `aten.randn_like` op into `aten.randn.generator` op. class DecomposeAtenRandnLikeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnLikeOp op, PatternRewriter &rewriter) const override { // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Preserve) return rewriter.notifyMatchFailure( op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } Value none = rewriter.create(op.getLoc()); auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( op, op.getType(), sizeList, /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { class DecomposeAtenRandOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto resultType = op.getType().cast(); if (!resultType.hasDtype()) { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } Value noneVal = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)1.0)); Value emptyTensor = rewriter.create( loc, resultType, op.getSize(), /*dtype=*/op.getDtype(), /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/noneVal); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/low, /*to=*/high, /*generator=*/noneVal); return success(); } }; } // namespace namespace { class DecomposeAtenLinspaceOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLinspaceOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = getContext(); auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value none = rewriter.create(loc); Value falseVal = rewriter.create(loc, false); Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value addStart; int64_t steps; if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) { // specically handle steps == 1 Value arange = rewriter.create( loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), op.getDevice(), op.getPinMemory()); addStart = rewriter.create(loc, baseType, arange, op.getStart(), one); } else { // handle steps != 1 or dynamic steps Value neOrNot = rewriter.create(loc, op.getSteps(), one); rewriter.create( loc, neOrNot, rewriter.getStringAttr("linspace's dynamic steps must not be 1")); // create arange: [0, ..., steps - 1] Value arange = rewriter.create( loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), op.getDevice(), op.getPinMemory()); // calculate (end - start) / (steps - 1) Value sub; if (op.getEnd().getType().isa() || op.getStart().getType().isa()) { sub = rewriter.create(loc, Torch::FloatType::get(context), op.getEnd(), op.getStart()); } else { sub = rewriter.create(loc, op.getEnd(), op.getStart()); } Value div = rewriter.create( loc, sub, rewriter.create(loc, op.getSteps(), one)); // calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start Value mulScalar = rewriter.create(loc, baseType, arange, div); addStart = rewriter.create(loc, baseType, mulScalar, op.getStart(), one); } // to dtype Value result; if (!op.getDtype().getType().isa()) { result = rewriter.create( loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, /*copy=*/falseVal, /*memory_format=*/none); } else { Value f32Type = rewriter.create( loc, (int)torch_upstream::ScalarType::Float); result = rewriter.create( loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal, /*copy=*/falseVal, /*memory_format=*/none); } rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value falseVal = rewriter.create(loc, false); Value noneVal = rewriter.create(loc); Value var = rewriter.create(loc, op.getType(0), op.getSelf(), /*dim=*/noneVal, op.getUnbiased(), /*keepdim=*/falseVal); Value mean = rewriter.create(loc, op.getType(0), op.getSelf(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { class DecomposeAtenNewEmptyStridedOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value opSize = op.getSize(); Value opStride = op.getStride(); if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( op, "Unable to determine if stride is default"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; } // namespace namespace { class DecomposeAtenEmptyStridedOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEmptyStridedOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value opSize = op.getSize(); Value opStride = op.getStride(); if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( op, "Unable to determine if stride is default"); Value noneVal = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); } }; } // namespace namespace { class DecomposePrimsSqueezeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsSqueezeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getA(); SmallVector dimensions; if (!matchPattern(op.getDimensions(), m_TorchListOfConstantInts(dimensions))) return rewriter.notifyMatchFailure( op, "all dimensions must be constant ints"); std::sort(dimensions.rbegin(), dimensions.rend()); if (dimensions.size() == 0) { rewriter.replaceOp(op, input); return success(); } Value result = input; for (unsigned i = 0; i < dimensions.size(); i++) { auto squeezeTensorInfo = squeezeTensor(rewriter, op, loc, dimensions[i], result); if (failed(squeezeTensorInfo)) { return rewriter.notifyMatchFailure(op, "cannot generate unsqueeze tensor"); } result = *squeezeTensorInfo; } rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeAtenMovedimIntOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMovedimIntOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); std::optional maybeInputRank = getTensorRank(input); if (!maybeInputRank) { return rewriter.notifyMatchFailure( op, "expected input tensor to have a rank"); } unsigned inputRank = *maybeInputRank; if (inputRank <= 1) { rewriter.replaceOp(op, input); return success(); } int64_t srcDimInt, dstDimInt; if (matchPattern(op.getSource(), m_TorchConstantInt(&srcDimInt))) { srcDimInt = toPositiveDim(srcDimInt, inputRank); if (!isValidDim(srcDimInt, inputRank)) return rewriter.notifyMatchFailure(op, "source is not a valid dim"); } else { return rewriter.notifyMatchFailure(op, "source is not a constant int"); } if (matchPattern(op.getDestination(), m_TorchConstantInt(&dstDimInt))) { dstDimInt = toPositiveDim(dstDimInt, inputRank); if (!isValidDim(dstDimInt, inputRank)) return rewriter.notifyMatchFailure(op, "destination is not a valid dim"); } else { return rewriter.notifyMatchFailure(op, "destination is not a constant int"); } SmallVector dimsOrder = computeDimsOrderForMoveDim(srcDimInt, dstDimInt, inputRank); SmallVector cstDimsOrder; for (int64_t dim : dimsOrder) cstDimsOrder.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dim))); Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), cstDimsOrder); rewriter.replaceOpWithNewOp(op, op.getType(), input, permuteDimsOrder); return success(); } }; } // namespace namespace { class DecomposeAtenCrossEntropyLossOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCrossEntropyLossOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); Value target = op.getTarget(); std::optional maybeRank = getTensorRank(self); if (!maybeRank) return rewriter.notifyMatchFailure( op, "Unimplemented: unranked input tensor"); unsigned selfRank = maybeRank.value(); maybeRank = getTensorRank(target); if (!maybeRank) return rewriter.notifyMatchFailure( op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the // combination of softmax and nll loss as follows: // cross_entropy_loss = NLLLoss(LogSoftmax(input, dim=1), target) // Currently, we only support the above-mentioned case. if (selfRank != 2 || targetRank != 1) { return rewriter.notifyMatchFailure( op, "unimplemented: only support cases with 2-d input and 1-d target"); } // TODO: Add support for label_smoothing value other than 0.0 (default // value). double labelSmoothing; if (!matchPattern(op.getLabelSmoothing(), m_TorchConstantFloat(&labelSmoothing))) { return rewriter.notifyMatchFailure( op, "Only support constant float label_smoothing value"); } else if (labelSmoothing != 0.0) { return rewriter.notifyMatchFailure(op, "unimplemented: only support default " "value of 0.0 for label_smoothing"); } Value noneVal = rewriter.create(loc); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value logSoftmax = rewriter.create( loc, self.getType(), self, dim, /*dtype=*/noneVal); Value nllLoss = rewriter .create( loc, op.getType(), target.getType(), logSoftmax, target, op.getWeight(), op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); return success(); } }; } // namespace namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenOneHotOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); Value input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputType.getSizes().size(); int64_t numClasses; if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) return rewriter.notifyMatchFailure( op, "unimplemented: num_classes must be constant"); Value none = rewriter.create(loc); // arange tensor auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto arangeType = ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type); Value arangeTensor = rewriter.create( loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); // unsqueeze input llvm::SmallVector unsqueezeShape(inputType.getSizes()); unsqueezeShape.push_back(1); auto unsqueezeType = ValueTensorType::get(context, unsqueezeShape, si64Type); Value unsqueezeTensor = rewriter.create( loc, unsqueezeType, input, rewriter.create(loc, rewriter.getI64IntegerAttr(inputRank))); // compare auto eqType = ValueTensorType::get( context, op.getType().cast().getSizes(), IntegerType::get(context, 1)); Value eqTensor = rewriter.create( loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Decompose `aten.var_mean.dim` op into `aten.var.dim` and // `aten.mean.dim` op. class DecomposeAtenVarMeanDimOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarMeanDimOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create(loc, op.getType(0), op.getSelf(), op.getDim(), op.getUnbiased(), op.getKeepdim()); Value mean = rewriter.create( loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } }; } // namespace namespace { // decompose aten.scalar_tensor to prim.NumToTensor.Scalar and // aten.to.dtype_layout class DecomposeAtenScalarTensor : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenScalarTensorOp op, PatternRewriter &rewriter) const override { auto resultTy = op.getResult().getType().cast(); auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType()); Value numToTensor = rewriter.create( op.getLoc(), resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy), op.getS()); Value cstNone = rewriter.create(op.getLoc()); Value cstFalse = rewriter.create(op.getLoc(), false); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype()); Value toDTypeLayout = rewriter.create( op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); rewriter.replaceOp(op, toDTypeLayout); return success(); } }; } // namespace namespace { // Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op. class DecomposeAtenTopkOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTopkOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto context = op.getContext(); bool sorted; if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted))) return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for sorted"); if (!sorted) return rewriter.notifyMatchFailure( op, "unimplemented: sorted value arg must be set to True"); Value self = op.getSelf(); Value dim = op.getDim(); auto selfType = self.getType().cast(); auto sortIndicesType = selfType.getWithSizesAndDtype( selfType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); auto sortOpResult = rewriter.create( loc, self.getType(), sortIndicesType, self, dim, /*descending=*/op.getLargest()); Value start = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value step = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value resultValue = rewriter.create( loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start, /*end=*/op.getK(), step); Value resultIndices = rewriter.create( loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start, /*end=*/op.getK(), step); rewriter.replaceOp(op, {resultValue, resultIndices}); return success(); } }; } // namespace namespace { // Decompose `aten.scatter.value` op into `aten.scatter.src` op. class DecomposeAtenScatterValueOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenScatterValueOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); Value self = op.getSelf(); Value index = op.getIndex(); std::optional maybeIndexRank = getTensorRank(index); if (!maybeIndexRank) { return rewriter.notifyMatchFailure( op, "expected index tensor to have a rank"); } unsigned indexRank = *maybeIndexRank; SmallVector sizes; for (int64_t i = 0; i < indexRank; ++i) { Value dim = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); sizes.push_back(rewriter.create(loc, index, /*dim=*/dim)); } Value sizeList = rewriter.create( loc, ListType::get(IntType::get(context)), sizes); auto selfType = self.getType().cast(); auto indexType = index.getType().cast(); BaseTensorType srcType = selfType .getWithSizesAndDtype(indexType.getOptionalSizes(), selfType.getOptionalDtype()) .cast(); Value src = createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); rewriter.replaceOpWithNewOp(op, op.getType(), self, op.getDim(), index, src); return success(); } }; } // namespace namespace { // Decompose `aten.sgn` op into comparisons and aten.where. class DecomposeAtenSgnOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSgnOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = op.getType().cast(); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "expected result type to have dtype"); } // TODO: support complex type in future. if (outType.getDtype().isa()) { return rewriter.notifyMatchFailure(op, "doesn't support complex type now"); } auto zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); auto one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto minusOne = rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(), rewriter.getI1Type()); auto greater = rewriter.create(loc, compTy, op.getSelf(), zero); auto less = rewriter.create(loc, compTy, op.getSelf(), zero); // Pseudo code: // if (in > 0) // return 1 // else if (in < 0) // return -1 // else // return 0 // note: return 0 if nan/0.0/-0.0 // return 1 if inf // return -1 if -inf auto selectGreater = rewriter.create(loc, outType, greater, one, zero); rewriter.replaceOpWithNewOp(op, outType, less, minusOne, selectGreater); return success(); } }; } // namespace namespace { // Unconditionally decompose `torch.type_as` into `prim.dtype` + // `torch.to.dtype`. class DecomposeAtenTypeAsOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTypeAsOp op, PatternRewriter &rewriter) const override { auto input = op.getSelf(); auto other = op.getOther(); Location loc = op.getLoc(); Value targetDtype = rewriter.create(loc, other); Value nonBlocking = rewriter.create(loc, false); Value copy = rewriter.create(loc, false); Value memoryFormat = rewriter.create(loc); rewriter.replaceOpWithNewOp( op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat); return success(); } }; } // namespace // AtenIndexTensorOp namespace { // The goal of this pattern is to eliminate none index in aten.Index.Tensor's // `indices` param for the ease of various backend. The detailed steps are: // 1. reorder input tensor so that the non-none index appears at adjacent // positions. // 2. manually generate index tensor with some ops like iota, to replace the // none index in `indices` // 3. replace the old aten.Index.Tensor with a new // aten.Index.Tensor_hacked_twin. class DecomposeAtenIndexTensorOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; // TODO: It might be better to use aten.view op instead of mulitple // aten.unsqueeze. But currently, torch-to-linalg pass has limited support for // view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will // cause relevant e2e tests fail. static FailureOr unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, Value input, int count) { Location loc = op->getLoc(); Value constMinusOne = rewriter.create( loc, rewriter.getI64IntegerAttr(-1)); Value result = input; while (count--) { auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { return failure(); } result = *unsqzTensorInfo; } return result; } static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter, Value input, int dimInt, int64_t dimSize) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value none = rewriter.create(loc); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto resultType = ValueTensorType::get( context, {dimSize}, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto dim = rewriter.create( loc, rewriter.getI64IntegerAttr(dimInt)); auto end = rewriter.create(loc, input, dim); auto v = rewriter.create( loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); return v; } LogicalResult matchAndRewrite(AtenIndexTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = op.getContext(); SmallVector indices; if (!getListConstructElements(op.getIndices(), indices)) return rewriter.notifyMatchFailure(op, "failed to get elements of `indices`"); auto input = op.getSelf(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only input with shape information is supported"); } auto inputSizes = inputType.getSizes(); int64_t inputRank = inputSizes.size(); auto outputType = op.getType().cast(); if (!outputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only output with shape information is supported"); } auto outputRank = outputType.getSizes().size(); auto isTensor = [](Value v) { return v.getType().isa(); }; // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { if (indices.size() == 0) { return rewriter.notifyMatchFailure( op, "the indices is empty, it should be folded as a nop"); } // By default, we regard the first index type as the list element type. auto indexElemType = indices[0] .getType() .template cast() .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndex = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); rewriter.replaceOpWithNewOp(op, op.getType(), input, newIndex); return success(); } SmallVector indexUsed = llvm::to_vector(llvm::map_range(indices, isTensor)); for (int64_t i = indices.size(); i < inputRank; ++i) indexUsed.emplace_back(false); bool indexIsConsecutive = true; int64_t firstUsedIndex = -1; for (size_t i = 0; i < indices.size(); ++i) { if (indexUsed[i] && firstUsedIndex == -1) { firstUsedIndex = i; } else if (indexUsed[i] && !indexUsed[i - 1]) { indexIsConsecutive = false; break; } } // use aten.permute to reorder the input Value newInput; // `dims` stores the mapping from new index to the old index of input // tensor. SmallVector dims; if (!indexIsConsecutive) { SmallVector dimValues; SmallVector permutedSizes; for (int i = 0; i < inputRank; i++) { if (indexUsed[i]) { dims.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } for (int i = 0; i < inputRank; i++) { if (!indexUsed[i]) { dims.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } auto dimValueList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), dimValues); newInput = rewriter.create( loc, inputType.getWithSizesAndDtype(permutedSizes, inputType.getOptionalDtype()), input, dimValueList); } else { newInput = input; for (int i = 0; i < inputRank; i++) { dims.emplace_back(i); } } // manually generate new indices. SmallVector listElements(inputRank); int64_t trailingDimCnt = 0; int64_t i; // handle trailing none index. for (i = inputRank - 1; i >= 0; --i) { int64_t oldI = dims[i]; if (indexUsed[oldI]) break; Value v = createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; trailingDimCnt++; } // handle non-none index in between. for (; i >= 0; --i) { int64_t oldI = dims[i]; if (!indexUsed[oldI]) break; auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI], trailingDimCnt); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; } // handle possible leading none dimensions. for (; i >= 0; --i) { int64_t oldI = dims[i]; if (indexUsed[oldI]) { return rewriter.notifyMatchFailure( op, "the indices are still unconsecutive after reordering input " "tensor"); } Value v = createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i); if (failed(vInfo)) { return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } listElements[i] = *vInfo; } auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); auto newIndexList = rewriter.create( loc, Torch::ListType::get(listElemType), listElements); rewriter.replaceOpWithNewOp( op, op.getType(), newInput, newIndexList); return success(); } }; } // namespace namespace { // Unconditionally decompose `aten.tile` into `aten.repeat`. class DecomposeAtenTileOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTileOp op, PatternRewriter &rewriter) const override { auto input = op.getSelf(); auto repeats = op.getDims(); SmallVector dimsElements; if (!getListConstructElements(repeats, dimsElements)) { return rewriter.notifyMatchFailure( op, "failed to get elements of `dims` param"); } auto dimsSize = dimsElements.size(); auto inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only support input tensor with shape information"); } auto inputRank = inputType.getSizes().size(); if (dimsSize < inputRank) { auto constantOne = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(1)); for (auto i = dimsSize; i < inputRank; ++i) { dimsElements.insert(dimsElements.begin(), constantOne); } repeats = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), dimsElements); } rewriter.replaceOpWithNewOp(op, op.getType(), input, repeats); return success(); } }; } // namespace namespace { // Unconditionally decompose `aten.reshape_as` into `aten.size` + // `aten.reshape`. class DecomposeAtenReshapeAsOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenReshapeAsOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value input = op.getSelf(); Value other = op.getOther(); auto otherShape = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), other); rewriter.replaceOpWithNewOp(op, op.getType(), input, otherShape); return success(); } }; } // namespace namespace { // Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only class DecomposeAtenLinalgNormOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLinalgNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); SmallVector dimList; if (!getListConstructElements(op.getDim(), dimList)) { return rewriter.notifyMatchFailure( op, "dim should comes from a PrimListConstructOp"); } if (dimList.size() != 1) { return rewriter.notifyMatchFailure( op, "Unimplemented: only dim size of 1 is supported"); } // default ord value is 2 for vector_norm auto ord = op.getOrd(); if (ord.getType().isa()) { ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), op.getDtype()); return success(); } }; } // namespace namespace { class DecomposeAtenFakeQuantizePerTensorAffineOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); MLIRContext *context = getContext(); Value none = rewriter.create(loc); Value falseVal = rewriter.create(loc, false); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto baseType = ValueTensorType::getWithLeastStaticInformation(context); // input/scale Value divScale = rewriter.create( loc, op.getType(), op.getSelf(), op.getScale()); // std::nearby_int(input/scale) Value round = rewriter.create(loc, op.getType(), divScale); // std::nearby_int(input/scale) + zero_point Value addZeroPoint = rewriter.create( loc, op.getType(), round, op.getZeroPoint(), one); // max(quant_min, std::nearby_int(input/scale) + zero_point) Value max = rewriter.create( loc, op.getType(), addZeroPoint, rewriter.create(loc, baseType, op.getQuantMin(), /*dtype=*/none, /*device=*/none, /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) Value min = rewriter.create( loc, op.getType(), max, rewriter.create(loc, baseType, op.getQuantMax(), /*dtype=*/none, /*device=*/none, /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) // - zero_point Value subZeroPoint = rewriter.create( loc, op.getType(), min, op.getZeroPoint(), one); // (min(quant_max, max(quant_min, std::nearby_int(input/scale) + // zero_point)) - zero_point) * scale Value result = rewriter.create( loc, op.getType(), subZeroPoint, op.getScale()); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { private: llvm::StringSet<> legalOpsSet; template void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) { MLIRContext *context = &getContext(); std::optional opName = DecomposePattern(context).getRootKind(); // Because the `DecomposeComplexOpsPass` uses a greedy algorithm // to apply patterns, only patterns that we for sure know we want to run // must be added. This restricts the set of patterns allowed in this file to // patterns that apply to a single op. In other words, patterns that match // on `Operation *` are not allowed, since there is no way of telling if // that pattern will match on an op in the `legalOpsSet` or not. assert(opName && "All decomposition patterns must target a single op"); if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix))) patterns.add(context); } public: DecomposeComplexOpsPass() = default; DecomposeComplexOpsPass(ArrayRef legalOps) { this->legalOps = legalOps; } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // The strings in the `legalOps` ArrayRef don't exist during the call to the // constructor `DecomposeComplexOpsPass`, so the creation of the // `legalOpsSet` must be delayed to when `runOnOperation` gets called. legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>( patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorNewLikeOp>( patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorNewLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); // More specific conv ops addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createDecomposeComplexOpsPass( ArrayRef legalOps) { return std::make_unique(legalOps); }