//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; static Value createElementwiseLinalgGeneric( OpBuilder &b, Location loc, ValueRange tensorOperands, Type resultElementType, function_ref bodyBuild) { // The overall error handling strategy here is best viewed by thinking about // what happens for a single result dimension. This loop not structured that // way because it is hard to create the affine maps for each operand unless // we structure the loop to iterate over tensor operands as the outer loop // instead of inner loop. This valsemcode gives better intuition: // ``` // for each result dimension: // for each tensor operand: // if it doesn't even have high enough rank relative to the result: // continue // if it is a static size-1 along this result dimension: // continue // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, // **non-broadcasted** traversal!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` SmallVector operandRanks; operandRanks.resize(tensorOperands.size()); llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { return tensor.getType().dyn_cast().getRank(); }); auto resultRankIt = std::max_element(operandRanks.begin(), operandRanks.end()); assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); int64_t resultRank = *resultRankIt; // Initialize the resultShape to all 1's, as a fallback in case // all sizes along that result dimension are statically 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); for (auto size : llvm::enumerate(type.getShape())) { // If the size is statically known to be 1, we don't want any // error guards to be spuriously emitted, since we are specifically // allowing size-1 broadcasts in this case, as they correspond to a // constant-0 indexing map. if (size.value() == 1) { exprs.push_back(b.getAffineConstantExpr(0)); continue; } // The rank of this operand might be smaller than the overall rank of // the broadcast. Add an offset to correlate it to the correct // dimension of the result. auto resultDim = size.index() + (resultRank - type.getRank()); // The generated linalg op will now be iterating along the full size // of this dimension. Record that fact. exprs.push_back(b.getAffineDimExpr(resultDim)); // Now, we need to ensure that such iteration is not going to trigger // undefined behavior, by doing appropriate checks against the current // dimension size. auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); // If the result size of this dimension has so far only hit the // statically-known-to-be-1 case above (i.e., we have not yet assigned a // new Value to `resultShape[resultDim]`), then we have no other dynamic // values to check against, and merely need to record the current // dimension size. if (resultShape[resultDim] == c1) { resultShape[resultDim] = currentDimSize; continue; } // We prohibit the size-1 dynamic broadcasting scenario, so just check // for exact equality with the running result size. // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. auto equalToRunning = b.create(loc, arith::CmpIPredicate::eq, resultShape[resultDim], currentDimSize); b.create(loc, equalToRunning, "mismatched size for broadcast"); } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); } SmallVector iteratorTypes(resultRank, getParallelIteratorTypeName()); // Add the indexing map for the outs init tensor. indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); Value initTensor = b.create( loc, getAsOpFoldResult(resultShape), resultElementType); return b .create(loc, /*resultTensorTypes=*/initTensor.getType(), /*inputs=*/tensorOperands, /*outputs=*/initTensor, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, Value lhs, Value rhs) { if (type.isa()) return b.create(loc, fpred, lhs, rhs); if (IntegerType intType = type.dyn_cast()) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) return b.create(loc, ispred, lhs, rhs); } assert(false && "Unhandled element type for comparison"); } static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); Value xMinusMean = b.create(loc, x, mean); Value two = b.create(loc, FloatAttr::get(elementType, 2)); Value sqrt2 = b.create(loc, two); Value erfArg = b.create(loc, xMinusMean, sqrt2); Value erf = b.create(loc, erfArg); Value one = b.create(loc, FloatAttr::get(elementType, 1)); Value erfPlus1 = b.create(loc, one, erf); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value normalCdf = b.create(loc, oneHalf, erfPlus1); return normalCdf; } static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { Type elementType = x.getType(); Value zero = b.create(loc, FloatAttr::get(elementType, 0)); Value one = b.create(loc, FloatAttr::get(elementType, 1)); return buildNormalCdf(b, loc, x, zero, one); } static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.memory_format().getType().isa() && (!matchPattern(clone.memory_format(), m_TorchConstantInt(&memoryFormat)) || memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { clone.emitError("unimplemented: only default memory format is supported"); return nullptr; } return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { if (bitwiseAndTensor.getType() .cast() .getDtype() .isa()) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseAndTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { Type elementType = payloadArgs[0].getType(); auto one = b.create(loc, FloatAttr::get(elementType, 1)); auto negate = b.create(loc, payloadArgs[0]); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); return b.create(loc, one, added); } if (auto relu = dyn_cast(op)) { if (!relu.getType() .cast() .getDtype() .isa()) { relu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); return b.create(loc, pred, payloadArgs[0], constZero); } if (auto lrelu = dyn_cast(op)) { if (!lrelu.getType() .cast() .getDtype() .isa()) { lrelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); Value positivePart = b.create(loc, pred, payloadArgs[0], constZero); Value negativePart = b.create(loc, pred, constZero, payloadArgs[0]); Value scale = convertScalarToDtype(b, loc, operands[1], elementType); Value scaledNegativePart = b.create(loc, negativePart, scale); return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() .getDtype() .isa()) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(gelu.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]); return b.create(loc, payloadArgs[0], cdf); } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() .cast() .getDtype() .isa()) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(geluBackward.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Type elementType = payloadArgs[1].getType(); Value cstAlpha0 = b.create( loc, FloatAttr::get(elementType, 1.12837916709551257390)); Value cstAlpha1 = b.create( loc, FloatAttr::get(elementType, 0.70710678118654752440)); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value kAlpha = b.create(loc, cstAlpha0, cstAlpha1); Value kAlphaHalf = b.create(loc, kAlpha, oneHalf); Value negOneHalf = b.create(loc, FloatAttr::get(elementType, -0.5)); Value inputSquared = b.create(loc, payloadArgs[1], payloadArgs[1]); Value negHalfInputSquared = b.create(loc, inputSquared, negOneHalf); Value dinput = b.create(loc, negHalfInputSquared); Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]); Value dinputInput = b.create(loc, dinput, payloadArgs[1]); Value dinputInputAlpha = b.create(loc, dinputInput, kAlphaHalf); Value cdfExt = b.create(loc, dinputInputAlpha, cdf); return b.create(loc, payloadArgs[0], cdfExt); } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(add.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(sub.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto subScalar = dyn_cast(op)) { Type dtype = converter->convertType(subScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } subScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto addScalar = dyn_cast(op)) { Type dtype = converter->convertType(addScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } addScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(mul.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } } if (auto gtTensor = dyn_cast(op)) { AtenGtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { gtTensor.emitError("unimplemented: different lhs and rhs dtype"); return nullptr; } Type elementalType = gtTensor.self().getType().cast().getDtype(); return createGreaterThan(b, loc, elementalType, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { AtenEqTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { eqTensor.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = eqTensor.self().getType().cast().getDtype(); if (elementalType.isa()) return b.create(loc, arith::CmpFPredicate::UEQ, payloadArgs[0], payloadArgs[1]); if (elementalType.isa()) { return b.create(loc, arith::CmpIPredicate::eq, payloadArgs[0], payloadArgs[1]); } eqTensor.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto ltTensor = dyn_cast(op)) { AtenLtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { ltTensor.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = ltTensor.self().getType().cast().getDtype(); return createLessThan(b, loc, elementalType, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); if (!dtype.isa()) div.emitError("unimplemented: non-floating point dtype"); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() .getDtype() .isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type dtype = pow.self().getType().cast().getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto gtScalar = dyn_cast(op)) { Type dtype = gtScalar.self().getType().cast().getDtype(); // TODO: `gtTensor` and `gtScalar` share similar code and can be called from // one static function. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. gtScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ugt, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sgt, payloadArgs[0], otherPromoted); } gtScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto geScalar = dyn_cast(op)) { Type dtype = geScalar.self().getType().cast().getDtype(); // TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that // can be refactored. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. geScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::uge, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sge, payloadArgs[0], otherPromoted); } geScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto eqScalar = dyn_cast(op)) { Type dtype = eqScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UEQ, payloadArgs[0], otherPromoted); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. eqScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } return b.create(loc, arith::CmpIPredicate::eq, payloadArgs[0], otherPromoted); } eqScalar.emitError("unimplemented: dtype isn't supported"); return nullptr; } if (auto ltScalar = dyn_cast(op)) { Type dtype = ltScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share a // lot of code that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. ltScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ult, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::slt, payloadArgs[0], otherPromoted); } ltScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto leScalar = dyn_cast(op)) { Type dtype = leScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code that // can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. leScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ule, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sle, payloadArgs[0], otherPromoted); } leScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto whereSelf = dyn_cast(op)) { Type dtype = converter->convertType(whereSelf.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, payloadArgs[0], lhs, rhs); } if (auto lerp = dyn_cast(op)) { if (!lerp.getType() .cast() .getDtype() .isa()) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenLerpTensorOp::Adaptor adaptor(payloadArgs); auto start = adaptor.self(); auto end = adaptor.end(); auto weight = adaptor.weight(); auto delta = b.create(loc, end, start); auto weightedDelta = b.create(loc, delta, weight); return b.create(loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { Type dtype = minimum.getType().cast().getDtype(); Type elemTy = converter->convertType(minimum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { Type dtype = maximum.getType().cast().getDtype(); Type elemTy = converter->convertType(maximum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); if (!dtype.isa()) { clamp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.min(); auto max = adaptor.max(); if (min.getType().isa() || max.getType().isa()) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } auto result = payloadArgs[0]; if (!min.getType().isa()) { auto minPromoted = convertScalarToDtype(b, loc, min, dtype); auto pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); result = b.create(loc, pred, maxPromoted, result); } return result; } if (auto rsub = dyn_cast(op)) { Type dtype = converter->convertType(rsub.getType()) .cast() .getElementType(); if (!dtype.isa()) { rsub.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = payloadArgs[0]; Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } if (auto mulScalar = dyn_cast(op)) { Type dtype = converter->convertType(mulScalar.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (dtype.isa()) return b.create(loc, lhs, rhs); if (dtype.isa()) return b.create(loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; Type dtype = converter->convertType(atenToDtype.getType()) .cast() .getElementType(); Value result = convertScalarToDtype(b, loc, input, dtype); return result; } if (auto divScalar = dyn_cast(op)) { Type dtype = converter->convertType(divScalar.getType()) .cast() .getElementType(); if (!dtype.isa()) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = payloadArgs[0]; Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } if (auto reciprocal = dyn_cast(op)) { if (!reciprocal.getType() .cast() .getDtype() .isa()) { reciprocal.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); // assert(element != 0) auto zero = b.create(loc, FloatAttr::get(elementType, 0.0)); auto pred = b.create(loc, arith::CmpFPredicate::ONE, payloadArgs[0], zero); b.create( loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); auto one = b.create(loc, FloatAttr::get(elementType, 1.0)); return b.create(loc, one, payloadArgs[0]); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? value : self AtenThresholdOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdOp.getType()) .cast() .getElementType(); Value self = payloadArgs[0]; Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, value, self); } if (auto thresholdBackward = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? 0 : grad AtenThresholdBackwardOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdBackward.getType()) .cast() .getElementType(); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, constantZero, grad); } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; } namespace { // Converts an elementwise op. // This specifically includes: // - converting elementwise ops of any tensor arity // - converting elementwise ops with any number of scalar captures (such as a // scalar alpha to torch.aten.Add) // - broadcasting of static size-1 dimensions // // Currently, we adopt the behavior that "size 1" broadcasting is a runtime // error if it happens dynamically. // // Looking forward a bit, eventually, it probably makes sense to have // a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted // operands. Modeling elementwise ops that way is potentially useful to allow a // more centralized reasoning about multiversioning. However a cost model will // be needed for "pre-fusing" elementwise ops that way, as it can potentially be // a pessimization. A mild extension of this pattern should work for such a // general op. class ConvertElementwiseOp : public ConversionPattern { public: ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( operands, [](Value v) { return v.getType().isa(); })); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); bool hadErrorCreatingPayload = false; Value generic = createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( b, loc, getTypeConverter(), payloadArgs, op, operands); if (!result) { hadErrorCreatingPayload = true; return; } b.create(loc, result); }); if (hadErrorCreatingPayload) return failure(); rewriter.replaceOpWithNewOp(op, resultType, generic); return success(); } }; } // namespace // Given `input`, `target`, `nll_loss_forward` is given by: // for i in range(0, len(target)): // indi = target[i]; // nll_loss_forward[i] = -(input[i][indi]); // TODO: `weight`operand is still to be taken care of. namespace { class ConvertAtenNllLossForwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.self(); Value target = adaptor.target(); Value weight = adaptor.weight(); int64_t reduction; if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); Value ignoreIndex = adaptor.ignore_index(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); unsigned inputRank = input.getType().cast().getRank(); unsigned targetRank = target.getType().cast().getRank(); // TODO: Add support for k-dim loss. if (inputRank > 2) { return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); Value finalRes = createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value indTarget = rewriter.create( loc, rewriter.getIndexType(), targetVal); // The final result is given by: // final_res = (indTarget == ignoreIndexVal) ? 0 : // input[indI][IndTarget] Value cmpEq = rewriter.create( loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); SmallVector extractionIndices{indTarget}; if (inputRank == 2) { Value indI = rewriter.create(loc, 0); extractionIndices.insert(extractionIndices.begin(), indI); } Value result = rewriter.create(loc, input, extractionIndices); Value negate = rewriter.create(loc, elementType, result); Value selectFinal = rewriter.create(loc, cmpEq, zeroVal, negate); b.create(loc, selectFinal); }); if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { Value numOfElems = getTensorSize(rewriter, loc, finalRes); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); llvm::iota_range dimsToReduce(0, targetRank, /*inclusive=*/false); DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); finalRes = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, finalRes, dimSet, /*keepDim=*/false, /*initElem=*/zeroVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value newVal = args[0]; Value accumulator = args[1]; if (reduction == torch_upstream::Reduction::Mean) newVal = b.create(loc, newVal, numOfElems); Value result = b.create(loc, newVal, accumulator); b.create(loc, result); }); } // TODO: Update the second result tensor. Value weightUpdated = createZeroInitTensor(rewriter, loc, {}, elementType); rewriter.replaceOp(op, {finalRes, weightUpdated}); return success(); } }; } // namespace // Normalization formula: // ((input - mean) / sqrt(var + eps)) * weight + bias static Value createLinalgPayloadCalculationForNormOps( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var, Value eps, Value weight, Value bias) { Value inputSubMean = b.create(loc, input, mean); // The eps is always f64. Value truncatedEps = b.create(loc, elemTy, eps); Value varPlusEps = b.create(loc, var, truncatedEps); Value rSTD = b.create(loc, varPlusEps); Value temp = b.create(loc, inputSubMean, rSTD); Value timesWeight = b.create(loc, temp, weight); Value plusBias = b.create(loc, timesWeight, bias); return plusBias; } namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value runningMean = adaptor.running_mean(); Value runningVar = adaptor.running_var(); Value training = adaptor.training(); Value eps = adaptor.eps(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias)) || failed(checkNotNone(rewriter, op, runningMean)) || failed(checkNotNone(rewriter, op, runningVar))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); auto runningMeanType = runningMean.getType().cast(); auto runningVarType = runningVar.getType().cast(); auto inputRank = inputType.getRank(); if (inputRank <= 2) return rewriter.notifyMatchFailure( op, "input should have rank larger than 2"); if (weightType.getRank() != 1 || biasType.getRank() != 1 || runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } // TODO: Add support for training. auto constFalse = rewriter.create( loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); auto trainingFalse = rewriter.create( loc, arith::CmpIPredicate::eq, training, constFalse); rewriter.create( loc, trainingFalse, rewriter.getStringAttr("training is not supported for now")); // num_features – C from an expected input of size (N,C,D,H,W ...) Value numFeatures = rewriter.create(loc, input, 1); auto contractingDim0EqualsNumFeatures = [&](Value v) { auto dim0 = rewriter.create(loc, v, 0); auto dim0Equal = rewriter.create( loc, arith::CmpIPredicate::eq, numFeatures, dim0); rewriter.create( loc, dim0Equal, rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; contractingDim0EqualsNumFeatures(weight); contractingDim0EqualsNumFeatures(bias); contractingDim0EqualsNumFeatures(runningMean); contractingDim0EqualsNumFeatures(runningVar); auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(inputRank), // input indexingMap, // weight indexingMap, // bias indexingMap, // runningMean indexingMap, // runningVar rewriter.getMultiDimIdentityMap(inputRank), // output }; SmallVector iteratorTypes(inputRank, "parallel"); Value batchNorm = rewriter .create( loc, input.getType(), ValueRange{input, weight, bias, runningMean, runningVar}, input, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], weight = args[1], bias = args[2], mean = args[3], var = args[4]; Value result = createLinalgPayloadCalculationForNormOps( b, loc, var.getType(), input, mean, var, eps, weight, bias); b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); return success(); } }; } // namespace // For layernorm, the mean and standard-deviation are calculated separately over // the last certain number dimensions which have to be of the shape specified by // normalized_shape. // // The shapes of different parts are as the following: // +-------------------+--------------------+ // | meanAndVarShape | normalizedShape | // +-------------------+--------------------- // <------------+ inputShape +--------------> // There are the following steps: // Step 1. Check if all the arguments meet the requirements. // Step 2. Common parts to be used for getting mean and var. // This includes elements count, affineMap and iteratorTypes. // Step 3. Get mean. // Step 4. Get var. // Step 5. Get layernorm. namespace { class ConvertAtenNativeLayerNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value eps = adaptor.eps(); Value normalizedShape = op.normalized_shape(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); int64_t inputRank = inputType.getRank(); Type elemTy = inputType.getElementType(); // Step 1. Check if all the arguments meet the requirements. SmallVector normalizedShapeSizesTorchInt; if (!getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt)) { return rewriter.notifyMatchFailure(op, "Unimplemented normalized_shape not" "constructed from ListConstruct"); } SmallVector normalizedShapeSizesInt = getTypeConvertedValues( rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt); int64_t normalizedShapeRank = normalizedShapeSizesInt.size(); if (weightType.getRank() != normalizedShapeRank || biasType.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" "normalized shape not compatible"); // Check all the dimensions match the normalized_shape int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size(); for (auto en : enumerate((normalizedShapeSizesInt))) { auto index = en.index(); auto inputDim = getDimOp(rewriter, loc, input, index + meanAndVarShapeRank); auto weightDim = getDimOp(rewriter, loc, weight, index); auto biasDim = getDimOp(rewriter, loc, bias, index); auto expectedSize = en.value(); checkDimEqualHelper(rewriter, loc, inputDim, expectedSize); checkDimEqualHelper(rewriter, loc, weightDim, expectedSize); checkDimEqualHelper(rewriter, loc, biasDim, expectedSize); } // Get iterator types for input shape. SmallVector normalizedShapeIteratorTypes( normalizedShapeRank, getReductionIteratorTypeName()); SmallVector meanAndVarIterationTypes( meanAndVarShapeRank, getParallelIteratorTypeName()); SmallVector inputShapeIteratorTypes = meanAndVarIterationTypes; inputShapeIteratorTypes.append(normalizedShapeIteratorTypes); // Step 2. Common parts to be used for getting mean and var. // Get sizes and affineMaps needed for mean and var. AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector meanAndVarShapeExprs; for (int i = 0; i < meanAndVarShapeRank; i++) meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); auto meanAndVarShapeAffineMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, meanAndVarShapeExprs, context); SmallVector meanAndVarShapeSizes = getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1); // Get number of elements to be used for calculating mean and var. Value elemCnts = normalizedShapeSizesInt[0]; for (int i = 1; i < normalizedShapeRank; i++) { elemCnts = rewriter.create(loc, elemCnts, normalizedShapeSizesInt[i]); } Value elemCntsFloat = rewriter.create(loc, elemTy, elemCnts); // Helper to calculate mean and var. auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) { SmallVector indexingMaps( 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); Value initShapeTensor = rewriter.create( loc, meanAndVarShapeSizes, elemTy); return rewriter .create( loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/meanAndVarIterationTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value sumOrSqureSum = args[0]; Value result = b.create(loc, sumOrSqureSum, elemCntsFloat); b.create(loc, result); }) .getResult(0); }; // Step 3. Get mean. // Get sum to be used for calculating mean. SmallVector sumIndexingMaps = { inputShapeAffineMap, // input meanAndVarShapeAffineMap, // output }; auto initSumTensor = createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); Value sum = rewriter .create( loc, initSumTensor.getType(), input, initSumTensor, /*indexingMaps=*/sumIndexingMaps, /*iteratorTypes=*/inputShapeIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], sum = args[1]; Value result = rewriter.create(loc, sum, input); b.create(loc, result); }) .getResult(0); Value mean = genMeanOrVarCalculation(sum); // Step 4. Get var. // Calculate squareSum for the layer. SmallVector squareSumIndexingMaps{ inputShapeAffineMap, meanAndVarShapeAffineMap, meanAndVarShapeAffineMap, }; auto initSquareSumTensor = createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); Value squareSum = rewriter .create( loc, initSquareSumTensor.getType(), ValueRange{input, mean}, initSquareSumTensor, /*indexingMaps=*/squareSumIndexingMaps, /*iteratorTypes=*/inputShapeIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], mean = args[1], squareSum = args[2]; Value sub = rewriter.create(loc, input, mean); Value square = rewriter.create(loc, sub, sub); Value result = rewriter.create(loc, squareSum, square); b.create(loc, result); }) .getResult(0); Value var = genMeanOrVarCalculation(squareSum); // Step 5. Get layernorm. // Get affineMap for normalized shape. SmallVector normalizedShapeExprs; for (int i = meanAndVarShapeRank; i < inputRank; i++) normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); auto normalizedShapeAffineMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, normalizedShapeExprs, context); auto inputSizes = getTensorSizes(rewriter, loc, input); Value initLayerNormTensor = rewriter.create(loc, inputSizes, elemTy); SmallVector indexingMaps(1, inputShapeAffineMap); indexingMaps.resize(3, meanAndVarShapeAffineMap); indexingMaps.resize(5, normalizedShapeAffineMap); indexingMaps.push_back(inputShapeAffineMap); SmallVector layerNormIterationTypes( inputRank, getParallelIteratorTypeName()); Value layerNorm = rewriter .create( loc, initLayerNormTensor.getType(), ValueRange{input, mean, var, weight, bias}, initLayerNormTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/layerNormIterationTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], mean = args[1], var = args[2], weight = args[3], bias = args[4]; Value result = createLinalgPayloadCalculationForNormOps( b, loc, elemTy, input, mean, var, eps, weight, bias); b.create(loc, result); }) .getResult(0); Type layerNormResultType = getTypeConverter()->convertType(op.getType(0)); Type meanResultType = getTypeConverter()->convertType(op.getType(1)); Type varResultType = getTypeConverter()->convertType(op.getType(2)); Value layerNorm_ = rewriter.create(loc, layerNormResultType, layerNorm); Value mean_ = rewriter.create(loc, meanResultType, mean); Value var_ = rewriter.create(loc, varResultType, var); rewriter.replaceOp(op, {layerNorm_, mean_, var_}); return success(); } }; } // namespace // Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by: // for i in range(0, len(input[0])): // for j in range(0, len(input[1])): // nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0 // TODO: `weight` and `reduction` operands are still to be taken care of. namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.self(); Value target = adaptor.target(); Value weight = adaptor.weight(); Value gradOutput = adaptor.grad_output(); int64_t reduction; if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Handle reduction. if (reduction != torch_upstream::Reduction::None) return rewriter.notifyMatchFailure( op, "reduction along dimensions is not supported."); // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); Value ignoreIndex = adaptor.ignore_index(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); unsigned inputRank = input.getType().cast().getRank(); unsigned targetRank = target.getType().cast().getRank(); // TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is // required. if (inputRank != 2 || targetRank != 1) { return rewriter.notifyMatchFailure( op, "expected input and target to be rank 2 and 1 respectively"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); // Given there is no reduction `grad_input` size is equal to `input` size. auto outputSize = getTensorSizes(rewriter, loc, input); Value initTensor0 = createZeroInitTensor(rewriter, loc, outputSize, elementType); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); SmallVector targetExpr{rewriter.getAffineDimExpr(0)}; SmallVector resultExpr{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}; SmallVector iteratorTypes{getParallelIteratorTypeName(), getParallelIteratorTypeName()}; auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr}); Value finalRes = rewriter .create( loc, initTensor0.getType(), ValueRange{target, gradOutput}, initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value indTarget = rewriter.create( loc, rewriter.getIndexType(), args[0]); Value indJ = rewriter.create(loc, 1); // The final result is given by: // grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0 Value cmpEq = rewriter.create( loc, arith::CmpIPredicate::eq, indJ, indTarget); // The target index shouldn't be equal to `ignoreIndex`. Value cmpNe = rewriter.create( loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget); Value finalPredicate = rewriter.create(loc, cmpEq, cmpNe); Value negate = rewriter.create(loc, elementType, args[1]); Value selectFinal = rewriter.create( loc, finalPredicate, negate, zeroVal); b.create(loc, selectFinal); }) .getResult(0); rewriter.replaceOp(op, finalRes); return success(); } }; } // namespace namespace { class ConvertTensorStaticInfoCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, adaptor.operand()); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); }