//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" #include #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Check if a ranked-tensor has the specified element type. template static bool hasElementType(Value tensor) { auto tensorType = cast(tensor.getType()); Type tensorElementType = tensorType.getElementType(); return isa(tensorElementType); } template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, Value lhs, Value rhs) { if (isa(type)) return b.create(loc, fpred, lhs, rhs); if (IntegerType intType = dyn_cast(type)) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); if (intType.isSigned()) return b.create(loc, ispred, lhs, rhs); assert(intType.getWidth() == 1); return b.create(loc, iupred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } static Value getZeroPoint(Value value) { if (auto make = value.getDefiningOp()) { return make.getZeroPoint(); } return nullptr; } static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createLessThanOrEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( b, loc, elementalType, lhs, rhs); } static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); Value xMinusMean = b.create(loc, x, mean); Value two = b.create(loc, FloatAttr::get(elementType, 2)); Value sqrt2 = b.create(loc, two); Value erfArg = b.create(loc, xMinusMean, sqrt2); Value erf = b.create(loc, erfArg); Value one = b.create(loc, FloatAttr::get(elementType, 1)); Value erfPlus1 = b.create(loc, one, erf); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value normalCdf = b.create(loc, oneHalf, erfPlus1); return normalCdf; } static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { Type elementType = x.getType(); Value zero = b.create(loc, FloatAttr::get(elementType, 0)); Value one = b.create(loc, FloatAttr::get(elementType, 1)); return buildNormalCdf(b, loc, x, zero, one); } template static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, Value payloadArg, Operation *op) { Type inTTy = cast(op->getOperand(0).getType()).getDtype(); Type outTTy = cast(op->getResult(0).getType()).getDtype(); Type outTy = cast(converter->convertType(op->getResult(0).getType())) .getElementType(); Type computeTy = outTy; if (isa(computeTy)) computeTy = b.getF32Type(); Location loc = op->getLoc(); Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy); auto newOp = b.create(loc, arg); return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } template static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, Value lhs, Value rhs) { static_assert(std::is_same() || std::is_same() || std::is_same() || std::is_same() || std::is_same() || std::is_same(), "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs // to be handled. if (lhsDtype != rhsDtype) { op.emitError("unimplemented: lhs and rhs dtype must be same"); return nullptr; } Type elementalType = cast(op.getSelf().getType()).getDtype(); if constexpr (std::is_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createNotEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); } template static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, Value lhs, Value rhs) { static_assert(std::is_same() || std::is_same() || std::is_same() || std::is_same() || std::is_same() || std::is_same(), "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); Type elementalType = cast(op.getSelf().getType()).getDtype(); if (lhsDtype.isIntOrFloat() && rhsDtype.isIntOrFloat()) { if (isa(lhsDtype) && isa(rhsDtype)) { rhs = convertScalarToDtype(b, loc, rhs, lhsDtype); elementalType = lhsDtype; } else if (isa(lhsDtype) && isa(rhsDtype)) { lhs = convertScalarToDtype(b, loc, lhs, rhsDtype); elementalType = rhsDtype; } else { // Both are either Integer or Float types, but the bit width might be // different. if (lhsDtype.getIntOrFloatBitWidth() > rhsDtype.getIntOrFloatBitWidth()) { rhs = convertScalarToDtype(b, loc, rhs, lhsDtype); } else { lhs = convertScalarToDtype(b, loc, lhs, rhsDtype); } } } else { op.emitError("unimplemented: type promotion from tensor to scalar."); return nullptr; } if constexpr (std::is_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } if constexpr (std::is_same()) { return createNotEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); } template static LogicalResult createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ArrayRef operands, Value &result) { auto inputType = cast(operands[0].getType()); uint64_t inputRank = inputType.getRank(); // Use the indices of the two innermost dimensions. auto rowIndex = b.create(loc, inputRank - 2); Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex); auto colIndex = b.create(loc, inputRank - 1); Value colIndexI64 = castIndexToInt64(b, loc, colIndex); // columnIndex >= rowIndex + diagonal? auto sum = b.create(loc, rowIndexI64, /*diagonal=*/operands[1]); auto pred = b.create(loc, predicate, colIndexI64, sum); Value scalar = payloadArgs[0]; Type elementType = inputType.getElementType(); Value zero = getConstant(b, loc, 0, elementType); result = b.create(loc, pred, scalar, zero); return success(); } template Value createDivModePayload(OpBuilder &b, Location loc, const TypeConverter *converter, ValueRange payloadArgs, OpT op, ArrayRef operands) { static_assert(std::is_same_v || std::is_same_v, "template type must be a tensor/scalar div mode"); typename OpT::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(op.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype( b, loc, std::is_same_v ? operands[1] : payloadArgs[1], dtype); Value quotient; if (isa(dtype)) { quotient = b.create(loc, lhs, rhs); } else if (dtype.isUnsignedInteger()) { quotient = b.create(loc, lhs, rhs); } else { assert(dtype.isInteger() && "dtype should be an integer (signless or signed)"); quotient = b.create(loc, lhs, rhs); } if (isa(op.getRoundingMode().getType())) return quotient; std::string roundingMode; if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) { op.emitError("only support constant str rounding mode"); return nullptr; } assert((roundingMode == "trunc" || roundingMode == "floor") && "unsupported rounding mode"); if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. if (!isa(dtype)) { // nothing to do for integers return quotient; } // float Value ceil = b.create(loc, quotient); Value floor = b.create(loc, quotient); Value cstZero = b.create(loc, b.getZeroAttr(dtype)); Value pred = b.create(loc, arith::CmpFPredicate::ULT, quotient, cstZero); return b.create(loc, pred, ceil, floor); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) if (isa(dtype)) return b.create(loc, quotient); if (!dtype.isUnsignedInteger()) { Type defaultIntToFloatType = b.getF64Type(); lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType); rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType); quotient = b.create(loc, lhs, rhs); Value floor = b.create(loc, quotient); Value convert = convertScalarToDtype(b, loc, floor, dtype); return convert; } } return quotient; } static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, const TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || (memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { clone.emitError("unimplemented: only contiguous and channels last memory " "format is supported"); return nullptr; } return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { if (bitwiseAndTensor.getType() .cast() .getDtype() .isa()) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseAndTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseAndScalar = dyn_cast(op)) { Type dtype = converter->convertType(bitwiseAndScalar.getType()) .cast() .getElementType(); if (!isa(dtype)) { bitwiseAndScalar.emitError( "bitwise_and.Scalar does not support non-integer input dtype."); return nullptr; } Type resultElementType = cast(bitwiseAndScalar.getType()).getDtype(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value other = convertScalarToDtype(b, loc, operands[1], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); return b.create(loc, self, other); } if (auto bitwiseOrTensor = dyn_cast(op)) { if (bitwiseOrTensor.getType() .cast() .getDtype() .isa()) { bitwiseOrTensor.emitError( "Bitwise_Or does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseOrTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseXorTensor = dyn_cast(op)) { if (bitwiseXorTensor.getType() .cast() .getDtype() .isa()) { bitwiseXorTensor.emitError( "Bitwise_Xor does not support floating point dtype"); return nullptr; } Type dtype = converter->convertType(bitwiseXorTensor.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseRightShiftTensor = dyn_cast(op)) { Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) .cast() .getElementType(); if (!isa(dtype)) { bitwiseRightShiftTensor.emitError( "Bitwise_Right_Shift op does not support non-integer input dtype."); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseLeftShiftTensor = dyn_cast(op)) { Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) .cast() .getElementType(); if (!isa(dtype)) { bitwiseLeftShiftTensor.emitError( "Bitwise_Left_Shift op does not support non-integer input dtype."); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } if (isa(op)) { return b.create(loc, lhsTest, rhsTest); } llvm_unreachable("Unknown op type"); } if (isa(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } if (isa(op)) { if (payloadArgs[0].getType().isa()) return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); } if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); return createEqual(b, loc, abs.getType(), abs, infinity); } if (isa(op)) { Type inTTy = cast(op->getOperand(0).getType()).getDtype(); Type outTTy = cast(op->getResult(0).getType()).getDtype(); Type outTy = cast( converter->convertType(op->getResult(0).getType())) .getElementType(); Type computeTy = outTy; if (isa(computeTy)) computeTy = b.getF32Type(); Value arg = payloadArgs[0]; arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy); auto negate = b.create(loc, arg); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); auto div = b.create(loc, one, added); return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { Value zeroPoint = getZeroPoint(relu.getSelf()); Value arg = payloadArgs[0]; auto intType = dyn_cast(arg.getType()); if (zeroPoint && !intType) { relu.emitError("unimplemented: non-integer quantized Relu."); return nullptr; } auto reluTorchType = cast(relu.getType()); bool isUnsigned = torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype()); if (zeroPoint) { int64_t zeroPointInt; int64_t width = intType.getWidth(); assert(width < 64); int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1)); int64_t maxForIntType = isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1; // check for constant zero point edge-cases: if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) { if (zeroPointInt > maxForIntType) { // TODO: figure out how to handle this case: // current impl. quantizes output like input. // If zero point > maxForIntType, ordinary relu should return 0. // However, 0 isn't represented in such a quantization scheme. relu.emitError( "unimplemented: quantized relu for zero-point > max qint"); return nullptr; } if (zeroPointInt < minForIntType) return arg; } zeroPoint = converter->materializeTargetConversion( b, loc, converter->convertType(zeroPoint.getType()), zeroPoint); auto minForIntTypeValue = b.create( loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); auto maxForIntTypeValue = b.create( loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); auto zpLtMax = b.create(loc, arith::CmpIPredicate::slt, zeroPoint, maxForIntTypeValue); b.create( loc, zpLtMax, b.getStringAttr("Invalid Quantization: quantized relu with " "zero-point > max qint")); auto zpLtMin = b.create(loc, arith::CmpIPredicate::slt, zeroPoint, minForIntTypeValue); zeroPoint = b.create(loc, zpLtMin, minForIntTypeValue, zeroPoint); zeroPoint = b.create(loc, arg.getType(), zeroPoint); } else { zeroPoint = b.create(loc, b.getZeroAttr(arg.getType())); } Value cmp; if (intType) { auto pred = isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; cmp = b.create(loc, pred, arg, zeroPoint); } else { cmp = b.create(loc, arith::CmpFPredicate::UGT, arg, zeroPoint); } return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { if (!round.getType() .cast() .getDtype() .isa()) { round.emitError("unimplemented: non-floating point dtype"); return nullptr; } return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { if (!prelu.getType() .cast() .getDtype() .isa()) { prelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type elementType = payloadArgs[0].getType(); Value constZero = b.create(loc, b.getZeroAttr(elementType)); Value pred = b.create(loc, arith::CmpFPredicate::UGT, payloadArgs[0], constZero); Value positivePart = b.create(loc, pred, payloadArgs[0], constZero); Value negativePart = b.create(loc, pred, constZero, payloadArgs[0]); Value scale = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value scaledNegativePart = b.create(loc, negativePart, scale); return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { if (!gelu.getType() .cast() .getDtype() .isa()) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) { gelu.emitError( "unimplemented: expected approximate to be a constant str"); return nullptr; } if (approximate == "none") { Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]); return b.create(loc, payloadArgs[0], multiplier); } if (approximate == "tanh") { // GELU(x)=0.5∗x∗(1+Tanh((2/π)^1/2 * (x+0.044715∗x^3))) // Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html Value cstThree = b.create( loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3)); Value xCube = b.create(loc, payloadArgs[0], cstThree); Type elementType = payloadArgs[0].getType(); Value cstAlpha = b.create( loc, FloatAttr::get(elementType, 0.044715)); Value xCubeMulAlpha = b.create(loc, xCube, cstAlpha); Value xPlusXCubeMulAlpha = b.create(loc, payloadArgs[0], xCubeMulAlpha); Value cstBeta = b.create( loc, FloatAttr::get(elementType, 0.7977240352174656)); Value betaMulX = b.create(loc, cstBeta, xPlusXCubeMulAlpha); Value tanh = b.create(loc, betaMulX); Value cstOne = b.create(loc, FloatAttr::get(elementType, 1.0)); Value onePlusTanh = b.create(loc, cstOne, tanh); Value cstHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value multiplier = b.create(loc, cstHalf, onePlusTanh); return b.create(loc, payloadArgs[0], multiplier); } gelu.emitError("unimplemented: approximate value should be none or tanh"); return nullptr; } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() .cast() .getDtype() .isa()) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } // TODO: Take approximation into account. std::string approximate; if (!matchPattern(geluBackward.getApproximate(), m_TorchConstantStr(approximate)) || approximate != "none") return nullptr; Type elementType = payloadArgs[1].getType(); Value cstAlpha0 = b.create( loc, FloatAttr::get(elementType, 1.12837916709551257390)); Value cstAlpha1 = b.create( loc, FloatAttr::get(elementType, 0.70710678118654752440)); Value oneHalf = b.create(loc, FloatAttr::get(elementType, 0.5)); Value kAlpha = b.create(loc, cstAlpha0, cstAlpha1); Value kAlphaHalf = b.create(loc, kAlpha, oneHalf); Value negOneHalf = b.create(loc, FloatAttr::get(elementType, -0.5)); Value inputSquared = b.create(loc, payloadArgs[1], payloadArgs[1]); Value negHalfInputSquared = b.create(loc, inputSquared, negOneHalf); Value dinput = b.create(loc, negHalfInputSquared); Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]); Value dinputInput = b.create(loc, dinput, payloadArgs[1]); Value dinputInputAlpha = b.create(loc, dinputInput, kAlphaHalf); Value cdfExt = b.create(loc, dinputInputAlpha, cdf); return b.create(loc, payloadArgs[0], cdfExt); } if (auto hardtanhBackward = dyn_cast(op)) { AtenHardtanhBackwardOp::Adaptor adaptor(operands); if (!hardtanhBackward.getType() .cast() .getDtype() .isa()) { hardtanhBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value gradOutput = payloadArgs[0]; Type elementType = gradOutput.getType(); Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value constantZero = b.create(loc, FloatAttr::get(elementType, 0.0)); Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType); Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType); Value lesser = b.create(loc, arith::CmpFPredicate::ULT, self, min); Value greater = b.create(loc, arith::CmpFPredicate::UGT, self, max); Value cmp = b.create(loc, lesser, greater); return b.create(loc, cmp, constantZero, gradOutput); } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); Type resultElementType = cast(add.getType()).getDtype(); Type dtype = cast(converter->convertType(add.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto sub = dyn_cast(op)) { AtenSubTensorOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(sub.getType())) .getElementType(); Type resultElementType = cast(sub.getType()).getDtype(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType, /*originalScalar=*/sub.getAlpha()); if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); } } if (auto subScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(subScalar.getType())) .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } subScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto addScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(addScalar.getType())) .getElementType(); Type resultElementType = cast(addScalar.getType()).getDtype(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value other = convertScalarToDtype(b, loc, operands[1], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, operands[2], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } else if (isa(dtype)) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); } addScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mul = dyn_cast(op)) { AtenMulTensorOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(mul.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (isa(dtype)) { return b.create(loc, lhs, rhs); } else if (isa(dtype)) { return b.create(loc, lhs, rhs); } else { return b.create(loc, lhs, rhs); } } if (auto atan2 = dyn_cast(op)) { Type dtype = cast(converter->convertType(atan2.getType())) .getElementType(); if (!isa(dtype)) { atan2.emitError("Atan2 requires floating point result type"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); } if (auto leTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]); } if (auto gtTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]); } if (auto geTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]); } if (auto neTensor = dyn_cast(op)) { return createCompareTensorOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(div.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (isa(dtype)) return b.create(loc, lhs, rhs); else if (isa(dtype)) { if (dtype.isUnsignedInteger()) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } div.emitError("unimplemented: non-floating point and non-integer dtype"); return nullptr; } if (auto divScalarMode = dyn_cast(op)) { return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode, operands); } if (auto divTensorMode = dyn_cast(op)) { return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode, operands); } if (auto pow = dyn_cast(op)) { Type dtype = cast(pow.getType()).getDtype(); if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype); Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype); return b.create(loc, selfPromoted, expPromoted); } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() .getDtype() .isa()) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Type dtype = cast(pow.getSelf().getType()).getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, payloadArgs[0], expPromoted); } if (auto pow = dyn_cast(op)) { Type dtype = cast(converter->convertType(pow.getType())) .getElementType(); if (!isa(dtype)) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto imag = dyn_cast(op)) { Type dtype = cast(converter->convertType(imag.getType())) .getElementType(); if (!isa(dtype)) { imag.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value imagVal = b.create(loc, payloadArgs[0]); return imagVal; } if (auto real = dyn_cast(op)) { Type dtype = cast(converter->convertType(real.getType())) .getElementType(); if (!isa(dtype)) { real.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value realVal = b.create(loc, payloadArgs[0]); return realVal; } if (auto gtScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]); } if (auto geScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]); } if (auto eqScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]); } if (auto neScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]); } if (auto ltScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]); } if (auto leScalar = dyn_cast(op)) { return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]); } if (auto whereSelf = dyn_cast(op)) { Type dtype = cast(converter->convertType(whereSelf.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, payloadArgs[0], lhs, rhs); } if (auto lerp = dyn_cast(op)) { if (!lerp.getType() .cast() .getDtype() .isa()) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } AtenLerpTensorOp::Adaptor adaptor(payloadArgs); auto start = adaptor.getSelf(); auto end = adaptor.getEnd(); auto weight = adaptor.getWeight(); auto delta = b.create(loc, end, start); auto weightedDelta = b.create(loc, delta, weight); return b.create(loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { Type dtype = cast(minimum.getType()).getDtype(); Type elemTy = converter->convertType(minimum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { Type dtype = cast(maximum.getType()).getDtype(); Type elemTy = converter->convertType(maximum.getType()) .cast() .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); if (min.getType().isa() || max.getType().isa()) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } Type dtype = cast(converter->convertType(clamp.getType())) .getElementType(); if (!isa(dtype)) { clamp.emitError("unimplement type for clamp"); return nullptr; } Type dstOriginalDtype = cast(clamp.getType()).getDtype(); bool isUnsigned = isa(dstOriginalDtype); if (auto intTy = dyn_cast(dstOriginalDtype)) { isUnsigned = intTy.isUnsigned(); } auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { clamp = convertScalarToDtype(b, loc, clamp, dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/dstOriginalDtype); Value pred; if (isa(dtype)) { auto cmp = getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; pred = b.create(loc, cmp, input, clamp); } else if (isa(dtype)) { auto cmp = isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; if (getMax) cmp = arith::invertPredicate(cmp); pred = b.create(loc, cmp, input, clamp); } return b.create(loc, pred, clamp, input); }; auto result = payloadArgs[0]; if (!min.getType().isa()) result = cmpSelect(result, min, /*getMax=*/false); if (!max.getType().isa()) result = cmpSelect(result, max, /*getMax=*/true); return result; } if (auto clampTensor = dyn_cast(op)) { AtenClampTensorOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); if (min.getType().isa() || max.getType().isa()) { clampTensor.emitError("unimplemented: runtime optional type"); return nullptr; } Type dtype = cast(converter->convertType(clampTensor.getType())) .getElementType(); bool isMinNone = true; auto result = payloadArgs[0]; if (!min.getType().isa()) { isMinNone = false; auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value pred; if (isa(dtype)) { pred = b.create(loc, arith::CmpFPredicate::ULT, result, minPromoted); } else if (isa(dtype)) { pred = b.create(loc, arith::CmpIPredicate::slt, result, minPromoted); } else { clampTensor.emitError( "unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { max = isMinNone ? payloadArgs[1] : payloadArgs[2]; auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); Value pred; if (isa(dtype)) { pred = b.create(loc, arith::CmpFPredicate::UGT, result, maxPromoted); } else if (isa(dtype)) { pred = b.create(loc, arith::CmpIPredicate::sgt, result, maxPromoted); } else { clampTensor.emitError( "unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } result = b.create(loc, pred, maxPromoted, result); } return result; } if (auto rsub = dyn_cast(op)) { Type dtype = cast(converter->convertType(rsub.getType())) .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value alpha = convertScalarToDtype( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); if (isa(dtype)) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } else if (isa(dtype)) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } rsub.emitError("unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } if (auto mulScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(mulScalar.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (isa(dtype)) return b.create(loc, lhs, rhs); if (isa(dtype)) return b.create(loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; Type dtype = cast(converter->convertType(atenToDtype.getType())) .getElementType(); Type resultElementType; int64_t dtypeInt; if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) { atenToDtype.emitError("unimplemented: dtype must be a constant integer"); return nullptr; } FailureOr maybeResultElementType = torch_to_linalg::getBackendTypeForScalarType( atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { atenToDtype.emitError("unable to convert `dtypeInt` to builtin type"); return nullptr; } resultElementType = *maybeResultElementType; Value result = convertScalarToDtype(b, loc, input, dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); return result; } if (auto divScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(divScalar.getType())) .getElementType(); if (!isa(dtype)) { divScalar.emitError("unimplemented: non-floating point dtype"); return nullptr; } Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); return b.create(loc, self, other); } if (auto remScalar = dyn_cast(op)) { Type newResultType = converter->convertType(remScalar.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType); Value result; if (isa(newResultType)) { result = b.create(loc, self, other); } else if (isa(newResultType)) { result = b.create(loc, self, other); } else { remScalar.emitError( "Unsupported type encountered for AtenRemainderScalarOp."); } return result; } if (auto remTensor = dyn_cast(op)) { Type newResultType = converter->convertType(remTensor.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value result; if (isa(newResultType)) { result = b.create(loc, self, other); } else if (isa(newResultType)) { result = b.create(loc, self, other); } else { remTensor.emitError( "Unsupported type encountered for AtenRemainderTensorOp."); } return result; } if (auto fmod = dyn_cast(op)) { Type newResultType = converter->convertType(fmod.getType()) .cast() .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); Value result; if (isa(newResultType)) { Value n = b.create(loc, self, other); n = b.create(loc, n); Value n_y = b.create(loc, n, other); result = b.create(loc, self, n_y); } else if (isa(newResultType)) { Value n = b.create(loc, self, other); Value n_y = b.create(loc, n, other); result = b.create(loc, self, n_y); } else { fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); } return result; } if (auto reciprocal = dyn_cast(op)) { Type dtype = cast(converter->convertType(reciprocal.getType())) .getElementType(); Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Type elementType = arg.getType(); // assert(element != 0) auto zero = b.create(loc, FloatAttr::get(elementType, 0.0)); auto pred = b.create(loc, arith::CmpFPredicate::ONE, arg, zero); b.create( loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); auto one = b.create(loc, FloatAttr::get(elementType, 1.0)); return b.create(loc, one, arg); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? value : self AtenThresholdOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(thresholdOp.getType())) .getElementType(); Value self = payloadArgs[0]; Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); Value predicate; if (isa(dtype)) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, value, self); } if (auto thresholdBackward = dyn_cast(op)) { // The approach used here is as follows: // result = self <= threshold ? 0 : grad AtenThresholdBackwardOp::Adaptor adaptor(operands); Type dtype = cast( converter->convertType(thresholdBackward.getType())) .getElementType(); Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); Value constantZero = b.create(loc, b.getZeroAttr(dtype)); Value predicate; if (isa(dtype)) predicate = b.create(loc, arith::CmpFPredicate::ULE, self, threshold); else predicate = b.create(loc, arith::CmpIPredicate::sle, self, threshold); return b.create(loc, predicate, constantZero, grad); } if (auto fillScalar = dyn_cast(op)) { AtenFillScalarOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(fillScalar.getType())) .getElementType(); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); } if (auto maskedFillTensor = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); Type dtype = cast( converter->convertType(maskedFillTensor.getType())) .getElementType(); Value input = payloadArgs[0]; Value mask = payloadArgs[1]; Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); return b.create(loc, mask, fillValue, input); } if (auto fillTensor = dyn_cast(op)) { AtenFillTensorOp::Adaptor adaptor(operands); Type dtype = cast(converter->convertType(fillTensor.getType())) .getElementType(); return convertScalarToDtype(b, loc, payloadArgs[1], dtype); } if (auto triu = dyn_cast(op)) { Value result; if (failed(createTriangularMatrix( b, loc, payloadArgs, op, operands, result))) return nullptr; return result; } if (auto tril = dyn_cast(op)) { Value result; if (failed(createTriangularMatrix( b, loc, payloadArgs, op, operands, result))) return nullptr; return result; } if (auto bitwiseNot = dyn_cast(op)) { Type elementType = converter->convertType(bitwiseNot.getType()) .cast() .getElementType(); if (isa(elementType)) { bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); return nullptr; } Value allOnesVal = b.create( loc, b.getIntegerAttr( elementType, APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); return b.create(loc, payloadArgs[0], allOnesVal); } if (isa(op)) { auto value = payloadArgs[0]; auto valueTy = value.getType(); auto qtensor = op->getOperand(0); auto qtensorTy = cast(qtensor.getType()).getDtype(); Value zp, scale; if (auto makeQTensor = qtensor.getDefiningOp()) { zp = makeQTensor.getZeroPoint(); scale = makeQTensor.getScale(); } if (auto quant = qtensor.getDefiningOp()) { zp = quant.getZeroPoint(); scale = quant.getScale(); } if (!zp || !scale) { return nullptr; } auto outFpTy = payloadArgs[1].getType(); auto outBw = outFpTy.getIntOrFloatBitWidth(); auto outIntTy = b.getIntegerType(outBw); if (valueTy != outIntTy) { if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { value = b.create(loc, outIntTy, value); } else { value = b.create(loc, outIntTy, value); } } zp = converter->materializeTargetConversion( b, loc, converter->convertType(zp.getType()), zp); auto zpTy = zp.getType(); if (zpTy != outIntTy) { zp = b.create(loc, outIntTy, zp); } value = b.create(loc, value, zp); // treat the i32 as a signed int regardless of original signed-ness // this will prevent overflow from subtraction for unsigned quantizations. value = b.create(loc, outFpTy, value); scale = converter->materializeTargetConversion( b, loc, converter->convertType(scale.getType()), scale); if (scale.getType() != value.getType()) { scale = b.create(loc, value.getType(), scale); } value = b.create(loc, value, scale); return value; } if (auto quant = dyn_cast(op)) { Value value = payloadArgs[0]; Value scale = quant.getScale(); Value zp = quant.getZeroPoint(); auto valueTy = value.getType(); zp = converter->materializeTargetConversion( b, loc, converter->convertType(zp.getType()), zp); zp = b.create(loc, valueTy, zp); scale = converter->materializeTargetConversion( b, loc, converter->convertType(scale.getType()), scale); scale = b.create(loc, valueTy, scale); value = b.create(loc, value, scale); value = b.create(loc, value); value = b.create(loc, value, zp); auto destTy = payloadArgs[1].getType(); auto bitwidth = destTy.getIntOrFloatBitWidth(); bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType()); APInt min = isUnsigned ? APInt::getMinValue(bitwidth) : APInt::getSignedMinValue(bitwidth); APInt max = isUnsigned ? APInt::getMaxValue(bitwidth) : APInt::getSignedMaxValue(bitwidth); double minI = isUnsigned ? static_cast(min.getZExtValue()) : static_cast(min.getSExtValue()); double maxI = isUnsigned ? static_cast(max.getZExtValue()) : static_cast(max.getSExtValue()); Value minVal = b.create(loc, b.getFloatAttr(valueTy, minI)); Value maxVal = b.create(loc, b.getFloatAttr(valueTy, maxI)); Value minCmp = b.create(loc, arith::CmpFPredicate::ULT, value, minVal); Value maxCmp = b.create(loc, arith::CmpFPredicate::UGT, value, maxVal); value = b.create(loc, minCmp, minVal, value); value = b.create(loc, maxCmp, maxVal, value); if (isUnsigned) { value = b.create(loc, destTy, value); } else { value = b.create(loc, destTy, value); } return value; } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; } namespace { // Converts an elementwise op. // This specifically includes: // - converting elementwise ops of any tensor arity // - converting elementwise ops with any number of scalar captures (such as a // scalar alpha to torch.aten.Add) // - broadcasting of static size-1 dimensions // // Currently, we adopt the behavior that "size 1" broadcasting is a runtime // error if it happens dynamically. // // Looking forward a bit, eventually, it probably makes sense to have // a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted // operands. Modeling elementwise ops that way is potentially useful to allow a // more centralized reasoning about multiversioning. However a cost model will // be needed for "pre-fusing" elementwise ops that way, as it can potentially be // a pessimization. A mild extension of this pattern should work for such a // general op. class ConvertElementwiseOp : public ConversionPattern { public: ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( operands, [](Value v) { return v.getType().isa(); })); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( b, loc, getTypeConverter(), payloadArgs, op, operands); if (!result) { hadErrorCreatingPayload = true; return; } b.create(loc, result); }); if (hadErrorCreatingPayload) return failure(); rewriter.replaceOpWithNewOp(op, resultType, generic); return success(); } }; } // namespace // Given `input`, `target`, `nll_loss_forward` is given by: // for i in range(0, len(target)): // indi = target[i]; // nll_loss_forward[i] = -(input[i][indi]); // TODO: `weight`operand is still to be taken care of. namespace { class ConvertAtenNllLossForwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); int64_t reduction; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. if (!weight.getType().isa()) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); Value ignoreIndex = adaptor.getIgnoreIndex(); Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex); unsigned inputRank = cast(input.getType()).getRank(); unsigned targetRank = cast(target.getType()).getRank(); // TODO: Add support for k-dim loss. if (inputRank > 2) { return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value indTarget = rewriter.create( loc, rewriter.getIndexType(), targetVal); // The final result is given by: // final_res = (indTarget == ignoreIndexVal) ? 0 : // input[indI][IndTarget] Value cmpEq = rewriter.create( loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); SmallVector extractionIndices{indTarget}; if (inputRank == 2) { Value indI = rewriter.create(loc, 0); extractionIndices.insert(extractionIndices.begin(), indI); } Value result = rewriter.create(loc, input, extractionIndices); Value negate = rewriter.create(loc, elementType, result); Value selectFinal = rewriter.create(loc, cmpEq, zeroVal, negate); b.create(loc, selectFinal); }); llvm::iota_range dimsToReduce(0, targetRank, /*inclusive=*/false); DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { Value numOfElems = getTensorSize(rewriter, loc, finalRes); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; finalRes = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, /*initElem=*/zeroVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value newVal = args[0]; Value accumulator = args[1]; if (reduction == torch_upstream::Reduction::Mean) newVal = b.create(loc, newVal, numOfElems); Value result = b.create(loc, newVal, accumulator); b.create(loc, result); }); } // The implementation for the `total_weight` has been adopted from here: // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L154-L294 // As per the ref link, the `total_weight` value when the `weight` is // `None`, is equal to `total_weight = batch_size - num_ignored_index`, // where `batch_size` is equal to `target.shape[0]` when rank(target) > 0, // otherwise 1. The value `num_ignored_index` is the number of elements of // the `target` tensors that have been ignored. if (reduction == torch_upstream::Reduction::None && inputRank == 2) { Value totalWeight = createZeroInitTensor(rewriter, loc, {}, elementType); rewriter.replaceOp(op, {finalRes, totalWeight}); return success(); } Value numIgnoredIndex; if (targetRank == 0) { Value targetVal = rewriter.create(loc, target); numIgnoredIndex = rewriter.create( loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); numIgnoredIndex = convertScalarToDtype(rewriter, loc, numIgnoredIndex, ignoreIndex.getType()); } else { Value zeroCstInt = rewriter.create( loc, rewriter.getZeroAttr(ignoreIndex.getType())); auto opInfo = torch_to_linalg::ReductionOpInfo{/*keepDim=*/false, target, dimSet}; numIgnoredIndex = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, /*initElem=*/zeroCstInt, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value accumulator = args[1]; Value result = b.create( loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); result = b.create( loc, convertScalarToDtype(rewriter, loc, result, ignoreIndex.getType()), accumulator); b.create(loc, result); }); numIgnoredIndex = rewriter.create(loc, numIgnoredIndex); } Value numtargetElems = getTensorSize(rewriter, loc, target); Value totalWeightVal = rewriter.create(loc, numtargetElems, numIgnoredIndex); Value totalWeight = createInitTensor( rewriter, loc, {}, elementType, convertScalarToDtype(rewriter, loc, totalWeightVal, elementType)); rewriter.replaceOp(op, {finalRes, totalWeight}); return success(); } }; } // namespace /// Inverted STD: rSTD = 1 / sqrt(var + eps). static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps, Value var) { // The eps is always f64. Value truncatedEps = b.create(loc, elemTy, eps); Value varPlusEps = b.create(loc, var, truncatedEps); Value rSTD = b.create(loc, varPlusEps); return rSTD; } // Normalization formula: // ((input - mean) * rSTD * weight + bias static Value createLinalgPayloadCalculationForNormOpsWithRSTD( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value rSTD, Value eps, Value weight, Value bias) { Value inputSubMean = b.create(loc, input, mean); Value temp = b.create(loc, inputSubMean, rSTD); Value timesWeight = b.create(loc, temp, weight); Value plusBias = b.create(loc, timesWeight, bias); return plusBias; } static Value createLinalgPayloadCalculationForNormOpsWithVar( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var, Value eps, Value weight, Value bias) { Value rSTD = calculateRSTD(b, loc, elemTy, eps, var); Value result = createLinalgPayloadCalculationForNormOpsWithRSTD( b, loc, elemTy, input, mean, rSTD, eps, weight, bias); return result; } namespace { class ConvertAtenBatchNormOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.getInput(); Value weight = adaptor.getWeight(); Value bias = adaptor.getBias(); Value runningMean = adaptor.getRunningMean(); Value runningVar = adaptor.getRunningVar(); Value training = adaptor.getTraining(); Value eps = adaptor.getEps(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); // TODO: Handle the None cases for the optional parameters: // weight, bias. if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias)) || failed(checkNotNone(rewriter, op, runningMean)) || failed(checkNotNone(rewriter, op, runningVar))) return failure(); auto inputType = cast(input.getType()); auto weightType = cast(weight.getType()); auto biasType = cast(bias.getType()); auto runningMeanType = cast(runningMean.getType()); auto runningVarType = cast(runningVar.getType()); auto inputRank = inputType.getRank(); if (inputRank < 2) return rewriter.notifyMatchFailure( op, "input should have rank larger than 1"); if (weightType.getRank() != 1 || biasType.getRank() != 1 || runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } // TODO: Add support for training. auto constFalse = rewriter.create( loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); auto trainingFalse = rewriter.create( loc, arith::CmpIPredicate::eq, training, constFalse); rewriter.create( loc, trainingFalse, rewriter.getStringAttr("training is not supported for now")); // num_features – C from an expected input of size (N,C,D,H,W ...) Value numFeatures = rewriter.create(loc, input, 1); auto contractingDim0EqualsNumFeatures = [&](Value v) { auto dim0 = rewriter.create(loc, v, 0); auto dim0Equal = rewriter.create( loc, arith::CmpIPredicate::eq, numFeatures, dim0); rewriter.create( loc, dim0Equal, rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; if (!isAssumingStrictSymbolicShapes(rewriter)) { contractingDim0EqualsNumFeatures(weight); contractingDim0EqualsNumFeatures(bias); contractingDim0EqualsNumFeatures(runningMean); contractingDim0EqualsNumFeatures(runningVar); } auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(inputRank), // input indexingMap, // weight indexingMap, // bias indexingMap, // runningMean indexingMap, // runningVar rewriter.getMultiDimIdentityMap(inputRank), // output }; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value batchNorm = rewriter .create( loc, input.getType(), ValueRange{input, weight, bias, runningMean, runningVar}, input, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0], weight = args[1], bias = args[2], mean = args[3], var = args[4]; Value result = createLinalgPayloadCalculationForNormOpsWithVar( b, loc, var.getType(), input, mean, var, eps, weight, bias); b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); return success(); } }; } // namespace namespace { class ConvertAtenNllLossBackwardOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value gradOutput = adaptor.getGradOutput(); Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); bool weightIsNone = op.getWeight().getType().isa(); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value totalWeight = adaptor.getTotalWeight(); auto inputType = cast(input.getType()); int inputRank = inputType.getRank(); auto gradOutputType = cast(gradOutput.getType()); Type resultElementType = gradOutputType.getElementType(); int64_t reduction; if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) return rewriter.notifyMatchFailure(op, "dim must be constant"); if (!hasElementType(gradOutput) || !hasElementType(gradOutput) || (!weightIsNone && !hasElementType(weight))) { return rewriter.notifyMatchFailure( op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of " "type float"); } if (!hasElementType(target)) { return rewriter.notifyMatchFailure( op, "`target` must be a tensor of integer type"); } auto outputSize = getTensorSizes(rewriter, loc, input); Value gradInputTensor = createZeroInitTensor(rewriter, loc, outputSize, resultElementType); auto getAffineMapForSingleElementTensor = [&](Value tensor) { auto tensorType = cast(tensor.getType()); SmallVector affineExprs(tensorType.getRank(), rewriter.getAffineConstantExpr(0)); return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs, op->getContext()); }; AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (reduction != torch_upstream::Reduction::None || inputRank == 1) gradOutMap = getAffineMapForSingleElementTensor(gradOutput); AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(0)); if (inputRank == 1) targetMap = getAffineMapForSingleElementTensor(target); AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight); AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector indexingMaps{gradOutMap, targetMap, totalWeightMap, resultMap}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); // The code generation is equivalent to the following pseudo-code: // // for batch_index in len(input.size(0)): // for class_index in len(input.size(1)): // target_elem = target[batch_index] // // if reduction == None: // grad_out_elem = grad_output[batchIndex] // else: // grad_out_elem = grad_output[0] // // if reduction == Mean: // total_weight_elem = total_weight[0] // grad_out_elem /= total_weight_elem // // weight_elem = weight[target_elem] if weight != None else 1 // // if target_elem != class_index or target_elem == ignore_index: // grad_input_elem = -weight_elem * grad_out_elem // else: // grad_input_elem = 0 // grad_input[batch_index, target_elem] = grad_input_elem // // NOTE: In the case of not batch dimension, `batch_index` essentially // becomes zero. Value gradInput = rewriter .create( loc, gradInputTensor.getType(), ValueRange{gradOutput, target, totalWeight}, gradInputTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value gradOutElem = args[0]; Value targetElem = castIntToIndex(b, loc, args[1]); Value totalWeightElem = args[2]; Value classIndex = b.create(loc, inputRank - 1); if (reduction == torch_upstream::Reduction::Mean) { gradOutElem = b.create(loc, gradOutElem, totalWeightElem); } Value negGradOutElem = b.create(loc, gradOutElem); Value weightElem = getConstant(b, loc, 1, resultElementType); if (!weightIsNone) { weightElem = b.create(loc, weight, targetElem); } Value weightedNegGradOutElem = b.create(loc, weightElem, negGradOutElem); Value targetNeqClassIndex = b.create( loc, arith::CmpIPredicate::ne, targetElem, classIndex); Value targetEqIgnoreIndex = b.create( loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); Value gradInputIsZero = b.create( loc, targetNeqClassIndex, targetEqIgnoreIndex); Value zero = getConstant(b, loc, 0, resultElementType); Value gradInElem = b.create( loc, gradInputIsZero, zero, weightedNegGradOutElem); b.create(loc, gradInElem); }) ->getResult(0); RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } }; } // namespace namespace { class ConvertAtenDetachOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDetachOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Type resultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertPrimsSplitDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrimsSplitDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); auto aRankedTensorType = cast(adaptor.getA().getType()); const TypeConverter *typeConverter = getTypeConverter(); auto resultRankedTensorType = cast(typeConverter->convertType(op.getType())); // The dimension being split must be statically known. int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) return failure(); SmallVector associations; associations.reserve(aRankedTensorType.getRank()); for (unsigned i = 0; i < dimInt; ++i) { associations.push_back(ReassociationIndices{i}); } associations.push_back(ReassociationIndices{dimInt, dimInt + 1}); for (int i = dimInt + 2; i < resultRankedTensorType.getRank(); ++i) { associations.push_back(ReassociationIndices{i}); } auto expanded = rewriter.createOrFold( op.getLoc(), resultRankedTensorType, adaptor.getA(), associations); rewriter.replaceOpWithNewOp(op, resultRankedTensorType, expanded); return success(); } }; } // namespace namespace { class ConvertPrimsCollapseOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrimsCollapseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); auto aRankedTensorType = cast(adaptor.getA().getType()); const TypeConverter *typeConverter = getTypeConverter(); auto resultRankedTensorType = cast(typeConverter->convertType(op.getType())); // Collapse range must be statically known. int64_t startInt; if (!matchPattern(op.getStart(), m_TorchConstantInt(&startInt))) return failure(); int64_t endInt; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&endInt))) return failure(); // Upstream MLIR is overly strict -- it fails verification if the // collapse_shape is the identity op (i.e. when no dimensions are // collapsed). We manually fold this case here. if (startInt == endInt) { rewriter.replaceOp(op, adaptor.getA()); return success(); } SmallVector associations; associations.reserve(resultRankedTensorType.getRank()); // An example of is where input shape is [3,4,5,6] and // start = 1, and end = 2. The collapsed shape is then [3,4*5,6], // with reassociation indices of [0], [1,2], and [3]. // Append the singleton dimensions before the collapsed dimensions. for (unsigned i = 0; i < startInt; ++i) { associations.push_back(ReassociationIndices{i}); } // Append the collapsed dimensions. ReassociationIndices collapseDims(endInt + 1 - startInt); std::iota(collapseDims.begin(), collapseDims.end(), startInt); associations.push_back(collapseDims); // Append the singleton dimensions after the collapsed dimensions. for (int i = endInt + 1; i < aRankedTensorType.getRank(); ++i) { associations.push_back(ReassociationIndices{i}); } rewriter.replaceOpWithNewOp( op, resultRankedTensorType, adaptor.getA(), associations); return success(); } }; } // namespace namespace { class ConvertTensorStaticInfoCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand()); return success(); } }; } // namespace namespace { class ConvertLogitOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value eps = adaptor.getEps(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); bool handleEps = false; if (succeeded(checkNotNone(rewriter, op, eps))) handleEps = true; if (handleEps && !eps.getType().isa()) { op.emitError("Logit does not support non-floating point type"); return failure(); } auto inputType = cast(input.getType()); auto inputElementType = inputType.getElementType(); if (!isa(inputElementType)) { op.emitError("Logit does not support non-floating point type"); return failure(); } auto inputRank = inputType.getRank(); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(inputRank), // input rewriter.getMultiDimIdentityMap(inputRank), // output }; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value logit = rewriter .create( loc, input.getType(), /*ins=*/input, /*outs=*/input, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value input = args[0]; TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); Value oneValue = b.create(loc, oneAttr); Value zI; if (!handleEps) { zI = input; } else { Value truncEps = b.create(loc, inputElementType, eps); Value oneMinusEps = b.create(loc, oneValue, truncEps); Value min = b.create(loc, input, oneMinusEps); Value clampedInput = b.create(loc, min, truncEps); zI = clampedInput; } Value probability = b.create(loc, oneValue, zI); Value odds = b.create(loc, zI, probability); Value result = b.create(loc, odds); b.create(loc, result); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, logit); return success(); } }; } // namespace namespace { class ConvertAtenIntReprOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { RankedTensorType resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertDequantizePerChannel : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto qoperand = op.getOperand(); auto make = qoperand.getDefiningOp(); if (!make) { return rewriter.notifyMatchFailure(op, "did not find per channel qint"); } auto converter = getTypeConverter(); auto operand = make.getOperand(0); auto scale = make.getScale(); auto zeropoint = make.getZeroPoint(); auto axis = make.getAxis(); IntegerAttr axisAttr; if (!matchPattern(axis, m_Constant(&axisAttr))) { return failure(); } auto operandDTy = cast(operand.getType()).getDtype(); auto zeropointDTy = cast(zeropoint.getType()).getDtype(); operand = converter->materializeTargetConversion( rewriter, loc, converter->convertType(operand.getType()), operand); scale = converter->materializeTargetConversion( rewriter, loc, converter->convertType(scale.getType()), scale); zeropoint = converter->materializeTargetConversion( rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); auto resultType = converter->convertType(op->getResult(0).getType()) .cast(); llvm::SmallVector dynSizes; for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { if (ShapedType::isDynamic(dim)) { dynSizes.push_back(rewriter.create(loc, operand, index)); } } llvm::SmallVector iterators( resultType.getRank(), utils::IteratorType::parallel); llvm::SmallVector maps( 4, {rewriter.getMultiDimIdentityMap(resultType.getRank())}); auto broadcastMap = AffineMap::get( resultType.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext()); maps[1] = broadcastMap; maps[2] = broadcastMap; auto empty = rewriter.create(op.getLoc(), resultType, dynSizes); auto linalgOp = rewriter.create( loc, resultType, ValueRange{operand, scale, zeropoint}, ValueRange{empty}, maps, iterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value operand = args[0]; Value scale = args[1]; Value zeropoint = args[2]; if (operandDTy.isUnsignedInteger(8)) { operand = b.create(loc, b.getI32Type(), operand); } else if (operandDTy.isSignedInteger(8)) { operand = b.create(loc, b.getI32Type(), operand); } if (zeropointDTy.isUnsignedInteger(8)) { zeropoint = b.create(loc, b.getI32Type(), zeropoint); } else if (zeropointDTy.isSignedInteger(8)) { zeropoint = b.create(loc, b.getI32Type(), zeropoint); } Value sub = rewriter.create(loc, operand, zeropoint); Value fp = rewriter.create(loc, args[3].getType(), sub); Value mul = rewriter.create(loc, fp, scale); b.create(loc, mul); }); rewriter.replaceOp(op, linalgOp.getResults()); return success(); } }; } // namespace namespace { template class ConvertCastEquivalentOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto converter = this->getTypeConverter(); RankedTensorType resultType = cast( converter->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); } }; } // namespace namespace { class ConvertAtenGridSamplerOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Type int64type = rewriter.getI64Type(); Type floatType = rewriter.getF32Type(); Value zeroIndex = rewriter.create(loc, 0); Value oneIndex = rewriter.create(loc, 1); Value twoIndex = rewriter.create(loc, 2); Value zeroFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 0.0)); Value oneFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 1.0)); Value twoFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); auto inputType = cast(input.getType()); auto inputShape = inputType.getShape(); Value innerDim0a = rewriter.create(loc, input, 2); Value innerDim1a = rewriter.create(loc, input, 3); Value innerDim0b = rewriter.create(loc, innerDim0a, oneIndex); Value innerDim1b = rewriter.create(loc, innerDim1a, oneIndex); Value innerDim0c = rewriter.create(loc, int64type, innerDim0b); Value innerDim1c = rewriter.create(loc, int64type, innerDim1b); Value innerDim0d = rewriter.create(loc, floatType, innerDim0c); Value innerDim1d = rewriter.create(loc, floatType, innerDim1c); Value innerDim0e = rewriter.create(loc, innerDim0d, twoFloat); Value innerDim1e = rewriter.create(loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); auto gridType = cast(grid.getType()); auto gridShape = gridType.getShape(); auto gridRank = gridType.getRank(); SmallVector extractGridOffsets0(gridRank, zeroIndex); SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); SmallVector extractGridStride(gridRank, oneIndex); int64_t lastGridDim = gridRank - 1; extractGridShape[lastGridDim] = oneIndex; extractGridStride[lastGridDim] = twoIndex; SmallVector extractGridOffsets1(gridRank, zeroIndex); extractGridOffsets1[lastGridDim] = oneIndex; SmallVector gridShapeExtracted(gridShape); gridShapeExtracted.back() = 1; SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], gridShape[2]}; auto grid0 = rewriter.create( loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); auto grid1 = rewriter.create( loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); SmallVector associations{ReassociationIndices{0}, ReassociationIndices{1}, ReassociationIndices{2, 3}}; auto gridCollapsed0 = rewriter.create(loc, grid0, associations); auto gridCollapsed1 = rewriter.create(loc, grid1, associations); AffineMap gridMap = AffineMap::get(4, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3)}, op->getContext()); SmallVector gridMaps{gridMap, gridMap, rewriter.getMultiDimIdentityMap(gridRank)}; SmallVector gridIterators( gridRank, utils::IteratorType::parallel); SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], gridShape[2]}; auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, Value idxB, Value idxC, Value idxD) -> Value { SmallVector index{idxA, idxB, idxC, idxD}; Value result = b.create(loc, input, index); return result; }; auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y, Value d) -> Value { Value dm = b.create(loc, oneFloat, d); Value ra = b.create(loc, x, dm); Value rb = b.create(loc, y, d); Value res = b.create(loc, ra, rb); return res; }; auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y, Value d) -> Value { Value halfConst = rewriter.create( loc, rewriter.getFloatAttr(floatType, 0.5)); Value checkClosest = b.create(loc, arith::CmpFPredicate::OLT, d, halfConst); Value res = b.create(loc, checkClosest, x, y); return res; }; auto lambdaInterpolate = [&](OpBuilder &b, Location loc, Value iMode, Value x, Value y, Value d) -> Value { Value linear = lambdaLinear(b, loc, x, y, d); Value nearest = lambdaNearest(b, loc, x, y, d); Value zeroInt = b.create(loc, b.getIntegerAttr(int64type, 0)); Value checkMode = b.create(loc, arith::CmpIPredicate::eq, iMode, zeroInt); Value res = b.create(loc, checkMode, linear, nearest); return res; }; auto resultType = getTypeConverter() ->convertType(op.getResult().getType()) .cast(); SmallVector resultSize{}; if (resultType.isDynamicDim(0)) resultSize.push_back(rewriter.create(loc, input, 0)); if (resultType.isDynamicDim(1)) resultSize.push_back(rewriter.create(loc, input, 1)); if (resultType.isDynamicDim(2)) resultSize.push_back(rewriter.create(loc, grid, 1)); if (resultType.isDynamicDim(3)) resultSize.push_back(rewriter.create(loc, grid, 2)); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); Value resultFinal = rewriter.create(loc, resultType, resultSize); auto sGrid = rewriter.create( loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, ValueRange(resultFinal), gridMaps, gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value gr0 = args[1]; Value gr1 = args[0]; Value gr0Half = b.create(loc, gr0, twoFloat); Value gr1Half = b.create(loc, gr1, twoFloat); Value gr0HalfSelect = b.create(loc, alignCorners, zeroFloat, gr0Half); Value gr1HalfSelect = b.create(loc, alignCorners, zeroFloat, gr1Half); Value gplus0 = b.create(loc, gr0, oneFloat); Value gplus1 = b.create(loc, gr1, oneFloat); Value gPlusMul0 = b.create(loc, gplus0, innerDim0e); Value gPlusMul1 = b.create(loc, gplus1, innerDim1e); Value result0 = b.create(loc, gPlusMul0, gr0HalfSelect); Value result1 = b.create(loc, gPlusMul1, gr1HalfSelect); Value checkLowerBound0 = b.create( loc, arith::CmpFPredicate::OLT, result0, zeroFloat); Value checkLowerBound1 = b.create( loc, arith::CmpFPredicate::OLT, result1, zeroFloat); Value lowerOrig0 = b.create(loc, int64type, result0); Value lowerOrig1 = b.create(loc, int64type, result1); Value zeroInt = b.create(loc, b.getIntegerAttr(int64type, 0)); Value oneInt = b.create(loc, b.getIntegerAttr(int64type, 1)); Value lowerSub0 = b.create(loc, lowerOrig0, oneInt); Value lowerSub1 = b.create(loc, lowerOrig1, oneInt); Value lower0 = b.create(loc, checkLowerBound0, lowerSub0, lowerOrig0); Value lower1 = b.create(loc, checkLowerBound1, lowerSub1, lowerOrig1); Value lowerValid0 = b.create(loc, checkLowerBound0, zeroInt, lower0); Value lowerValid1 = b.create(loc, checkLowerBound1, zeroInt, lower1); Value upper0 = b.create(loc, int64type, lower0, oneInt); Value upper1 = b.create(loc, int64type, lower1, oneInt); Value notValidUpper0 = rewriter.create( loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); Value notValidUpper1 = rewriter.create( loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); Value upperValid0 = b.create(loc, notValidUpper0, lower0, upper0); Value upperValid1 = b.create(loc, notValidUpper1, lower1, upper1); Value lw0 = b.create(loc, b.getIndexType(), lowerValid0); Value lw1 = b.create(loc, b.getIndexType(), lowerValid1); Value up0 = b.create(loc, b.getIndexType(), upperValid0); Value up1 = b.create(loc, b.getIndexType(), upperValid1); Value N = b.create(loc, 0); Value C = b.create(loc, 1); Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); Value result00a = b.create(loc, checkLowerBound0, zeroFloat, result00); Value result00b = b.create(loc, checkLowerBound1, zeroFloat, result00a); Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); Value result01a = b.create(loc, notValidUpper1, zeroFloat, result01); Value result01b = b.create(loc, checkLowerBound0, zeroFloat, result01a); Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); Value result10a = b.create(loc, notValidUpper0, zeroFloat, result10); Value result10b = b.create(loc, checkLowerBound1, zeroFloat, result10a); Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); Value result11a = b.create(loc, notValidUpper0, zeroFloat, result11); Value result11b = b.create(loc, notValidUpper1, zeroFloat, result11a); Value lw0a = b.create(loc, floatType, lower0); Value lw1a = b.create(loc, floatType, lower1); Value d1 = b.create(loc, result0, lw0a); Value d0 = b.create(loc, result1, lw1a); Value resultScaled0 = lambdaInterpolate(b, loc, interMode, result00b, result01b, d0); Value resultScaled1 = lambdaInterpolate(b, loc, interMode, result10b, result11b, d0); Value resultScaled = lambdaInterpolate( b, loc, interMode, resultScaled0, resultScaled1, d1); b.create(loc, resultScaled); }); rewriter.replaceOp(op, sGrid.getResults()); return success(); } }; } // namespace static Value NearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes) { auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { indices.push_back(b.create(loc, i)); } for (unsigned i = 2; i < inputRank; i++) { Value outIndex = indices[i]; Value inputSizeFP = b.create(loc, b.getF32Type(), inputSizes[i - 2]); Value outputSizeFP = b.create(loc, b.getF32Type(), outputSizes[i - 2]); // scale = length_resized / length_original // x_original = x_resized / scale Value scale = b.create(loc, outputSizeFP, inputSizeFP); Value outInt = b.create(loc, b.getI64Type(), outIndex); Value outFP = b.create(loc, b.getF32Type(), outInt); Value proj = b.create(loc, outFP, scale); // get nearest pixel using floor Value nearestFP = b.create(loc, proj); Value nearestInt = b.create(loc, b.getI64Type(), nearestFP); Value nearest = b.create(loc, b.getIndexType(), nearestInt); indices[i] = nearest; } Value retVal = b.create(loc, input, indices); return retVal; } static Value BilinearInterpolate(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes) { Value inputSizeH = inputSizes[0]; Value inputSizeW = inputSizes[1]; Value outputSizeH = outputSizes[0]; Value outputSizeW = outputSizes[1]; int hDimOffset = 2; auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); Value cstOneEps = b.create(loc, b.getF32FloatAttr(1.001)); Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); Value yOut = b.create(loc, 2); Value xOut = b.create(loc, 3); bool alignCornersBool; matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); Value yProj, xProj; if (alignCornersBool) { // x_original = x_resized * (length_original - 1) / (length_resized - 1) Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); Value outputSizeHFP = b.create(loc, b.getF32Type(), outputSizeH); Value yOutInt = b.create(loc, b.getI64Type(), yOut); Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); Value inputHSubOne = b.create(loc, inputHFP, cstOneFloat); Value outputSizeHSubOne = b.create(loc, outputSizeHFP, cstOneFloat); Value hScale = b.create(loc, inputHSubOne, outputSizeHSubOne); Value yProjBeforeClamp = b.create(loc, yOutFP, hScale); Value yMax = b.create(loc, yProjBeforeClamp, zero); Value outputSizeHSubOneEps = b.create(loc, outputSizeHFP, cstOneEps); yProj = b.create(loc, outputSizeHSubOneEps, yMax); Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); Value outputSizeWFP = b.create(loc, b.getF32Type(), outputSizeW); Value xOutInt = b.create(loc, b.getI64Type(), xOut); Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); Value inputWSubOne = b.create(loc, inputWFP, cstOneFloat); Value outputSizeWSubOne = b.create(loc, outputSizeWFP, cstOneFloat); Value wScale = b.create(loc, inputWSubOne, outputSizeWSubOne); Value xProjBeforeClamp = b.create(loc, xOutFP, wScale); Value xMax = b.create(loc, xProjBeforeClamp, zero); Value outputSizeWSubOneEps = b.create(loc, outputSizeWFP, cstOneEps); xProj = b.create(loc, outputSizeWSubOneEps, xMax); } else { // y_original = (y_resized + 0.5) / scale - 0.5 Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); Value outputSizeHFP = b.create(loc, b.getF32Type(), outputSizeH); Value hScale = b.create(loc, outputSizeHFP, inputHFP); Value yOutInt = b.create(loc, b.getI64Type(), yOut); Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); Value yPlusHalf = b.create(loc, yOutFP, cstHalf); Value yDivScale = b.create(loc, yPlusHalf, hScale); Value ySubHalf = b.create(loc, yDivScale, cstHalf); Value yMax = b.create(loc, ySubHalf, zero); Value inputHSubOne = b.create(loc, inputHFP, cstOneEps); yProj = b.create(loc, yMax, inputHSubOne); Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); Value outputSizeWFP = b.create(loc, b.getF32Type(), outputSizeW); Value wScale = b.create(loc, outputSizeWFP, inputWFP); Value xOutInt = b.create(loc, b.getI64Type(), xOut); Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); Value xPlusHalf = b.create(loc, xOutFP, cstHalf); Value xDivScale = b.create(loc, xPlusHalf, wScale); Value xSubHalf = b.create(loc, xDivScale, cstHalf); // clamp Value xMax = b.create(loc, xSubHalf, zero); Value inputWSubOne = b.create(loc, inputWFP, cstOneEps); xProj = b.create(loc, xMax, inputWSubOne); } Value yLow = b.create(loc, yProj); Value yProjPlusOne = b.create(loc, cstOneFloat, yProj); Value yHigh = b.create(loc, yProjPlusOne); Value xLow = b.create(loc, xProj); Value xProjPlusOne = b.create(loc, cstOneFloat, xProj); Value xHigh = b.create(loc, xProjPlusOne); SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { indices.push_back(b.create(loc, i)); } Value yLowInt = b.create(loc, b.getI64Type(), yLow); Value yLowIdx = b.create(loc, b.getIndexType(), yLowInt); Value xLowInt = b.create(loc, b.getI64Type(), xLow); Value xLowIdx = b.create(loc, b.getIndexType(), xLowInt); Value yHighInt = b.create(loc, b.getI64Type(), yHigh); Value yHighIdx = b.create(loc, b.getIndexType(), yHighInt); Value xHighInt = b.create(loc, b.getI64Type(), xHigh); Value xHighIdx = b.create(loc, b.getIndexType(), xHighInt); indices[hDimOffset] = yLowIdx; indices[hDimOffset + 1] = xLowIdx; Value p00 = b.create(loc, input, indices); indices[hDimOffset] = yLowIdx; indices[hDimOffset + 1] = xHighIdx; Value p01 = b.create(loc, input, indices); indices[hDimOffset] = yHighIdx; indices[hDimOffset + 1] = xLowIdx; Value p10 = b.create(loc, input, indices); indices[hDimOffset] = yHighIdx; indices[hDimOffset + 1] = xHighIdx; Value p11 = b.create(loc, input, indices); // p00 p01 // p10 p11 // (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) / // (xhigh - xlow) * p01 Value xHighMinusxProj = b.create(loc, xHigh, xProj); Value xHighMinusxLow = b.create(loc, xHigh, xLow); Value w0 = b.create(loc, xHighMinusxProj, xHighMinusxLow); Value lhs = b.create(loc, w0, p00); Value xProjMinusxLow = b.create(loc, xProj, xLow); Value w1 = b.create(loc, xProjMinusxLow, xHighMinusxLow); Value rhs = b.create(loc, w1, p01); Value xInter = b.create(loc, lhs, rhs); // (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) / // (xhigh - xlow) * p11 lhs = b.create(loc, w0, p10); rhs = b.create(loc, w1, p11); Value xInter1 = b.create(loc, lhs, rhs); // (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow) // / (yhigh - ylow) * xInter1 Value yHighMinusyProj = b.create(loc, yHigh, yProj); Value yHighMinusyLow = b.create(loc, yHigh, yLow); w0 = b.create(loc, yHighMinusyProj, yHighMinusyLow); lhs = b.create(loc, w0, xInter); Value yProjMinusyLow = b.create(loc, yProj, yLow); w1 = b.create(loc, yProjMinusyLow, yHighMinusyLow); rhs = b.create(loc, w1, xInter1); Value retVal = b.create(loc, lhs, rhs); return retVal; } namespace { class ConvertInterpolateOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { std::string mode; matchPattern(op.getMode(), m_TorchConstantStr(mode)); if (mode != "bilinear" && mode != "nearest") { return failure(); } Location loc = op->getLoc(); Value input = adaptor.getInput(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); if (mode == "bilinear" && inputRank != 4) return rewriter.notifyMatchFailure( op, "cannot perform bilinear interpolation when input spatial dims != 2"); SmallVector outputSizeIntValues; SmallVector inputSizes; for (unsigned i = 2; i < inputRank; i++) { Value inputSize = getDimOp(rewriter, loc, input, 2); inputSizes.push_back(rewriter.create( loc, rewriter.getIntegerType(64), inputSize)); } if (!op.getScaleFactor().getType().isa()) { SmallVector ScaleFactorTorchFloat; if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); for (unsigned i = 0; i < inputRank - 2; i++) { Value inputSizeFP = rewriter.create( loc, rewriter.getF32Type(), inputSizes[i]); Value scale = rewriter.create( loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); Value outputSize = rewriter.create(loc, inputSizeFP, scale); outputSize = rewriter.create(loc, outputSize); outputSize = rewriter.create( loc, rewriter.getI64Type(), outputSize); outputSizeIntValues.push_back(outputSize); } } else { SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); } SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); for (unsigned i = 2; i < inputRank; i++) { dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); } Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value finalRes = rewriter .create( loc, outTensor.getType(), ValueRange{}, outTensor, /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value retVal; if (mode == "nearest") { retVal = NearestInterpolate(b, loc, outputSizeIntValues, input, inputSizes); } else if (mode == "bilinear") { retVal = BilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes); } b.create(loc, retVal); }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getResult().getType()); rewriter.replaceOpWithNewOp(op, newResultType, finalRes); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }