//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Check if a ranked-tensor has the specified element type. template static bool hasElementType(Value tensor) { auto tensorType = tensor.getType().cast(); Type tensorElementType = tensorType.getElementType(); return tensorElementType.isa(); } template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, Value lhs, Value rhs) { if (type.isa()) return b.create(loc, fpred, lhs, rhs); if (IntegerType intType = type.dyn_cast()) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) return b.create(loc, ispred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); Value xMinusMean = b.create(loc, x, mean); Value two = b.create(loc, FloatAttr::get(elementType, 2)); Value sqrt2 = b.create(loc, two); Value erfArg = b.create(loc, xMinusMean, sqrt2); Value erf = b.create(loc, erfArg); Value one = b.create(loc, FloatAttr::get(elementType, 1)); Value erfPlus1 = b.create(loc, one, erf); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value normalCdf = b.create(loc, oneHalf, erfPlus1); return normalCdf; } static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { Type elementType = x.getType(); Value zero = b.create(loc, FloatAttr::get(elementType, 0)); Value one = b.create(loc, FloatAttr::get(elementType, 1)); return buildNormalCdf(b, loc, x, zero, one); } template static Value createCalculationForMathOpWithDtypeConversion( OpBuilder &b, TypeConverter *converter, Value payloadArg, Operation *op) { Type dtype = converter->convertType(op->getResult(0).getType()) .template cast() .getElementType(); Location loc = op->getLoc(); Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); return b.create(loc, arg); } static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.memory_format().getType().isa() && (!matchPattern(clone.memory_format(), m_TorchConstantInt(&memoryFormat)) || memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { clone.emitError("unimplemented: only default memory format is supported"); return nullptr; } return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { if (bitwiseAndTensor.getType() .cast() .getDtype() .isa()) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseAndTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto logicalOr = dyn_cast(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); return b.create(loc, lhsTest, rhsTest); } if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { auto negate = createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); return b.create(loc, one, added); } if (auto relu = dyn_cast(op)) { if (!relu.getType() .cast() .getDtype() .isa()) { relu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); return b.create(loc, pred, payloadArgs[0], constZero); } if (auto lrelu = dyn_cast(op)) { if (!lrelu.getType() .cast() .getDtype() .isa()) { lrelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); Value positivePart = b.create(loc, pred, payloadArgs[0], constZero); Value negativePart = b.create(loc, pred, constZero, payloadArgs[0]); Value scale = convertScalarToDtype(b, loc, operands[1], elementType); Value scaledNegativePart = b.create(loc, negativePart, scale); return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() .getDtype() .isa()) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(gelu.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]); return b.create(loc, payloadArgs[0], cdf); } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() .cast() .getDtype() .isa()) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(geluBackward.approximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Type elementType = payloadArgs[1].getType(); Value cstAlpha0 = b.create( loc, FloatAttr::get(elementType, 1.12837916709551257390)); Value cstAlpha1 = b.create( loc, FloatAttr::get(elementType, 0.70710678118654752440)); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value kAlpha = b.create(loc, cstAlpha0, cstAlpha1); Value kAlphaHalf = b.create(loc, kAlpha, oneHalf); Value negOneHalf = b.create(loc, FloatAttr::get(elementType, -0.5)); Value inputSquared = b.create(loc, payloadArgs[1], payloadArgs[1]); Value negHalfInputSquared = b.create(loc, inputSquared, negOneHalf); Value dinput = b.create(loc, negHalfInputSquared); Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]); Value dinputInput = b.create(loc, dinput, payloadArgs[1]); Value dinputInputAlpha = b.create(loc, dinputInput, kAlphaHalf); Value cdfExt = b.create(loc, dinputInputAlpha, cdf); return b.create(loc, payloadArgs[0], cdfExt); } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(add.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(sub.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto subScalar = dyn_cast(op)) { Type dtype = converter->convertType(subScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } subScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto addScalar = dyn_cast(op)) { Type dtype = converter->convertType(addScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } addScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(mul.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (dtype.isa()) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } } if (auto atan2 = dyn_cast(op)) { Type dtype = converter->convertType(atan2.getType()) .cast() .getElementType(); if (!dtype.isa()) { atan2.emitError("Atan2 requires floating point result type"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto gtTensor = dyn_cast(op)) { AtenGtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { gtTensor.emitError("unimplemented: different lhs and rhs dtype"); return nullptr; } Type elementalType = gtTensor.self().getType().cast().getDtype(); return createGreaterThan(b, loc, elementalType, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { AtenEqTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { eqTensor.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = eqTensor.self().getType().cast().getDtype(); if (elementalType.isa()) return b.create(loc, arith::CmpFPredicate::UEQ, payloadArgs[0], payloadArgs[1]); if (elementalType.isa()) { return b.create(loc, arith::CmpIPredicate::eq, payloadArgs[0], payloadArgs[1]); } eqTensor.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto ltTensor = dyn_cast(op)) { AtenLtTensorOp::Adaptor adaptor(operands); Type lhsDtype = payloadArgs[0].getType(); Type rhsDtype = payloadArgs[1].getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { ltTensor.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = ltTensor.self().getType().cast().getDtype(); return createLessThan(b, loc, elementalType, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); if (!dtype.isa()) { div.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto divTensorMode = dyn_cast(op)) { AtenDivTensorModeOp::Adaptor adaptor(operands); Type dtype = converter->convertType(divTensorMode.getType()) .cast() .getElementType(); if (!dtype.isa()) { divTensorMode.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value div = b.create(loc, lhs, rhs); if (divTensorMode.rounding_mode().getType().isa()) return div; std::string roundingMode; if (!matchPattern(divTensorMode.rounding_mode(), m_TorchConstantStr(roundingMode))) { divTensorMode.emitError("only support constant str rounding mode"); return nullptr; } if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. Value ceil = b.create(loc, div); Value floor = b.create(loc, div); Value cstZero = b.create(loc, b.getZeroAttr(dtype)); Value pred = b.create(loc, arith::CmpFPredicate::ULT, div, cstZero); return b.create(loc, pred, ceil, floor); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) return b.create(loc, div); } divTensorMode.emitError("invalid rounding mode"); return nullptr; } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() .getDtype() .isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type dtype = pow.self().getType().cast().getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto gtScalar = dyn_cast(op)) { Type dtype = gtScalar.self().getType().cast().getDtype(); // TODO: `gtTensor` and `gtScalar` share similar code and can be called from // one static function. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. gtScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ugt, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sgt, payloadArgs[0], otherPromoted); } gtScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto geScalar = dyn_cast(op)) { Type dtype = geScalar.self().getType().cast().getDtype(); // TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that // can be refactored. Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::UGE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor args from integer to float. geScalar.emitError( "unimplemented: type promotion from tensor to scalar."); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::uge, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sge, payloadArgs[0], otherPromoted); } geScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto eqScalar = dyn_cast(op)) { Type dtype = eqScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. eqScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } } return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted); } if (auto neScalar = dyn_cast(op)) { Type dtype = neScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); if (dtype.isa()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. neScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } } return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted); } if (auto ltScalar = dyn_cast(op)) { Type dtype = ltScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share // a lot of code that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. ltScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ult, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::slt, payloadArgs[0], otherPromoted); } ltScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto leScalar = dyn_cast(op)) { Type dtype = leScalar.self().getType().cast().getDtype(); Value otherPromoted = convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); // TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code // that can be refactored. if (dtype.isa()) return b.create(loc, arith::CmpFPredicate::ULE, payloadArgs[0], otherPromoted); if (IntegerType intType = dtype.dyn_cast()) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. leScalar.emitError( "unimplemented: type promotion from tensor to scalar"); return nullptr; } if (intType.isUnsigned()) return b.create(loc, arith::CmpIPredicate::ule, payloadArgs[0], otherPromoted); if (intType.isSigned()) return b.create(loc, arith::CmpIPredicate::sle, payloadArgs[0], otherPromoted); } leScalar.emitError("unimplemented: dtype isn't supported."); return nullptr; } if (auto whereSelf = dyn_cast(op)) { Type dtype = converter->convertType(whereSelf.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, payloadArgs[0], lhs, rhs); } if (auto lerp = dyn_cast(op)) { if (!lerp.getType() .cast() .getDtype() .isa()) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenLerpTensorOp::Adaptor adaptor(payloadArgs); auto start = adaptor.self(); auto end = adaptor.end(); auto weight = adaptor.weight(); auto delta = b.create(loc, end, start); auto weightedDelta = b.create(loc, delta, weight); return b.create(loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { Type dtype = minimum.getType().cast().getDtype(); Type elemTy = converter->convertType(minimum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { Type dtype = maximum.getType().cast().getDtype(); Type elemTy = converter->convertType(maximum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); if (!dtype.isa()) { clamp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.min(); auto max = adaptor.max(); if (min.getType().isa() || max.getType().isa()) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } auto result = payloadArgs[0]; if (!min.getType().isa()) { auto minPromoted = convertScalarToDtype(b, loc, min, dtype); auto pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); auto pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); result = b.create(loc, pred, maxPromoted, result); } return result; } if (auto rsub = dyn_cast(op)) { Type dtype = converter->convertType(rsub.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } else if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } rsub.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mulScalar = dyn_cast(op)) { Type dtype = converter->convertType(mulScalar.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (dtype.isa()) return b.create(loc, lhs, rhs); if (dtype.isa()) return b.create(loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; Type dtype = converter->convertType(atenToDtype.getType()) .cast() .getElementType(); Value result = convertScalarToDtype(b, loc, input, dtype); return result; } if (auto divScalar = dyn_cast(op)) { Type dtype = converter->convertType(divScalar.getType()) .cast() .getElementType(); if (!dtype.isa()) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = payloadArgs[0]; Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } if (auto remScalar = dyn_cast(op)) { Type newResultType = converter->convertType(remScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value result; if (newResultType.isa()) { result = b.create(loc, self, other); } else { result = b.create(loc, self, other); } return result; } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() .getElementType(); Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Type elementType = arg.getType(); // assert(element != 0) auto zero = b.create(loc, FloatAttr::get(elementType, 0.0)); auto pred = b.create(loc, arith::CmpFPredicate::ONE, arg, zero); b.create( loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); auto one = b.create(loc, FloatAttr::get(elementType, 1.0)); return b.create(loc, one, arg); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? value : self AtenThresholdOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdOp.getType()) .cast() .getElementType(); Value self = payloadArgs[0]; Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, value, self); } if (auto thresholdBackward = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? 0 : grad AtenThresholdBackwardOp::Adaptor adaptor(operands); Type dtype = converter->convertType(thresholdBackward.getType()) .cast() .getElementType(); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; if (dtype.isa()) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, constantZero, grad); } if (auto maskedFill = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); Type dtype = converter->convertType(maskedFill.getType()) .cast() .getElementType(); Value input = payloadArgs[0]; Value mask = payloadArgs[1]; Value fillValue = convertScalarToDtype(b, loc, adaptor.value(), dtype); return b.create(loc, mask, fillValue, input); } if (auto triu = dyn_cast(op)) { // Check if the rank of the input tensor is valid. AtenTriuOp::Adaptor adaptor(operands); auto inputType = adaptor.self().getType().cast(); uint64_t inputRank = inputType.getRank(); if (inputRank < 2) { triu.emitError("too few dimensions to compute triangular part of matrix"); return nullptr; } // Use the indices of the two innermost dimensions. auto rowIndex = b.create(loc, inputRank - 2); Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex); auto colIndex = b.create(loc, inputRank - 1); Value colIndexI64 = castIndexToInt64(b, loc, colIndex); // columnIndex >= rowIndex + diagonal? auto sum = b.create(loc, rowIndexI64, adaptor.diagonal()); auto pred = b.create(loc, arith::CmpIPredicate::sge, colIndexI64, sum); Value scalar = payloadArgs[0]; Type elementType = inputType.getElementType(); Value zero = getConstant(b, loc, 0, elementType); return b.create(loc, pred, scalar, zero); } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; } namespace { // Converts an elementwise op. // This specifically includes: // - converting elementwise ops of any tensor arity // - converting elementwise ops with any number of scalar captures (such as a // scalar alpha to torch.aten.Add) // - broadcasting of static size-1 dimensions // // Currently, we adopt the behavior that "size 1" broadcasting is a runtime // error if it happens dynamically. // // Looking forward a bit, eventually, it probably makes sense to have // a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted // operands. Modeling elementwise ops that way is potentially useful to allow a // more centralized reasoning about multiversioning. However a cost model will // be needed for "pre-fusing" elementwise ops that way, as it can potentially be // a pessimization. A mild extension of this pattern should work for such a // general op. class ConvertElementwiseOp : public ConversionPattern { public: ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( operands, [](Value v) { return v.getType().isa(); })); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( b, loc, getTypeConverter(), payloadArgs, op, operands); if (!result) { hadErrorCreatingPayload = true; return; } b.create(loc, result); }); if (hadErrorCreatingPayload) return failure(); rewriter.replaceOpWithNewOp(op, resultType, generic); return success(); } }; } // namespace // Given `input`, `target`, `nll_loss_forward` is given by: // for i in range(0, len(target)): // indi = target[i]; // nll_loss_forward[i] = -(input[i][indi]); // TODO: `weight`operand is still to be taken care of. namespace { class ConvertAtenNllLossForwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.self(); Value target = adaptor.target(); Value weight = adaptor.weight(); int64_t reduction; if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); Value ignoreIndex = adaptor.ignore_index(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); unsigned inputRank = input.getType().cast().getRank(); unsigned targetRank = target.getType().cast().getRank(); // TODO: Add support for k-dim loss. if (inputRank > 2) { return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value indTarget = rewriter.create( loc, rewriter.getIndexType(), targetVal); // The final result is given by: // final_res = (indTarget == ignoreIndexVal) ? 0 : // input[indI][IndTarget] Value cmpEq = rewriter.create( loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); SmallVector extractionIndices{indTarget}; if (inputRank == 2) { Value indI = rewriter.create(loc, 0); extractionIndices.insert(extractionIndices.begin(), indI); } Value result = rewriter.create(loc, input, extractionIndices); Value negate = rewriter.create(loc, elementType, result); Value selectFinal = rewriter.create(loc, cmpEq, zeroVal, negate); b.create(loc, selectFinal); }); if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { Value numOfElems = getTensorSize(rewriter, loc, finalRes); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); llvm::iota_range dimsToReduce(0, targetRank, /*inclusive=*/false); DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; finalRes = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, /*initElem=*/zeroVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value newVal = args[0]; Value accumulator = args[1]; if (reduction == torch_upstream::Reduction::Mean) newVal = b.create(loc, newVal, numOfElems); Value result = b.create(loc, newVal, accumulator); b.create(loc, result); }); } // TODO: Update the second result tensor. Value weightUpdated = createZeroInitTensor(rewriter, loc, {}, elementType); rewriter.replaceOp(op, {finalRes, weightUpdated}); return success(); } }; } // namespace /// Inverted STD: rSTD = 1 / sqrt(var + eps). static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps, Value var) { // The eps is always f64. Value truncatedEps = b.create(loc, elemTy, eps); Value varPlusEps = b.create(loc, var, truncatedEps); Value rSTD = b.create(loc, varPlusEps); return rSTD; } // Normalization formula: // ((input - mean) * rSTD * weight + bias static Value createLinalgPayloadCalculationForNormOpsWithRSTD( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value rSTD, Value eps, Value weight, Value bias) { Value inputSubMean = b.create(loc, input, mean); Value temp = b.create(loc, inputSubMean, rSTD); Value timesWeight = b.create(loc, temp, weight); Value plusBias = b.create(loc, timesWeight, bias); return plusBias; } static Value createLinalgPayloadCalculationForNormOpsWithVar( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var, Value eps, Value weight, Value bias) { Value rSTD = calculateRSTD(b, loc, elemTy, eps, var); Value result = createLinalgPayloadCalculationForNormOpsWithRSTD( b, loc, elemTy, input, mean, rSTD, eps, weight, bias); return result; } namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value runningMean = adaptor.running_mean(); Value runningVar = adaptor.running_var(); Value training = adaptor.training(); Value eps = adaptor.eps(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias)) || failed(checkNotNone(rewriter, op, runningMean)) || failed(checkNotNone(rewriter, op, runningVar))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); auto runningMeanType = runningMean.getType().cast(); auto runningVarType = runningVar.getType().cast(); auto inputRank = inputType.getRank(); if (inputRank < 2) return rewriter.notifyMatchFailure( op, "input should have rank larger than 1"); if (weightType.getRank() != 1 || biasType.getRank() != 1 || runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } // TODO: Add support for training. auto constFalse = rewriter.create( loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); auto trainingFalse = rewriter.create( loc, arith::CmpIPredicate::eq, training, constFalse); rewriter.create( loc, trainingFalse, rewriter.getStringAttr("training is not supported for now")); // num_features – C from an expected input of size (N,C,D,H,W ...) Value numFeatures = rewriter.create(loc, input, 1); auto contractingDim0EqualsNumFeatures = [&](Value v) { auto dim0 = rewriter.create(loc, v, 0); auto dim0Equal = rewriter.create( loc, arith::CmpIPredicate::eq, numFeatures, dim0); rewriter.create( loc, dim0Equal, rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; contractingDim0EqualsNumFeatures(weight); contractingDim0EqualsNumFeatures(bias); contractingDim0EqualsNumFeatures(runningMean); contractingDim0EqualsNumFeatures(runningVar); auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(inputRank), // input indexingMap, // weight indexingMap, // bias indexingMap, // runningMean indexingMap, // runningVar rewriter.getMultiDimIdentityMap(inputRank), // output }; SmallVector iteratorTypes(inputRank, "parallel"); Value batchNorm = rewriter .create( loc, input.getType(), ValueRange{input, weight, bias, runningMean, runningVar}, input, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], weight = args[1], bias = args[2], mean = args[3], var = args[4]; Value result = createLinalgPayloadCalculationForNormOpsWithVar( b, loc, var.getType(), input, mean, var, eps, weight, bias); b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); return success(); } }; } // namespace // For layernorm, the mean and standard-deviation are calculated separately over // the last certain number dimensions which have to be of the shape specified by // normalized_shape. // // The shapes of different parts are as the following: // +-------------------+--------------------+ // | meanAndVarShape | normalizedShape | // +-------------------+--------------------- // <------------+ inputShape +--------------> // There are the following steps: // Step 1. Check if all the arguments meet the requirements. // Step 2. Common parts to be used for getting mean and var. // This includes elements count, affineMap and iteratorTypes. // Step 3. Get mean. // Step 4. Get rSTD. // Step 5. Get layernorm. namespace { class ConvertAtenNativeLayerNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); Value eps = adaptor.eps(); Value normalizedShape = op.normalized_shape(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); int64_t inputRank = inputType.getRank(); Type elemTy = inputType.getElementType(); // Step 1. Check if all the arguments meet the requirements. SmallVector normalizedShapeSizesTorchInt; if (!getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt)) { return rewriter.notifyMatchFailure(op, "Unimplemented normalized_shape not" "constructed from ListConstruct"); } SmallVector normalizedShapeSizesInt = getTypeConvertedValues( rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt); int64_t normalizedShapeRank = normalizedShapeSizesInt.size(); if (weightType.getRank() != normalizedShapeRank || biasType.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" "normalized shape not compatible"); // Check all the dimensions match the normalized_shape int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size(); for (auto en : enumerate((normalizedShapeSizesInt))) { auto index = en.index(); auto inputDim = getDimOp(rewriter, loc, input, index + meanAndVarShapeRank); auto weightDim = getDimOp(rewriter, loc, weight, index); auto biasDim = getDimOp(rewriter, loc, bias, index); auto expectedSize = en.value(); checkDimEqualHelper(rewriter, loc, inputDim, expectedSize); checkDimEqualHelper(rewriter, loc, weightDim, expectedSize); checkDimEqualHelper(rewriter, loc, biasDim, expectedSize); } // Get iterator types for input shape. SmallVector normalizedShapeIteratorTypes( normalizedShapeRank, getReductionIteratorTypeName()); SmallVector meanAndVarIterationTypes( meanAndVarShapeRank, getParallelIteratorTypeName()); SmallVector inputShapeIteratorTypes = meanAndVarIterationTypes; inputShapeIteratorTypes.append(normalizedShapeIteratorTypes); // Step 2. Common parts to be used for getting mean and var. // Get sizes and affineMaps needed for mean and var. AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector meanAndVarShapeExprs; for (int i = 0; i < meanAndVarShapeRank; i++) meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); auto meanAndVarShapeAffineMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, meanAndVarShapeExprs, context); SmallVector meanAndVarShapeSizes = getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1); // Get number of elements to be used for calculating mean and var. Value elemCnts = normalizedShapeSizesInt[0]; for (int i = 1; i < normalizedShapeRank; i++) { elemCnts = rewriter.create(loc, elemCnts, normalizedShapeSizesInt[i]); } Value elemCntsFloat = rewriter.create(loc, elemTy, elemCnts); // Helper to calculate mean and var. auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) { SmallVector indexingMaps( 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); Value initShapeTensor = rewriter.create( loc, meanAndVarShapeSizes, elemTy); return rewriter .create( loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/meanAndVarIterationTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value sumOrSqureSum = args[0]; Value result = b.create(loc, sumOrSqureSum, elemCntsFloat); b.create(loc, result); }) .getResult(0); }; // Step 3. Get mean. // Get sum to be used for calculating mean. SmallVector sumIndexingMaps = { inputShapeAffineMap, // input meanAndVarShapeAffineMap, // output }; auto initSumTensor = createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); Value sum = rewriter .create( loc, initSumTensor.getType(), input, initSumTensor, /*indexingMaps=*/sumIndexingMaps, /*iteratorTypes=*/inputShapeIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], sum = args[1]; Value result = rewriter.create(loc, sum, input); b.create(loc, result); }) .getResult(0); Value mean = genMeanOrVarCalculation(sum); // Step 4. Get rSTD. // Calculate squareSum for the layer. SmallVector squareSumIndexingMaps{ inputShapeAffineMap, meanAndVarShapeAffineMap, meanAndVarShapeAffineMap, }; auto initSquareSumTensor = createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy); Value squareSum = rewriter .create( loc, initSquareSumTensor.getType(), ValueRange{input, mean}, initSquareSumTensor, /*indexingMaps=*/squareSumIndexingMaps, /*iteratorTypes=*/inputShapeIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], mean = args[1], squareSum = args[2]; Value sub = rewriter.create(loc, input, mean); Value square = rewriter.create(loc, sub, sub); Value result = rewriter.create(loc, squareSum, square); b.create(loc, result); }) .getResult(0); Value var = genMeanOrVarCalculation(squareSum); Value rSTDTensor = rewriter.create( loc, meanAndVarShapeSizes, elemTy); SmallVector rSTDIndexingMap( 2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank)); Value rSTD = rewriter .create( loc, rSTDTensor.getType(), var, rSTDTensor, rSTDIndexingMap, meanAndVarIterationTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value result = calculateRSTD(b, loc, elemTy, eps, args[0]); b.create(loc, result); }) .getResult(0); // Step 5. Get layernorm. // Get affineMap for normalized shape. SmallVector normalizedShapeExprs; for (int i = meanAndVarShapeRank; i < inputRank; i++) normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context)); auto normalizedShapeAffineMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, normalizedShapeExprs, context); auto inputSizes = getTensorSizes(rewriter, loc, input); Value initLayerNormTensor = rewriter.create(loc, inputSizes, elemTy); SmallVector indexingMaps(1, inputShapeAffineMap); indexingMaps.resize(3, meanAndVarShapeAffineMap); indexingMaps.resize(5, normalizedShapeAffineMap); indexingMaps.push_back(inputShapeAffineMap); SmallVector layerNormIterationTypes( inputRank, getParallelIteratorTypeName()); Value layerNorm = rewriter .create( loc, initLayerNormTensor.getType(), ValueRange{input, mean, rSTD, weight, bias}, initLayerNormTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/layerNormIterationTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], mean = args[1], rSTD = args[2], weight = args[3], bias = args[4]; Value result = createLinalgPayloadCalculationForNormOpsWithRSTD( b, loc, elemTy, input, mean, rSTD, eps, weight, bias); b.create(loc, result); }) .getResult(0); SmallVector expandShape(inputRank, 1); for (int i = 0; i < meanAndVarShapeRank; i++) { // `mean` and `rstd` are not yet casted, so they will be having dynamic // shape. Hence to match them, for each dimension corresponding to `mean` // or `rstd` assign -1. expandShape[i] = -1; } auto expandShapeType = RankedTensorType::get(expandShape, elemTy); SmallVector reassociation(meanAndVarShapeRank); for (auto i : llvm::seq(0, meanAndVarShapeRank)) { reassociation[i].push_back(i); if (i == meanAndVarShapeRank - 1) { for (auto j : llvm::seq(0, normalizedShapeRank)) reassociation[i].push_back(i + j + 1); } } Value meanResult = rewriter.create( loc, expandShapeType, mean, reassociation); Value rSTDResult = rewriter.create( loc, expandShapeType, rSTD, reassociation); Type layerNormResultType = getTypeConverter()->convertType(op.getType(0)); Type meanResultType = getTypeConverter()->convertType(op.getType(1)); Type rSTDResultType = getTypeConverter()->convertType(op.getType(2)); Value layerNorm_ = rewriter.create(loc, layerNormResultType, layerNorm); Value mean_ = rewriter.create(loc, meanResultType, meanResult); Value var_ = rewriter.create(loc, rSTDResultType, rSTDResult); rewriter.replaceOp(op, {layerNorm_, mean_, var_}); return success(); } }; } // namespace namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value gradOutput = adaptor.grad_output(); Value input = adaptor.self(); Value target = adaptor.target(); Value weight = adaptor.weight(); bool weightIsNone = op.weight().getType().isa(); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.ignore_index()); Value totalWeight = adaptor.total_weight(); auto inputType = input.getType().cast(); int inputRank = inputType.getRank(); auto gradOutputType = gradOutput.getType().cast(); Type resultElementType = gradOutputType.getElementType(); int64_t reduction; if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); if (!hasElementType(gradOutput) || !hasElementType(gradOutput) || (!weightIsNone && !hasElementType(weight))) { return rewriter.notifyMatchFailure( op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of " "type float"); } if (!hasElementType(target)) { return rewriter.notifyMatchFailure( op, "`target` must be a tensor of integer type"); } auto outputSize = getTensorSizes(rewriter, loc, input); Value gradInputTensor = createZeroInitTensor(rewriter, loc, outputSize, resultElementType); auto getAffineMapForSingleElementTensor = [&](Value tensor) { auto tensorType = tensor.getType().cast(); SmallVector affineExprs(tensorType.getRank(), rewriter.getAffineConstantExpr(0)); return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, op->getContext()); }; AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (reduction != torch_upstream::Reduction::None || inputRank == 1) gradOutMap = getAffineMapForSingleElementTensor(gradOutput); AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (inputRank == 1) targetMap = getAffineMapForSingleElementTensor(target); AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight); AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector indexingMaps{gradOutMap, targetMap, totalWeightMap, resultMap}; SmallVector iteratorTypes(inputRank, getParallelIteratorTypeName()); // The code generation is equivalent to the following pseudo-code: // // for batch_index in len(input.size(0)): // for class_index in len(input.size(1)): // target_elem = target[batch_index] // // if reduction == None: // grad_out_elem = grad_output[batchIndex] // else: // grad_out_elem = grad_output[0] // // if reduction == Mean: // total_weight_elem = total_weight[0] // grad_out_elem /= total_weight_elem // // weight_elem = weight[target_elem] if weight != None else 1 // // if target_elem != class_index or target_elem == ignore_index: // grad_input_elem = -weight_elem * grad_out_elem // else: // grad_input_elem = 0 // grad_input[batch_index, target_elem] = grad_input_elem // // NOTE: In the case of not batch dimension, `batch_index` essentially // becomes zero. Value gradInput = rewriter .create( loc, gradInputTensor.getType(), ValueRange{gradOutput, target, totalWeight}, gradInputTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value gradOutElem = args[0]; Value targetElem = castIntToIndex(b, loc, args[1]); Value totalWeightElem = args[2]; Value classIndex = b.create(loc, inputRank - 1); if (reduction == torch_upstream::Reduction::Mean) { gradOutElem = b.create(loc, gradOutElem, totalWeightElem); } Value negGradOutElem = b.create(loc, gradOutElem); Value weightElem = getConstant(b, loc, 1, resultElementType); if (!weightIsNone) { weightElem = b.create(loc, weight, targetElem); } Value weightedNegGradOutElem = b.create(loc, weightElem, negGradOutElem); Value targetNeqClassIndex = b.create( loc, arith::CmpIPredicate::ne, targetElem, classIndex); Value targetEqIgnoreIndex = b.create( loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); Value gradInputIsZero = b.create( loc, targetNeqClassIndex, targetEqIgnoreIndex); Value zero = getConstant(b, loc, 0, resultElementType); Value gradInElem = b.create( loc, gradInputIsZero, zero, weightedNegGradOutElem); b.create(loc, gradInElem); }) ->getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } }; } // namespace namespace { class ConvertAtenDetachOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDetachOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, adaptor.self()); return success(); } }; } // namespace namespace { class ConvertTensorStaticInfoCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, adaptor.operand()); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); }