//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Check if a ranked-tensor has the specified element type. template static bool hasElementType(Value tensor) { auto tensorType = tensor.getType().cast(); Type tensorElementType = tensorType.getElementType(); return tensorElementType.isa(); } template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, Value lhs, Value rhs) { if (type.isa()) return b.create(loc, fpred, lhs, rhs); if (IntegerType intType = type.dyn_cast()) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) return b.create(loc, ispred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); Value xMinusMean = b.create(loc, x, mean); Value two = b.create(loc, FloatAttr::get(elementType, 2)); Value sqrt2 = b.create(loc, two); Value erfArg = b.create(loc, xMinusMean, sqrt2); Value erf = b.create(loc, erfArg); Value one = b.create(loc, FloatAttr::get(elementType, 1)); Value erfPlus1 = b.create(loc, one, erf); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value normalCdf = b.create(loc, oneHalf, erfPlus1); return normalCdf; } static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { Type elementType = x.getType(); Value zero = b.create(loc, FloatAttr::get(elementType, 0)); Value one = b.create(loc, FloatAttr::get(elementType, 1)); return buildNormalCdf(b, loc, x, zero, one); } template static Value createCalculationForMathOpWithDtypeConversion( OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) { Type dtype = converter->convertType(op->getResult(0).getType()) .template cast() .getElementType(); Location loc = op->getLoc(); Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); return b.create(loc, arg); } template static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, Value lhs, Value rhs) { static_assert(std::is_same() || std::is_same() || std::is_same() || std::is_same() || std::is_same(), "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { op.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = op.getSelf().getType().template cast().getDtype(); if constexpr (std::is_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); } static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { clone.emitError("unimplemented: only contiguous and channels last memory " "format is supported"); return nullptr; } return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { if (bitwiseAndTensor.getType() .cast() .getDtype() .isa()) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseAndTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseOrTensor = dyn_cast(op)) { if (bitwiseOrTensor.getType() .cast() .getDtype() .isa()) { bitwiseOrTensor.emitError( "Bitwise_Or does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseOrTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseXorTensor = dyn_cast(op)) { if (bitwiseXorTensor.getType() .cast() .getDtype() .isa()) { bitwiseXorTensor.emitError( "Bitwise_Xor does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseXorTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } llvm_unreachable("Unknown op type"); } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { auto negate = createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); return b.create(loc, one, added); } if (auto relu = dyn_cast(op)) { if (!relu.getType() .cast() .getDtype() .isa()) { relu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); return b.create(loc, pred, payloadArgs[0], constZero); } if (auto round = dyn_cast(op)) { if (!round.getType() .cast() .getDtype() .isa()) { round.emitError("unimplemented: non-floating point dtype"); return nullptr; } return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { if (!prelu.getType() .cast() .getDtype() .isa()) { prelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); Value positivePart = b.create(loc, pred, payloadArgs[0], constZero); Value negativePart = b.create(loc, pred, constZero, payloadArgs[0]); Value scale = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value scaledNegativePart = b.create(loc, negativePart, scale); return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() .getDtype() .isa()) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]); return b.create(loc, payloadArgs[0], cdf); } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() .cast() .getDtype() .isa()) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(geluBackward.getApproximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Type elementType = payloadArgs[1].getType(); Value cstAlpha0 = b.create( loc, FloatAttr::get(elementType, 1.12837916709551257390)); Value cstAlpha1 = b.create( loc, FloatAttr::get(elementType, 0.70710678118654752440)); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value kAlpha = b.create(loc, cstAlpha0, cstAlpha1); Value kAlphaHalf = b.create(loc, kAlpha, oneHalf); Value negOneHalf = b.create(loc, FloatAttr::get(elementType, -0.5)); Value inputSquared = b.create(loc, payloadArgs[1], payloadArgs[1]); Value negHalfInputSquared = b.create(loc, inputSquared, negOneHalf); Value dinput = b.create(loc, negHalfInputSquared); Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]); Value dinputInput = b.create(loc, dinput, payloadArgs[1]); Value dinputInputAlpha = b.create(loc, dinputInput, kAlphaHalf); Value cdfExt = b.create(loc, dinputInputAlpha, cdf); return b.create(loc, payloadArgs[0], cdfExt); } if (auto hardtanhBackward = dyn_cast(op)) { AtenHardtanhBackwardOp::Adaptor adaptor(operands); if (!hardtanhBackward.getType() .cast() .getDtype() .isa()) { hardtanhBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value gradOutput = payloadArgs[0]; Type elementType = gradOutput.getType(); Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value constantZero = b.create(loc, FloatAttr::get(elementType, 0.0)); Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType); Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType); Value lesser = b.create(loc, arith::CmpFPredicate::ULT, self, min); Value greater = b.create(loc, arith::CmpFPredicate::UGT, self, max); Value cmp = b.create(loc, lesser, greater); return b.create(loc, cmp, constantZero, gradOutput); } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(add.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(sub.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto subScalar = dyn_cast(op)) { Type dtype = converter->convertType(subScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } subScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto addScalar = dyn_cast(op)) { Type dtype = converter->convertType(addScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } addScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(mul.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } } if (auto atan2 = dyn_cast(op)) { Type dtype = converter->convertType(atan2.getType()) .cast() .getElementType(); if (!dtype.isa()) { atan2.emitError("Atan2 requires floating point result type"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); } if (auto leTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]); } if (auto gtTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]); } if (auto geTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); if (!dtype.isa()) { div.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto divTensorMode = dyn_cast(op)) { AtenDivTensorModeOp::Adaptor adaptor(operands); Type dtype = converter->convertType(divTensorMode.getType()) .cast() .getElementType(); if (!dtype.isa()) { divTensorMode.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value div = b.create(loc, lhs, rhs); if (divTensorMode.getRoundingMode().getType().isa()) return div; std::string roundingMode; if (!matchPattern(divTensorMode.getRoundingMode(), m_TorchConstantStr(roundingMode))) { divTensorMode.emitError("only support constant str rounding mode"); return nullptr; } if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. Value ceil = b.create(loc, div); Value floor = b.create(loc, div); Value cstZero = b.create(loc, b.getZeroAttr(dtype)); Value pred = b.create(loc, arith::CmpFPredicate::ULT, div, cstZero); return b.create(loc, pred, ceil, floor); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) return b.create(loc, div); } divTensorMode.emitError("invalid rounding mode"); return nullptr; } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() .getDtype() .isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type dtype = pow.getSelf().getType().cast().getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto pow = dyn_cast(op)) { Type dtype = converter->convertType(pow.getType()) .cast() .getElementType(); if (!dtype.isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto imag = dyn_cast(op)) { Type dtype = converter->convertType(imag.getType()) .cast() .getElementType(); if (!dtype.isa()) { imag.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value imagVal = b.create(loc, payloadArgs[0]); return imagVal; } if (auto real = dyn_cast(op)) { Type dtype = converter->convertType(real.getType()) .cast() .getElementType(); if (!dtype.isa()) { real.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value realVal = b.create(loc, payloadArgs[0]); return realVal; } if (auto gtScalar = dyn_cast(op)) { Type dtype = gtScalar.getSelf().getType().cast().getDtype(); // TODO: `gtTensor` and `gtScalar` share similar code and can be called from // one static function. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. gtScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ugt, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sgt, payloadArgs[0], otherPromoted); } gtScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto geScalar = dyn_cast(op)) { Type dtype = geScalar.getSelf().getType().cast().getDtype(); // TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that // can be refactored. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. geScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::uge, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sge, payloadArgs[0], otherPromoted); } geScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto eqScalar = dyn_cast(op)) { Type dtype = eqScalar.getSelf().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. eqScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } } return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted); } if (auto neScalar = dyn_cast(op)) { Type dtype = neScalar.getSelf().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. neScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } } return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted); } if (auto ltScalar = dyn_cast(op)) { Type dtype = ltScalar.getSelf().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share // a lot of code that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. ltScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ult, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::slt, payloadArgs[0], otherPromoted); } ltScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto leScalar = dyn_cast(op)) { Type dtype = leScalar.getSelf().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code // that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. leScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ule, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sle, payloadArgs[0], otherPromoted); } leScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto whereSelf = dyn_cast(op)) { Type dtype = converter->convertType(whereSelf.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, payloadArgs[0], lhs, rhs); } if (auto lerp = dyn_cast(op)) { if (!lerp.getType() .cast() .getDtype() .isa()) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenLerpTensorOp::Adaptor adaptor(payloadArgs); auto start = adaptor.getSelf(); auto end = adaptor.getEnd(); auto weight = adaptor.getWeight(); auto delta = b.create(loc, end, start); auto weightedDelta = b.create(loc, delta, weight); return b.create(loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { Type dtype = minimum.getType().cast().getDtype(); Type elemTy = converter->convertType(minimum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { Type dtype = maximum.getType().cast().getDtype(); Type elemTy = converter->convertType(maximum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); if (!dtype.isa()) { clamp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); if (min.getType().isa() || max.getType().isa()) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } auto result = payloadArgs[0]; if (!min.getType().isa()) { auto minPromoted = convertScalarToDtype(b, loc, min, dtype); auto pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); result = b.create(loc, pred, maxPromoted, result); } return result; } if (auto rsub = dyn_cast(op)) { Type dtype = converter->convertType(rsub.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } else if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } rsub.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mulScalar = dyn_cast(op)) { Type dtype = converter->convertType(mulScalar.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (dtype.isa()) return b.create(loc, lhs, rhs); if (dtype.isa()) return b.create(loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; Type dtype = converter->convertType(atenToDtype.getType()) .cast() .getElementType(); Value result = convertScalarToDtype(b, loc, input, dtype); return result; } if (auto divScalar = dyn_cast(op)) { Type dtype = converter->convertType(divScalar.getType()) .cast() .getElementType(); if (!dtype.isa()) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = payloadArgs[0]; Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } if (auto remScalar = dyn_cast(op)) { Type newResultType = converter->convertType(remScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value result; if (newResultType.isa()) { result = b.create(loc, self, other); } else if (newResultType.isa()) { result = b.create(loc, self, other); } else { remScalar.emitError( "Unsupported type encountered for AtenRemainderScalarOp."); } return result; } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() .getElementType(); Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Type elementType = arg.getType(); // assert(element != 0) auto zero = b.create(loc, FloatAttr::get(elementType, 0.0)); auto pred = b.create(loc, arith::CmpFPredicate::ONE, arg, zero); b.create( loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); auto one = b.create(loc, FloatAttr::get(elementType, 1.0)); return b.create(loc, one, arg); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? value : self AtenThresholdOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdOp.getType()) .cast() .getElementType(); Value self = payloadArgs[0]; Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, value, self); } if (auto thresholdBackward = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? 0 : grad AtenThresholdBackwardOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdBackward.getType()) .cast() .getElementType(); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, constantZero, grad); } if (auto fillScalar = dyn_cast(op)) { AtenFillScalarOp::Adaptor adaptor(operands); Type dtype = converter->convertType(fillScalar.getType()) .cast() .getElementType(); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); } if (auto maskedFillTensor = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); Type dtype = converter->convertType(maskedFillTensor.getType()) .cast() .getElementType(); Value input = payloadArgs[0]; Value mask = payloadArgs[1]; Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, mask, fillValue, input); } if (auto fillTensor = dyn_cast(op)) { AtenFillTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(fillTensor.getType()) .cast() .getElementType(); return convertScalarToDtype(b, loc, payloadArgs[1], dtype); } if (auto triu = dyn_cast(op)) { // Check if the rank of the input tensor is valid. AtenTriuOp::Adaptor adaptor(operands); auto inputType = adaptor.getSelf().getType().cast(); uint64_t inputRank = inputType.getRank(); if (inputRank < 2) { triu.emitError("too few dimensions to compute triangular part of matrix"); return nullptr; } // Use the indices of the two innermost dimensions. auto rowIndex = b.create(loc, inputRank - 2); Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex); auto colIndex = b.create(loc, inputRank - 1); Value colIndexI64 = castIndexToInt64(b, loc, colIndex); // columnIndex >= rowIndex + diagonal? auto sum = b.create(loc, rowIndexI64, adaptor.getDiagonal()); auto pred = b.create(loc, arith::CmpIPredicate::sge, colIndexI64, sum); Value scalar = payloadArgs[0]; Type elementType = inputType.getElementType(); Value zero = getConstant(b, loc, 0, elementType); return b.create(loc, pred, scalar, zero); } if (auto bitwiseNot = dyn_cast(op)) { Type elementType = converter->convertType(bitwiseNot.getType()) .cast() .getElementType(); if (elementType.isa()) { bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); return nullptr; } Value allOnesVal = b.create( loc, b.getIntegerAttr( elementType, APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); return b.create(loc, payloadArgs[0], allOnesVal); } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; } namespace { // Converts an elementwise op. // This specifically includes: // - converting elementwise ops of any tensor arity // - converting elementwise ops with any number of scalar captures (such as a // scalar alpha to torch.aten.Add) // - broadcasting of static size-1 dimensions // // Currently, we adopt the behavior that "size 1" broadcasting is a runtime // error if it happens dynamically. // // Looking forward a bit, eventually, it probably makes sense to have // a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted // operands. Modeling elementwise ops that way is potentially useful to allow a // more centralized reasoning about multiversioning. However a cost model will // be needed for "pre-fusing" elementwise ops that way, as it can potentially be // a pessimization. A mild extension of this pattern should work for such a // general op. class ConvertElementwiseOp : public ConversionPattern { public: ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( operands, [](Value v) { return v.getType().isa(); })); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( b, loc, getTypeConverter(), payloadArgs, op, operands); if (!result) { hadErrorCreatingPayload = true; return; } b.create(loc, result); }); if (hadErrorCreatingPayload) return failure(); rewriter.replaceOpWithNewOp(op, resultType, generic); return success(); } }; } // namespace // Given `input`, `target`, `nll_loss_forward` is given by: // for i in range(0, len(target)): // indi = target[i]; // nll_loss_forward[i] = -(input[i][indi]); // TODO: `weight`operand is still to be taken care of. namespace { class ConvertAtenNllLossForwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); int64_t reduction; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); Value ignoreIndex = adaptor.getIgnoreIndex(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); unsigned inputRank = input.getType().cast().getRank(); unsigned targetRank = target.getType().cast().getRank(); // TODO: Add support for k-dim loss. if (inputRank > 2) { return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value indTarget = rewriter.create( loc, rewriter.getIndexType(), targetVal); // The final result is given by: // final_res = (indTarget == ignoreIndexVal) ? 0 : // input[indI][IndTarget] Value cmpEq = rewriter.create( loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); SmallVector extractionIndices{indTarget}; if (inputRank == 2) { Value indI = rewriter.create(loc, 0); extractionIndices.insert(extractionIndices.begin(), indI); } Value result = rewriter.create(loc, input, extractionIndices); Value negate = rewriter.create(loc, elementType, result); Value selectFinal = rewriter.create(loc, cmpEq, zeroVal, negate); b.create(loc, selectFinal); }); if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { Value numOfElems = getTensorSize(rewriter, loc, finalRes); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); llvm::iota_range dimsToReduce(0, targetRank, /*inclusive=*/false); DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; finalRes = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, /*initElem=*/zeroVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value newVal = args[0]; Value accumulator = args[1]; if (reduction == torch_upstream::Reduction::Mean) newVal = b.create(loc, newVal, numOfElems); Value result = b.create(loc, newVal, accumulator); b.create(loc, result); }); } // TODO: Update the second result tensor. Value weightUpdated = createZeroInitTensor(rewriter, loc, {}, elementType); rewriter.replaceOp(op, {finalRes, weightUpdated}); return success(); } }; } // namespace /// Inverted STD: rSTD = 1 / sqrt(var + eps). static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps, Value var) { // The eps is always f64. Value truncatedEps = b.create(loc, elemTy, eps); Value varPlusEps = b.create(loc, var, truncatedEps); Value rSTD = b.create(loc, varPlusEps); return rSTD; } // Normalization formula: // ((input - mean) * rSTD * weight + bias static Value createLinalgPayloadCalculationForNormOpsWithRSTD( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value rSTD, Value eps, Value weight, Value bias) { Value inputSubMean = b.create(loc, input, mean); Value temp = b.create(loc, inputSubMean, rSTD); Value timesWeight = b.create(loc, temp, weight); Value plusBias = b.create(loc, timesWeight, bias); return plusBias; } static Value createLinalgPayloadCalculationForNormOpsWithVar( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var, Value eps, Value weight, Value bias) { Value rSTD = calculateRSTD(b, loc, elemTy, eps, var); Value result = createLinalgPayloadCalculationForNormOpsWithRSTD( b, loc, elemTy, input, mean, rSTD, eps, weight, bias); return result; } namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.getInput(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); Value runningMean = adaptor.getRunningMean(); Value runningVar = adaptor.getRunningVar(); Value training = adaptor.getTraining(); Value eps = adaptor.getEps(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias)) || failed(checkNotNone(rewriter, op, runningMean)) || failed(checkNotNone(rewriter, op, runningVar))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); auto runningMeanType = runningMean.getType().cast(); auto runningVarType = runningVar.getType().cast(); auto inputRank = inputType.getRank(); if (inputRank < 2) return rewriter.notifyMatchFailure( op, "input should have rank larger than 1"); if (weightType.getRank() != 1 || biasType.getRank() != 1 || runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } // TODO: Add support for training. auto constFalse = rewriter.create( loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); auto trainingFalse = rewriter.create( loc, arith::CmpIPredicate::eq, training, constFalse); rewriter.create( loc, trainingFalse, rewriter.getStringAttr("training is not supported for now")); // num_features – C from an expected input of size (N,C,D,H,W ...) Value numFeatures = rewriter.create(loc, input, 1); auto contractingDim0EqualsNumFeatures = [&](Value v) { auto dim0 = rewriter.create(loc, v, 0); auto dim0Equal = rewriter.create( loc, arith::CmpIPredicate::eq, numFeatures, dim0); rewriter.create( loc, dim0Equal, rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; contractingDim0EqualsNumFeatures(weight); contractingDim0EqualsNumFeatures(bias); contractingDim0EqualsNumFeatures(runningMean); contractingDim0EqualsNumFeatures(runningVar); auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(inputRank), // input indexingMap, // weight indexingMap, // bias indexingMap, // runningMean indexingMap, // runningVar rewriter.getMultiDimIdentityMap(inputRank), // output }; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value batchNorm = rewriter .create( loc, input.getType(), ValueRange{input, weight, bias, runningMean, runningVar}, input, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], weight = args[1], bias = args[2], mean = args[3], var = args[4]; Value result = createLinalgPayloadCalculationForNormOpsWithVar( b, loc, var.getType(), input, mean, var, eps, weight, bias); b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); return success(); } }; } // namespace namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value gradOutput = adaptor.getGradOutput(); Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); bool weightIsNone = op.getWeight().getType().isa(); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value totalWeight = adaptor.getTotalWeight(); auto inputType = input.getType().cast(); int inputRank = inputType.getRank(); auto gradOutputType = gradOutput.getType().cast(); Type resultElementType = gradOutputType.getElementType(); int64_t reduction; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); if (!hasElementType(gradOutput) || !hasElementType(gradOutput) || (!weightIsNone && !hasElementType(weight))) { return rewriter.notifyMatchFailure( op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of " "type float"); } if (!hasElementType(target)) { return rewriter.notifyMatchFailure( op, "`target` must be a tensor of integer type"); } auto outputSize = getTensorSizes(rewriter, loc, input); Value gradInputTensor = createZeroInitTensor(rewriter, loc, outputSize, resultElementType); auto getAffineMapForSingleElementTensor = [&](Value tensor) { auto tensorType = tensor.getType().cast(); SmallVector affineExprs(tensorType.getRank(), rewriter.getAffineConstantExpr(0)); return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, op->getContext()); }; AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (reduction != torch_upstream::Reduction::None || inputRank == 1) gradOutMap = getAffineMapForSingleElementTensor(gradOutput); AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (inputRank == 1) targetMap = getAffineMapForSingleElementTensor(target); AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight); AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector indexingMaps{gradOutMap, targetMap, totalWeightMap, resultMap}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); // The code generation is equivalent to the following pseudo-code: // // for batch_index in len(input.size(0)): // for class_index in len(input.size(1)): // target_elem = target[batch_index] // // if reduction == None: // grad_out_elem = grad_output[batchIndex] // else: // grad_out_elem = grad_output[0] // // if reduction == Mean: // total_weight_elem = total_weight[0] // grad_out_elem /= total_weight_elem // // weight_elem = weight[target_elem] if weight != None else 1 // // if target_elem != class_index or target_elem == ignore_index: // grad_input_elem = -weight_elem * grad_out_elem // else: // grad_input_elem = 0 // grad_input[batch_index, target_elem] = grad_input_elem // // NOTE: In the case of not batch dimension, `batch_index` essentially // becomes zero. Value gradInput = rewriter .create( loc, gradInputTensor.getType(), ValueRange{gradOutput, target, totalWeight}, gradInputTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value gradOutElem = args[0]; Value targetElem = castIntToIndex(b, loc, args[1]); Value totalWeightElem = args[2]; Value classIndex = b.create(loc, inputRank - 1); if (reduction == torch_upstream::Reduction::Mean) { gradOutElem = b.create(loc, gradOutElem, totalWeightElem); } Value negGradOutElem = b.create(loc, gradOutElem); Value weightElem = getConstant(b, loc, 1, resultElementType); if (!weightIsNone) { weightElem = b.create(loc, weight, targetElem); } Value weightedNegGradOutElem = b.create(loc, weightElem, negGradOutElem); Value targetNeqClassIndex = b.create( loc, arith::CmpIPredicate::ne, targetElem, classIndex); Value targetEqIgnoreIndex = b.create( loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); Value gradInputIsZero = b.create( loc, targetNeqClassIndex, targetEqIgnoreIndex); Value zero = getConstant(b, loc, 0, resultElementType); Value gradInElem = b.create( loc, gradInputIsZero, zero, weightedNegGradOutElem); b.create(loc, gradInElem); }) ->getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } }; } // namespace namespace { class ConvertAtenDetachOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDetachOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertTensorStaticInfoCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand()); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); }