torch-mlir/lib/Conversion/TorchToLinalg/Uncategorized.cpp

3002 lines
129 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
#include <type_traits>
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Check if a ranked-tensor has the specified element type.
template <typename elementType> static bool hasElementType(Value tensor) {
auto tensorType = cast<RankedTensorType>(tensor.getType());
Type tensorElementType = tensorType.getElementType();
return isa<elementType>(tensorElementType);
}
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
arith::CmpIPredicate ispred>
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
Value lhs, Value rhs) {
if (isa<mlir::FloatType>(type))
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(type)) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
assert(intType.getWidth() == 1);
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
}
llvm_unreachable("Unhandled element type for comparison");
}
static Value getZeroPoint(Value value) {
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
return make.getZeroPoint();
}
return nullptr;
}
static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OGT,
arith::CmpIPredicate::ugt,
arith::CmpIPredicate::sgt>(
b, loc, elementalType, lhs, rhs);
}
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs,
Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OGE,
arith::CmpIPredicate::uge,
arith::CmpIPredicate::sge>(
b, loc, elementalType, lhs, rhs);
}
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OLT,
arith::CmpIPredicate::ult,
arith::CmpIPredicate::slt>(
b, loc, elementalType, lhs, rhs);
}
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OLE,
arith::CmpIPredicate::ule,
arith::CmpIPredicate::sle>(
b, loc, elementalType, lhs, rhs);
}
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OEQ,
arith::CmpIPredicate::eq,
arith::CmpIPredicate::eq>(
b, loc, elementalType, lhs, rhs);
}
static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UNE,
arith::CmpIPredicate::ne,
arith::CmpIPredicate::ne>(
b, loc, elementalType, lhs, rhs);
}
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
Value sigma) {
Type elementType = x.getType();
Value xMinusMean = b.create<arith::SubFOp>(loc, x, mean);
Value two = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 2));
Value sqrt2 = b.create<math::SqrtOp>(loc, two);
Value erfArg = b.create<arith::DivFOp>(loc, xMinusMean, sqrt2);
Value erf = b.create<math::ErfOp>(loc, erfArg);
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
Value erfPlus1 = b.create<arith::AddFOp>(loc, one, erf);
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value normalCdf = b.create<arith::MulFOp>(loc, oneHalf, erfPlus1);
return normalCdf;
}
static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
Type elementType = x.getType();
Value zero = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
return buildNormalCdf(b, loc, x, zero, one);
}
template <typename MathOpTy>
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
Value payloadArg, Operation *op) {
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy =
cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Location loc = op->getLoc();
Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
auto newOp = b.create<MathOpTy>(loc, arg);
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}
template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>() ||
std::is_same<OpTy, AtenNeTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if (lhsDtype.isIntOrFloat() && rhsDtype.isIntOrFloat()) {
if (isa<mlir::FloatType>(lhsDtype) && isa<mlir::IntegerType>(rhsDtype)) {
rhs = convertScalarToDtype(b, loc, rhs, lhsDtype);
elementalType = lhsDtype;
} else if (isa<mlir::IntegerType>(lhsDtype) &&
isa<mlir::FloatType>(rhsDtype)) {
lhs = convertScalarToDtype(b, loc, lhs, rhsDtype);
elementalType = rhsDtype;
} else {
// Both are either Integer or Float types, but the bit width might be
// different.
if (lhsDtype.getIntOrFloatBitWidth() > rhsDtype.getIntOrFloatBitWidth()) {
rhs = convertScalarToDtype(b, loc, rhs, lhsDtype);
} else {
lhs = convertScalarToDtype(b, loc, lhs, rhsDtype);
}
}
} else {
op.emitError("unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <arith::CmpIPredicate predicate>
static LogicalResult
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
Operation *op, ArrayRef<Value> operands, Value &result) {
auto inputType = cast<RankedTensorType>(operands[0].getType());
uint64_t inputRank = inputType.getRank();
// Use the indices of the two innermost dimensions.
auto rowIndex = b.create<linalg::IndexOp>(loc, inputRank - 2);
Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex);
auto colIndex = b.create<linalg::IndexOp>(loc, inputRank - 1);
Value colIndexI64 = castIndexToInt64(b, loc, colIndex);
// columnIndex >= rowIndex + diagonal?
auto sum =
b.create<arith::AddIOp>(loc, rowIndexI64, /*diagonal=*/operands[1]);
auto pred = b.create<arith::CmpIOp>(loc, predicate, colIndexI64, sum);
Value scalar = payloadArgs[0];
Type elementType = inputType.getElementType();
Value zero = getConstant(b, loc, 0, elementType);
result = b.create<arith::SelectOp>(loc, pred, scalar, zero);
return success();
}
template <typename OpT>
Value createDivModePayload(OpBuilder &b, Location loc,
const TypeConverter *converter,
ValueRange payloadArgs, OpT op,
ArrayRef<Value> operands) {
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
std::is_same_v<OpT, AtenDivScalarModeOp>,
"template type must be a tensor/scalar div mode");
typename OpT::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(
b, loc,
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
dtype);
Value quotient;
if (isa<mlir::FloatType>(dtype)) {
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
} else if (dtype.isUnsignedInteger()) {
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
} else {
assert(dtype.isInteger() &&
"dtype should be an integer (signless or signed)");
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
}
if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
return quotient;
std::string roundingMode;
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
op.emitError("only support constant str rounding mode");
return nullptr;
}
assert((roundingMode == "trunc" || roundingMode == "floor") &&
"unsupported rounding mode");
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (!isa<mlir::FloatType>(dtype)) {
// nothing to do for integers
return quotient;
}
// float
Value ceil = b.create<math::CeilOp>(loc, quotient);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
quotient, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, quotient);
if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
}
}
return quotient;
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, const TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op)) {
return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenExpm1Op>(op)) {
return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLogOp>(op)) {
return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog2Op>(op)) {
return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog10Op>(op)) {
return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog1pOp>(op)) {
return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenErfOp>(op)) {
return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSqrtOp>(op)) {
return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenRsqrtOp>(op)) {
return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenNegOp>(op)) {
return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinOp>(op)) {
return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinhOp>(op)) {
return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinOp>(op)) {
return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinhOp>(op)) {
return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCosOp>(op)) {
return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCoshOp>(op)) {
return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcosOp>(op)) {
return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcoshOp>(op)) {
return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanOp>(op)) {
return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanhOp>(op)) {
return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanOp>(op)) {
return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanhOp>(op)) {
return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
if (!isa<Torch::NoneType>(clone.getMemoryFormat().getType()) &&
(!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
clone.emitError("unimplemented: only contiguous and channels last memory "
"format is supported");
return nullptr;
}
return payloadArgs[0];
}
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseAndTensor.getType()).getDtype())) {
bitwiseAndTensor.emitError(
"Bitwise_And does not support floating point dtype");
return nullptr;
}
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseAndTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (auto bitwiseAndScalar = dyn_cast<AtenBitwiseAndScalarOp>(op)) {
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseAndScalar.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseAndScalar.emitError(
"bitwise_and.Scalar does not support non-integer input dtype.");
return nullptr;
}
Type resultElementType =
cast<BaseTensorType>(bitwiseAndScalar.getType()).getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
return b.create<arith::AndIOp>(loc, self, other);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseOrTensor.getType()).getDtype())) {
bitwiseOrTensor.emitError(
"Bitwise_Or does not support floating point dtype");
return nullptr;
}
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseOrTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs);
}
if (auto bitwiseXorTensor = dyn_cast<AtenBitwiseXorTensorOp>(op)) {
if (isa<mlir::FloatType>(
cast<ValueTensorType>(bitwiseXorTensor.getType()).getDtype())) {
bitwiseXorTensor.emitError(
"Bitwise_Xor does not support floating point dtype");
return nullptr;
}
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseXorTensor.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::XOrIOp>(loc, lhs, rhs);
}
if (auto bitwiseRightShiftTensor =
dyn_cast<AtenBitwiseRightShiftTensorOp>(op)) {
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseRightShiftTensor.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseRightShiftTensor.emitError(
"Bitwise_Right_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
}
if (auto bitwiseLeftShiftTensor =
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
Type dtype = cast<RankedTensorType>(
converter->convertType(bitwiseLeftShiftTensor.getType()))
.getElementType();
if (!isa<mlir::IntegerType>(dtype)) {
bitwiseLeftShiftTensor.emitError(
"Bitwise_Left_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShLIOp>(loc, lhs, rhs);
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
Value zero =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero);
Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero);
if (isa<AtenLogicalOrOp>(op)) {
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalAndOp>(op)) {
return b.create<arith::AndIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalXorOp>(op)) {
return b.create<arith::XOrIOp>(loc, lhsTest, rhsTest);
}
llvm_unreachable("Unknown op type");
}
if (isa<AtenLogicalNotOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value zero =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op)) {
if (isa<IntegerType>(payloadArgs[0].getType()))
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
}
if (isa<AtenIsinfOp>(op)) {
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(
loc,
b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Value arg = payloadArgs[0];
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
auto negate = b.create<arith::NegFOp>(loc, arg);
auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
auto div = b.create<arith::DivFOp>(loc, one, added);
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
Value zeroPoint = getZeroPoint(relu.getSelf());
Value arg = payloadArgs[0];
auto intType = dyn_cast<mlir::IntegerType>(arg.getType());
if (zeroPoint && !intType) {
relu.emitError("unimplemented: non-integer quantized Relu.");
return nullptr;
}
auto reluTorchType = cast<ValueTensorType>(relu.getType());
bool isUnsigned =
torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype());
if (zeroPoint) {
int64_t zeroPointInt;
int64_t width = intType.getWidth();
assert(width < 64);
int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1));
int64_t maxForIntType =
isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1;
// check for constant zero point edge-cases:
if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) {
if (zeroPointInt > maxForIntType) {
// TODO: figure out how to handle this case:
// current impl. quantizes output like input.
// If zero point > maxForIntType, ordinary relu should return 0.
// However, 0 isn't represented in such a quantization scheme.
relu.emitError(
"unimplemented: quantized relu for zero-point > max qint");
return nullptr;
}
if (zeroPointInt < minForIntType)
return arg;
}
zeroPoint = converter->materializeTargetConversion(
b, loc, converter->convertType(zeroPoint.getType()), zeroPoint);
auto minForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType));
auto maxForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType));
auto zpLtMax = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, maxForIntTypeValue);
b.create<cf::AssertOp>(
loc, zpLtMax,
b.getStringAttr("Invalid Quantization: quantized relu with "
"zero-point > max qint"));
auto zpLtMin = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, minForIntTypeValue);
zeroPoint = b.create<arith::SelectOp>(loc, zpLtMin, minForIntTypeValue,
zeroPoint);
zeroPoint = b.create<arith::TruncIOp>(loc, arg.getType(), zeroPoint);
} else {
zeroPoint =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(arg.getType()));
}
Value cmp;
if (intType) {
auto pred =
isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt;
cmp = b.create<arith::CmpIOp>(loc, pred, arg, zeroPoint);
} else {
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, arg,
zeroPoint);
}
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
}
if (auto round = dyn_cast<AtenRoundOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(round.getType()).getDtype())) {
round.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
}
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(prelu.getType()).getDtype())) {
prelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
Value positivePart =
b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart =
b.create<arith::SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value scale = convertScalarToDtype(b, loc, payloadArgs[1], elementType);
Value scaledNegativePart =
b.create<arith::MulFOp>(loc, negativePart, scale);
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
}
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(gelu.getType()).getDtype())) {
gelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) {
gelu.emitError(
"unimplemented: expected approximate to be a constant str");
return nullptr;
}
if (approximate == "none") {
Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
if (approximate == "tanh") {
// GELU(x)=0.5x(1+Tanh((2/π)^1/2 * (x+0.044715x^3)))
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
Value cstThree = b.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3));
Value xCube = b.create<math::FPowIOp>(loc, payloadArgs[0], cstThree);
Type elementType = payloadArgs[0].getType();
Value cstAlpha = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.044715));
Value xCubeMulAlpha = b.create<arith::MulFOp>(loc, xCube, cstAlpha);
Value xPlusXCubeMulAlpha =
b.create<arith::AddFOp>(loc, payloadArgs[0], xCubeMulAlpha);
Value cstBeta = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.7977240352174656));
Value betaMulX =
b.create<arith::MulFOp>(loc, cstBeta, xPlusXCubeMulAlpha);
Value tanh = b.create<math::TanhOp>(loc, betaMulX);
Value cstOne =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
Value onePlusTanh = b.create<arith::AddFOp>(loc, cstOne, tanh);
Value cstHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value multiplier = b.create<arith::MulFOp>(loc, cstHalf, onePlusTanh);
return b.create<arith::MulFOp>(loc, payloadArgs[0], multiplier);
}
gelu.emitError("unimplemented: approximate value should be none or tanh");
return nullptr;
}
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(geluBackward.getType()).getDtype())) {
geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
// TODO: Take approximation into account.
std::string approximate;
if (!matchPattern(geluBackward.getApproximate(),
m_TorchConstantStr(approximate)) ||
approximate != "none")
return nullptr;
Type elementType = payloadArgs[1].getType();
Value cstAlpha0 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 1.12837916709551257390));
Value cstAlpha1 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.70710678118654752440));
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value kAlpha = b.create<arith::MulFOp>(loc, cstAlpha0, cstAlpha1);
Value kAlphaHalf = b.create<arith::MulFOp>(loc, kAlpha, oneHalf);
Value negOneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, -0.5));
Value inputSquared =
b.create<arith::MulFOp>(loc, payloadArgs[1], payloadArgs[1]);
Value negHalfInputSquared =
b.create<arith::MulFOp>(loc, inputSquared, negOneHalf);
Value dinput = b.create<math::ExpOp>(loc, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]);
Value dinputInput = b.create<arith::MulFOp>(loc, dinput, payloadArgs[1]);
Value dinputInputAlpha =
b.create<arith::MulFOp>(loc, dinputInput, kAlphaHalf);
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
}
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(hardtanhBackward.getType()).getDtype())) {
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value gradOutput = payloadArgs[0];
Type elementType = gradOutput.getType();
Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType);
Value constantZero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType);
Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType);
Value lesser =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, self, min);
Value greater =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, self, max);
Value cmp = b.create<arith::OrIOp>(loc, lesser, greater);
return b.create<arith::SelectOp>(loc, cmp, constantZero, gradOutput);
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
Type resultElementType = cast<BaseTensorType>(add.getType()).getDtype();
Type dtype = cast<RankedTensorType>(converter->convertType(add.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::AddIOp>(loc, lhs, scaled);
}
}
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
AtenSubTensorOp::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(sub.getType()))
.getElementType();
Type resultElementType = cast<BaseTensorType>(sub.getType()).getDtype();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType,
/*originalScalar=*/sub.getAlpha());
if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::SubIOp>(loc, lhs, scaled);
}
}
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(subScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype);
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::SubFOp>(loc, self, mult);
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::SubIOp>(loc, self, mult);
}
subScalar.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(addScalar.getType()))
.getElementType();
Type resultElementType =
cast<BaseTensorType>(addScalar.getType()).getDtype();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value other = convertScalarToDtype(b, loc, operands[1], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/resultElementType);
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::AddFOp>(loc, self, mult);
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::AddIOp>(loc, self, mult);
}
addScalar.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
AtenMulTensorOp::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(mul.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (isa<mlir::FloatType>(dtype)) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else if (isa<mlir::ComplexType>(dtype)) {
return b.create<complex::MulOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
}
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
Type dtype = cast<RankedTensorType>(converter->convertType(atan2.getType()))
.getElementType();
if (!isa<mlir::FloatType>(dtype)) {
atan2.emitError("Atan2 requires floating point result type");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(div.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (isa<mlir::FloatType>(dtype))
return b.create<arith::DivFOp>(loc, lhs, rhs);
else if (isa<mlir::IntegerType>(dtype)) {
if (dtype.isUnsignedInteger())
return b.create<arith::DivUIOp>(loc, lhs, rhs);
return b.create<arith::DivSIOp>(loc, lhs, rhs);
}
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
}
if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
operands);
}
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
operands);
}
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = cast<ValueTensorType>(pow.getType()).getDtype();
if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype);
Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
return b.create<math::PowFOp>(loc, selfPromoted, expPromoted);
}
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(pow.getType()).getDtype())) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
}
if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) {
Type dtype = cast<RankedTensorType>(converter->convertType(pow.getType()))
.getElementType();
if (!isa<mlir::FloatType>(dtype)) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::PowFOp>(loc, lhs, rhs);
}
if (auto imag = dyn_cast<AtenImagOp>(op)) {
Type dtype = cast<RankedTensorType>(converter->convertType(imag.getType()))
.getElementType();
if (!isa<mlir::FloatType>(dtype)) {
imag.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value imagVal = b.create<complex::ImOp>(loc, payloadArgs[0]);
return imagVal;
}
if (auto real = dyn_cast<AtenRealOp>(op)) {
Type dtype = cast<RankedTensorType>(converter->convertType(real.getType()))
.getElementType();
if (!isa<mlir::FloatType>(dtype)) {
real.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value realVal = b.create<complex::ReOp>(loc, payloadArgs[0]);
return realVal;
}
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(whereSelf.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<arith::SelectOp>(loc, payloadArgs[0], lhs, rhs);
}
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
if (!isa<mlir::FloatType>(
cast<ValueTensorType>(lerp.getType()).getDtype())) {
lerp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
AtenLerpTensorOp::Adaptor adaptor(payloadArgs);
auto start = adaptor.getSelf();
auto end = adaptor.getEnd();
auto weight = adaptor.getWeight();
auto delta = b.create<arith::SubFOp>(loc, end, start);
auto weightedDelta = b.create<arith::MulFOp>(loc, delta, weight);
return b.create<arith::AddFOp>(loc, start, weightedDelta);
}
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
Type elemTy =
cast<RankedTensorType>(converter->convertType(minimum.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
Value pred = createLessThan(b, loc, dtype, lhs, rhs);
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
}
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
Type elemTy =
cast<RankedTensorType>(converter->convertType(maximum.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
Value pred = createGreaterThan(b, loc, dtype, lhs, rhs);
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
}
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
if (isa<Torch::OptionalType>(min.getType()) ||
isa<Torch::OptionalType>(max.getType())) {
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
}
Type dtype = cast<RankedTensorType>(converter->convertType(clamp.getType()))
.getElementType();
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
clamp.emitError("unimplement type for clamp");
return nullptr;
}
Type dstOriginalDtype = cast<BaseTensorType>(clamp.getType()).getDtype();
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
isUnsigned = intTy.isUnsigned();
}
auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value {
clamp = convertScalarToDtype(b, loc, clamp, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*dstOriginalDtype=*/dstOriginalDtype);
Value pred;
if (isa<mlir::FloatType>(dtype)) {
auto cmp =
getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT;
pred = b.create<arith::CmpFOp>(loc, cmp, input, clamp);
} else if (isa<mlir::IntegerType>(dtype)) {
auto cmp =
isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt;
if (getMax)
cmp = arith::invertPredicate(cmp);
pred = b.create<arith::CmpIOp>(loc, cmp, input, clamp);
}
return b.create<arith::SelectOp>(loc, pred, clamp, input);
};
auto result = payloadArgs[0];
if (!isa<Torch::NoneType>(min.getType()))
result = cmpSelect(result, min, /*getMax=*/false);
if (!isa<Torch::NoneType>(max.getType()))
result = cmpSelect(result, max, /*getMax=*/true);
return result;
}
if (auto clampTensor = dyn_cast<AtenClampTensorOp>(op)) {
AtenClampTensorOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
if (isa<Torch::OptionalType>(min.getType()) ||
isa<Torch::OptionalType>(max.getType())) {
clampTensor.emitError("unimplemented: runtime optional type");
return nullptr;
}
Type dtype =
cast<RankedTensorType>(converter->convertType(clampTensor.getType()))
.getElementType();
bool isMinNone = true;
auto result = payloadArgs[0];
if (!isa<Torch::NoneType>(min.getType())) {
isMinNone = false;
auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value pred;
if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, result,
minPromoted);
} else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, result,
minPromoted);
} else {
clampTensor.emitError(
"unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
}
if (!isa<Torch::NoneType>(max.getType())) {
max = isMinNone ? payloadArgs[1] : payloadArgs[2];
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
Value pred;
if (isa<mlir::FloatType>(dtype)) {
pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, result,
maxPromoted);
} else if (isa<mlir::IntegerType>(dtype)) {
pred = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, result,
maxPromoted);
} else {
clampTensor.emitError(
"unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
}
return result;
}
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
Type dtype = cast<RankedTensorType>(converter->convertType(rsub.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(
b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(),
/*dstOriginalDtype=*/dtype);
if (isa<mlir::FloatType>(dtype)) {
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
} else if (isa<mlir::IntegerType>(dtype)) {
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
return b.create<arith::SubIOp>(loc, other, mult);
}
rsub.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(mulScalar.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
if (isa<mlir::FloatType>(dtype))
return b.create<arith::MulFOp>(loc, lhs, rhs);
if (isa<mlir::IntegerType>(dtype))
return b.create<arith::MulIOp>(loc, lhs, rhs);
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
return nullptr;
}
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
Type inputElementType =
cast<BaseTensorType>(atenToDtype.getSelf().getType()).getDtype();
Type dtype =
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
.getElementType();
Type resultElementType;
int64_t dtypeInt;
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
atenToDtype.emitError("unimplemented: dtype must be a constant integer");
return nullptr;
}
FailureOr<Type> maybeResultElementType =
torch_to_linalg::getBackendTypeForScalarType(
atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt);
if (failed(maybeResultElementType)) {
atenToDtype.emitError("unable to convert `dtypeInt` to builtin type");
return nullptr;
}
resultElementType = *maybeResultElementType;
Value result = convertScalarToDtype(b, loc, input, dtype,
/*srcOriginalDtype=*/inputElementType,
/*dstOriginalDtype=*/resultElementType);
return result;
}
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(divScalar.getType()))
.getElementType();
if (!isa<mlir::FloatType>(dtype)) {
divScalar.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::DivFOp>(loc, self, other);
}
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(remScalar.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
Value result;
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remScalar.emitError(
"Unsupported type encountered for AtenRemainderScalarOp.");
}
return result;
}
if (auto remTensor = dyn_cast<AtenRemainderTensorOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(remTensor.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (isa<mlir::FloatType>(newResultType)) {
result = b.create<arith::RemFOp>(loc, self, other);
} else if (isa<mlir::IntegerType>(newResultType)) {
result = b.create<arith::RemSIOp>(loc, self, other);
} else {
remTensor.emitError(
"Unsupported type encountered for AtenRemainderTensorOp.");
}
return result;
}
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(fmod.getType()))
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;
if (isa<mlir::FloatType>(newResultType)) {
Value n = b.create<arith::DivFOp>(loc, self, other);
n = b.create<math::TruncOp>(loc, n);
Value n_y = b.create<arith::MulFOp>(loc, n, other);
result = b.create<arith::SubFOp>(loc, self, n_y);
} else if (isa<mlir::IntegerType>(newResultType)) {
Value n = b.create<arith::DivSIOp>(loc, self, other);
Value n_y = b.create<arith::MulIOp>(loc, n, other);
result = b.create<arith::SubIOp>(loc, self, n_y);
} else {
fmod.emitError("Unsupported type encountered for AtenFmodTensorOp.");
}
return result;
}
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
.getElementType();
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Type elementType = arg.getType();
// assert(element != 0)
auto zero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
auto pred =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE, arg, zero);
b.create<cf::AssertOp>(
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
return b.create<arith::DivFOp>(loc, one, arg);
}
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? value : self
AtenThresholdOp::Adaptor adaptor(operands);
Type dtype =
cast<RankedTensorType>(converter->convertType(thresholdOp.getType()))
.getElementType();
Value self = payloadArgs[0];
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value value = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
Value predicate;
if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<arith::SelectOp>(loc, predicate, value, self);
}
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? 0 : grad
AtenThresholdBackwardOp::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(
converter->convertType(thresholdBackward.getType()))
.getElementType();
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value threshold =
convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype);
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate;
if (isa<mlir::FloatType>(dtype))
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<arith::SelectOp>(loc, predicate, constantZero, grad);
}
if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) {
AtenFillScalarOp::Adaptor adaptor(operands);
Type dtype =
cast<RankedTensorType>(converter->convertType(fillScalar.getType()))
.getElementType();
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
}
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(
converter->convertType(maskedFillTensor.getType()))
.getElementType();
Value input = payloadArgs[0];
Value mask = payloadArgs[1];
Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
}
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
AtenFillTensorOp::Adaptor adaptor(operands);
Type dtype =
cast<RankedTensorType>(converter->convertType(fillTensor.getType()))
.getElementType();
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
}
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
Value result;
if (failed(createTriangularMatrix<arith::CmpIPredicate::sge>(
b, loc, payloadArgs, op, operands, result)))
return nullptr;
return result;
}
if (auto tril = dyn_cast<AtenTrilOp>(op)) {
Value result;
if (failed(createTriangularMatrix<arith::CmpIPredicate::sle>(
b, loc, payloadArgs, op, operands, result)))
return nullptr;
return result;
}
if (auto bitwiseNot = dyn_cast<AtenBitwiseNotOp>(op)) {
Type elementType =
cast<RankedTensorType>(converter->convertType(bitwiseNot.getType()))
.getElementType();
if (isa<mlir::FloatType>(elementType)) {
bitwiseNot.emitError("Bitwise_Not does not support floating point dtype");
return nullptr;
}
Value allOnesVal = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(
elementType,
APSInt::getAllOnes(elementType.getIntOrFloatBitWidth())));
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
}
if (isa<AtenDequantizeTensorOp, AtenDequantizeSelfOp>(op)) {
auto value = payloadArgs[0];
auto valueTy = value.getType();
auto qtensor = op->getOperand(0);
auto qtensorTy = cast<ValueTensorType>(qtensor.getType()).getDtype();
Value zp, scale;
if (auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
zp = makeQTensor.getZeroPoint();
scale = makeQTensor.getScale();
}
if (auto quant = qtensor.getDefiningOp<AtenQuantizePerTensorOp>()) {
zp = quant.getZeroPoint();
scale = quant.getScale();
}
if (!zp || !scale) {
return nullptr;
}
auto outFpTy = payloadArgs[1].getType();
auto outBw = outFpTy.getIntOrFloatBitWidth();
auto outIntTy = b.getIntegerType(outBw);
if (valueTy != outIntTy) {
if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) {
value = b.create<arith::ExtUIOp>(loc, outIntTy, value);
} else {
value = b.create<arith::ExtSIOp>(loc, outIntTy, value);
}
}
zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()), zp);
auto zpTy = zp.getType();
if (zpTy != outIntTy) {
zp = b.create<arith::TruncIOp>(loc, outIntTy, zp);
}
value = b.create<arith::SubIOp>(loc, value, zp);
// treat the i32 as a signed int regardless of original signed-ness
// this will prevent overflow from subtraction for unsigned quantizations.
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);
scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()), scale);
if (scale.getType() != value.getType()) {
scale = b.create<arith::TruncFOp>(loc, value.getType(), scale);
}
value = b.create<arith::MulFOp>(loc, value, scale);
return value;
}
if (auto quant = dyn_cast<AtenQuantizePerTensorOp>(op)) {
Value value = payloadArgs[0];
Value scale = quant.getScale();
Value zp = quant.getZeroPoint();
auto valueTy = value.getType();
zp = converter->materializeTargetConversion(
b, loc, converter->convertType(zp.getType()), zp);
zp = b.create<arith::SIToFPOp>(loc, valueTy, zp);
scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()), scale);
scale = b.create<arith::TruncFOp>(loc, valueTy, scale);
value = b.create<arith::DivFOp>(loc, value, scale);
value = b.create<math::RoundOp>(loc, value);
value = b.create<arith::AddFOp>(loc, value, zp);
auto destTy = payloadArgs[1].getType();
auto bitwidth = destTy.getIntOrFloatBitWidth();
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType());
APInt min = isUnsigned ? APInt::getMinValue(bitwidth)
: APInt::getSignedMinValue(bitwidth);
APInt max = isUnsigned ? APInt::getMaxValue(bitwidth)
: APInt::getSignedMaxValue(bitwidth);
double minI = isUnsigned ? static_cast<double>(min.getZExtValue())
: static_cast<double>(min.getSExtValue());
double maxI = isUnsigned ? static_cast<double>(max.getZExtValue())
: static_cast<double>(max.getSExtValue());
Value minVal =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(valueTy, minI));
Value maxVal =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(valueTy, maxI));
Value minCmp =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, value, minVal);
Value maxCmp =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, value, maxVal);
value = b.create<arith::SelectOp>(loc, minCmp, minVal, value);
value = b.create<arith::SelectOp>(loc, maxCmp, maxVal, value);
if (isUnsigned) {
value = b.create<arith::FPToUIOp>(loc, destTy, value);
} else {
value = b.create<arith::FPToSIOp>(loc, destTy, value);
}
return value;
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
}
namespace {
// Converts an elementwise op.
// This specifically includes:
// - converting elementwise ops of any tensor arity
// - converting elementwise ops with any number of scalar captures (such as a
// scalar alpha to torch.aten.Add)
// - broadcasting of static size-1 dimensions
//
// Currently, we adopt the behavior that "size 1" broadcasting is a runtime
// error if it happens dynamically.
//
// Looking forward a bit, eventually, it probably makes sense to have
// a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted
// operands. Modeling elementwise ops that way is potentially useful to allow a
// more centralized reasoning about multiversioning. However a cost model will
// be needed for "pre-fusing" elementwise ops that way, as it can potentially be
// a pessimization. A mild extension of this pattern should work for such a
// general op.
class ConvertElementwiseOp : public ConversionPattern {
public:
ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForElementwiseOp(
b, loc, getTypeConverter(), payloadArgs, op, operands);
if (!result) {
hadErrorCreatingPayload = true;
return;
}
b.create<linalg::YieldOp>(loc, result);
});
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, generic);
return success();
}
};
} // namespace
// Given `input`, `target`, `nll_loss_forward` is given by:
// for i in range(0, len(target)):
// indi = target[i];
// nll_loss_forward[i] = -(input[i][indi]);
// TODO: `weight`operand is still to be taken care of.
namespace {
class ConvertAtenNllLossForwardOp
: public OpConversionPattern<AtenNllLossForwardOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value input = adaptor.getSelf();
Value target = adaptor.getTarget();
Value weight = adaptor.getWeight();
int64_t reduction;
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Incorporate the weight argument.
if (!isa<mlir::torch::Torch::NoneType>(weight.getType()))
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");
Value ignoreIndex = adaptor.getIgnoreIndex();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
unsigned inputRank = cast<RankedTensorType>(input.getType()).getRank();
unsigned targetRank = cast<RankedTensorType>(target.getType()).getRank();
// TODO: Add support for k-dim loss.
if (inputRank > 2) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank <= 2");
}
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
Type elementType = resultType.getElementType();
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {target}, elementType,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0];
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), targetVal);
// The final result is given by:
// final_res = (indTarget == ignoreIndexVal) ? 0 :
// input[indI][IndTarget]
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal);
SmallVector<Value> extractionIndices{indTarget};
if (inputRank == 2) {
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
extractionIndices.insert(extractionIndices.begin(), indI);
}
Value result =
rewriter.create<tensor::ExtractOp>(loc, input, extractionIndices);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, result);
Value selectFinal =
rewriter.create<arith::SelectOp>(loc, cmpEq, zeroVal, negate);
b.create<linalg::YieldOp>(loc, selectFinal);
});
llvm::iota_range<int64_t> dimsToReduce(0, targetRank,
/*inclusive=*/false);
DenseSet<int64_t> dimSet(dimsToReduce.begin(), dimsToReduce.end());
if (reduction == torch_upstream::Reduction::Sum ||
reduction == torch_upstream::Reduction::Mean) {
Value numOfElems = getTensorSize(rewriter, loc, finalRes);
numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType);
auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet};
finalRes = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo,
/*initElem=*/zeroVal,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value newVal = args[0];
Value accumulator = args[1];
if (reduction == torch_upstream::Reduction::Mean)
newVal = b.create<arith::DivFOp>(loc, newVal, numOfElems);
Value result = b.create<arith::AddFOp>(loc, newVal, accumulator);
b.create<linalg::YieldOp>(loc, result);
});
}
// The implementation for the `total_weight` has been adopted from here:
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L154-L294
// As per the ref link, the `total_weight` value when the `weight` is
// `None`, is equal to `total_weight = batch_size - num_ignored_index`,
// where `batch_size` is equal to `target.shape[0]` when rank(target) > 0,
// otherwise 1. The value `num_ignored_index` is the number of elements of
// the `target` tensors that have been ignored.
if (reduction == torch_upstream::Reduction::None && inputRank == 2) {
Value totalWeight = createZeroInitTensor(rewriter, loc, {}, elementType);
rewriter.replaceOp(op, {finalRes, totalWeight});
return success();
}
Value numIgnoredIndex;
if (targetRank == 0) {
Value targetVal = rewriter.create<tensor::ExtractOp>(loc, target);
numIgnoredIndex = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
numIgnoredIndex = convertScalarToDtype(rewriter, loc, numIgnoredIndex,
ignoreIndex.getType());
} else {
Value zeroCstInt = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(ignoreIndex.getType()));
auto opInfo =
torch_to_linalg::ReductionOpInfo{/*keepDim=*/false, target, dimSet};
numIgnoredIndex = torch_to_linalg::createReductionLinalgGeneric(
rewriter, loc, opInfo,
/*initElem=*/zeroCstInt,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value targetVal = args[0];
Value accumulator = args[1];
Value result = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex);
result = b.create<arith::AddIOp>(
loc,
convertScalarToDtype(rewriter, loc, result,
ignoreIndex.getType()),
accumulator);
b.create<linalg::YieldOp>(loc, result);
});
numIgnoredIndex =
rewriter.create<tensor::ExtractOp>(loc, numIgnoredIndex);
}
Value numtargetElems = getTensorSize(rewriter, loc, target);
Value totalWeightVal =
rewriter.create<arith::SubIOp>(loc, numtargetElems, numIgnoredIndex);
Value totalWeight = createInitTensor(
rewriter, loc, {}, elementType,
convertScalarToDtype(rewriter, loc, totalWeightVal, elementType));
rewriter.replaceOp(op, {finalRes, totalWeight});
return success();
}
};
} // namespace
/// Inverted STD: rSTD = 1 / sqrt(var + eps).
static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps,
Value var) {
// The eps is always f64.
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
return rSTD;
}
// Normalization formula:
// ((input - mean) * rSTD * weight + bias
static Value createLinalgPayloadCalculationForNormOpsWithRSTD(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean,
Value rSTD, Value eps, Value weight, Value bias) {
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
return plusBias;
}
static Value createLinalgPayloadCalculationForNormOpsWithVar(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
Value eps, Value weight, Value bias) {
Value rSTD = calculateRSTD(b, loc, elemTy, eps, var);
Value result = createLinalgPayloadCalculationForNormOpsWithRSTD(
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
return result;
}
namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.getInput();
Value weight = adaptor.getWeight();
Value bias = adaptor.getBias();
Value runningMean = adaptor.getRunningMean();
Value runningVar = adaptor.getRunningVar();
Value training = adaptor.getTraining();
Value eps = adaptor.getEps();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)) ||
failed(checkNotNone(rewriter, op, runningMean)) ||
failed(checkNotNone(rewriter, op, runningVar)))
return failure();
auto inputType = cast<RankedTensorType>(input.getType());
auto weightType = cast<RankedTensorType>(weight.getType());
auto biasType = cast<RankedTensorType>(bias.getType());
auto runningMeanType = cast<RankedTensorType>(runningMean.getType());
auto runningVarType = cast<RankedTensorType>(runningVar.getType());
auto inputRank = inputType.getRank();
if (inputRank < 2)
return rewriter.notifyMatchFailure(
op, "input should have rank larger than 1");
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1");
}
// TODO: Add support for training.
auto constFalse = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
auto trainingFalse = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, training, constFalse);
rewriter.create<cf::AssertOp>(
loc, trainingFalse,
rewriter.getStringAttr("training is not supported for now"));
// num_features C from an expected input of size (N,C,D,H,W ...)
Value numFeatures = rewriter.create<tensor::DimOp>(loc, input, 1);
auto contractingDim0EqualsNumFeatures = [&](Value v) {
auto dim0 = rewriter.create<tensor::DimOp>(loc, v, 0);
auto dim0Equal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, numFeatures, dim0);
rewriter.create<cf::AssertOp>(
loc, dim0Equal,
rewriter.getStringAttr(
"expect the size of dim 0 equal to the number of features"));
};
if (!isAssumingStrictSymbolicShapes(rewriter)) {
contractingDim0EqualsNumFeatures(weight);
contractingDim0EqualsNumFeatures(bias);
contractingDim0EqualsNumFeatures(runningMean);
contractingDim0EqualsNumFeatures(runningVar);
}
auto indexingMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
indexingMap, // weight
indexingMap, // bias
indexingMap, // runningMean
indexingMap, // runningVar
rewriter.getMultiDimIdentityMap(inputRank), // output
};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value batchNorm =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(),
ValueRange{input, weight, bias, runningMean, runningVar}, input,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4];
Value result =
createLinalgPayloadCalculationForNormOpsWithVar(
b, loc, var.getType(), input, mean, var, eps, weight,
bias);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
return success();
}
};
} // namespace
namespace {
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value gradOutput = adaptor.getGradOutput();
Value input = adaptor.getSelf();
Value target = adaptor.getTarget();
Value weight = adaptor.getWeight();
bool weightIsNone = isa<Torch::NoneType>(op.getWeight().getType());
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
Value totalWeight = adaptor.getTotalWeight();
auto inputType = cast<RankedTensorType>(input.getType());
int inputRank = inputType.getRank();
auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
Type resultElementType = gradOutputType.getElementType();
int64_t reduction;
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
if (!hasElementType<mlir::FloatType>(gradOutput) ||
!hasElementType<mlir::FloatType>(gradOutput) ||
(!weightIsNone && !hasElementType<mlir::FloatType>(weight))) {
return rewriter.notifyMatchFailure(
op, "`gradOutput`, 'weight', and `totalWeight` must be tensors of "
"type float");
}
if (!hasElementType<mlir::IntegerType>(target)) {
return rewriter.notifyMatchFailure(
op, "`target` must be a tensor of integer type");
}
auto outputSize = getTensorSizes(rewriter, loc, input);
Value gradInputTensor =
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
auto tensorType = cast<RankedTensorType>(tensor.getType());
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
rewriter.getAffineConstantExpr(0));
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
op->getContext());
};
AffineMap gradOutMap = AffineMap::get(inputRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(0));
if (reduction != torch_upstream::Reduction::None || inputRank == 1)
gradOutMap = getAffineMapForSingleElementTensor(gradOutput);
AffineMap targetMap = AffineMap::get(inputRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(0));
if (inputRank == 1)
targetMap = getAffineMapForSingleElementTensor(target);
AffineMap totalWeightMap = getAffineMapForSingleElementTensor(totalWeight);
AffineMap resultMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<AffineMap> indexingMaps{gradOutMap, targetMap, totalWeightMap,
resultMap};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
// The code generation is equivalent to the following pseudo-code:
//
// for batch_index in len(input.size(0)):
// for class_index in len(input.size(1)):
// target_elem = target[batch_index]
//
// if reduction == None:
// grad_out_elem = grad_output[batchIndex]
// else:
// grad_out_elem = grad_output[0]
//
// if reduction == Mean:
// total_weight_elem = total_weight[0]
// grad_out_elem /= total_weight_elem
//
// weight_elem = weight[target_elem] if weight != None else 1
//
// if target_elem != class_index or target_elem == ignore_index:
// grad_input_elem = -weight_elem * grad_out_elem
// else:
// grad_input_elem = 0
// grad_input[batch_index, target_elem] = grad_input_elem
//
// NOTE: In the case of not batch dimension, `batch_index` essentially
// becomes zero.
Value gradInput =
rewriter
.create<linalg::GenericOp>(
loc, gradInputTensor.getType(),
ValueRange{gradOutput, target, totalWeight}, gradInputTensor,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gradOutElem = args[0];
Value targetElem = castIntToIndex(b, loc, args[1]);
Value totalWeightElem = args[2];
Value classIndex =
b.create<linalg::IndexOp>(loc, inputRank - 1);
if (reduction == torch_upstream::Reduction::Mean) {
gradOutElem = b.create<arith::DivFOp>(loc, gradOutElem,
totalWeightElem);
}
Value negGradOutElem =
b.create<arith::NegFOp>(loc, gradOutElem);
Value weightElem = getConstant(b, loc, 1, resultElementType);
if (!weightIsNone) {
weightElem =
b.create<tensor::ExtractOp>(loc, weight, targetElem);
}
Value weightedNegGradOutElem =
b.create<arith::MulFOp>(loc, weightElem, negGradOutElem);
Value targetNeqClassIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, targetElem, classIndex);
Value targetEqIgnoreIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex);
Value gradInputIsZero = b.create<arith::OrIOp>(
loc, targetNeqClassIndex, targetEqIgnoreIndex);
Value zero = getConstant(b, loc, 0, resultElementType);
Value gradInElem = b.create<arith::SelectOp>(
loc, gradInputIsZero, zero, weightedNegGradOutElem);
b.create<linalg::YieldOp>(loc, gradInElem);
})
->getResult(0);
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, gradInput);
return success();
}
};
} // namespace
namespace {
class ConvertAtenDetachOp : public OpConversionPattern<AtenDetachOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDetachOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
} // namespace
namespace {
class ConvertPrimsSplitDimOp : public OpConversionPattern<PrimsSplitDimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
const TypeConverter *typeConverter = getTypeConverter();
auto resultRankedTensorType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
// The dimension being split must be statically known.
int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
return failure();
SmallVector<ReassociationIndices> associations;
associations.reserve(aRankedTensorType.getRank());
for (unsigned i = 0; i < dimInt; ++i) {
associations.push_back(ReassociationIndices{i});
}
associations.push_back(ReassociationIndices{dimInt, dimInt + 1});
for (int i = dimInt + 2; i < resultRankedTensorType.getRank(); ++i) {
associations.push_back(ReassociationIndices{i});
}
auto expanded = rewriter.createOrFold<tensor::ExpandShapeOp>(
op.getLoc(), resultRankedTensorType, adaptor.getA(), associations);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultRankedTensorType,
expanded);
return success();
}
};
} // namespace
namespace {
class ConvertPrimsCollapseOp : public OpConversionPattern<PrimsCollapseOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimsCollapseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
const TypeConverter *typeConverter = getTypeConverter();
auto resultRankedTensorType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
// Collapse range must be statically known.
int64_t startInt;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&startInt)))
return failure();
int64_t endInt;
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&endInt)))
return failure();
// Upstream MLIR is overly strict -- it fails verification if the
// collapse_shape is the identity op (i.e. when no dimensions are
// collapsed). We manually fold this case here.
if (startInt == endInt) {
rewriter.replaceOp(op, adaptor.getA());
return success();
}
SmallVector<ReassociationIndices> associations;
associations.reserve(resultRankedTensorType.getRank());
// An example of is where input shape is [3,4,5,6] and
// start = 1, and end = 2. The collapsed shape is then [3,4*5,6],
// with reassociation indices of [0], [1,2], and [3].
// Append the singleton dimensions before the collapsed dimensions.
for (unsigned i = 0; i < startInt; ++i) {
associations.push_back(ReassociationIndices{i});
}
// Append the collapsed dimensions.
ReassociationIndices collapseDims(endInt + 1 - startInt);
std::iota(collapseDims.begin(), collapseDims.end(), startInt);
associations.push_back(collapseDims);
// Append the singleton dimensions after the collapsed dimensions.
for (int i = endInt + 1; i < aRankedTensorType.getRank(); ++i) {
associations.push_back(ReassociationIndices{i});
}
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultRankedTensorType, adaptor.getA(), associations);
return success();
}
};
} // namespace
namespace {
class ConvertTensorStaticInfoCastOp
: public OpConversionPattern<TensorStaticInfoCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getOperand());
return success();
}
};
} // namespace
namespace {
class ConvertLogitOp : public OpConversionPattern<AtenLogitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value input = adaptor.getSelf();
Value eps = adaptor.getEps();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
bool handleEps = false;
if (succeeded(checkNotNone(rewriter, op, eps)))
handleEps = true;
if (handleEps && !isa<mlir::FloatType>(eps.getType())) {
op.emitError("Logit does not support non-floating point type");
return failure();
}
auto inputType = cast<RankedTensorType>(input.getType());
auto inputElementType = inputType.getElementType();
if (!isa<mlir::FloatType>(inputElementType)) {
op.emitError("Logit does not support non-floating point type");
return failure();
}
auto inputRank = inputType.getRank();
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
rewriter.getMultiDimIdentityMap(inputRank), // output
};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value logit =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(),
/*ins=*/input,
/*outs=*/input,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0];
TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0);
Value oneValue = b.create<arith::ConstantOp>(loc, oneAttr);
Value zI;
if (!handleEps) {
zI = input;
} else {
Value truncEps =
b.create<arith::TruncFOp>(loc, inputElementType, eps);
Value oneMinusEps =
b.create<arith::SubFOp>(loc, oneValue, truncEps);
Value min =
b.create<arith::MinimumFOp>(loc, input, oneMinusEps);
Value clampedInput =
b.create<arith::MaximumFOp>(loc, min, truncEps);
zI = clampedInput;
}
Value probability =
b.create<arith::SubFOp>(loc, oneValue, zI);
Value odds = b.create<arith::DivFOp>(loc, zI, probability);
Value result = b.create<math::LogOp>(loc, odds);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, logit);
return success();
}
};
} // namespace
namespace {
class ConvertAtenIntReprOp : public OpConversionPattern<AtenIntReprOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
} // namespace
namespace {
class ConvertDequantizePerChannel
: public OpConversionPattern<AtenDequantizeSelfOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto qoperand = op.getOperand();
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
if (!make) {
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
}
auto converter = getTypeConverter();
auto operand = make.getOperand(0);
auto scale = make.getScale();
auto zeropoint = make.getZeroPoint();
auto axis = make.getAxis();
IntegerAttr axisAttr;
if (!matchPattern(axis, m_Constant(&axisAttr))) {
return failure();
}
auto operandDTy = cast<ValueTensorType>(operand.getType()).getDtype();
auto zeropointDTy = cast<ValueTensorType>(zeropoint.getType()).getDtype();
operand = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(operand.getType()), operand);
scale = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(scale.getType()), scale);
zeropoint = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
auto resultType = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()));
llvm::SmallVector<Value> dynSizes;
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
if (ShapedType::isDynamic(dim)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, operand, index));
}
}
llvm::SmallVector<utils::IteratorType> iterators(
resultType.getRank(), utils::IteratorType::parallel);
llvm::SmallVector<AffineMap> maps(
4, {rewriter.getMultiDimIdentityMap(resultType.getRank())});
auto broadcastMap = AffineMap::get(
resultType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext());
maps[1] = broadcastMap;
maps[2] = broadcastMap;
auto empty =
rewriter.create<tensor::EmptyOp>(op.getLoc(), resultType, dynSizes);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultType, ValueRange{operand, scale, zeropoint},
ValueRange{empty}, maps, iterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value operand = args[0];
Value scale = args[1];
Value zeropoint = args[2];
if (operandDTy.isUnsignedInteger(8)) {
operand = b.create<arith::ExtUIOp>(loc, b.getI32Type(), operand);
} else if (operandDTy.isSignedInteger(8)) {
operand = b.create<arith::ExtSIOp>(loc, b.getI32Type(), operand);
}
if (zeropointDTy.isUnsignedInteger(8)) {
zeropoint =
b.create<arith::ExtUIOp>(loc, b.getI32Type(), zeropoint);
} else if (zeropointDTy.isSignedInteger(8)) {
zeropoint =
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
}
Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
Value fp =
rewriter.create<arith::SIToFPOp>(loc, args[3].getType(), sub);
Value mul = rewriter.create<arith::MulFOp>(loc, fp, scale);
b.create<linalg::YieldOp>(loc, mul);
});
rewriter.replaceOp(op, linalgOp.getResults());
return success();
}
};
} // namespace
namespace {
template <typename OpTy>
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
}
};
} // namespace
namespace {
class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Type int64type = rewriter.getI64Type();
Type floatType = rewriter.getF32Type();
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value twoIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2);
Value zeroFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.0));
Value oneFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 1.0));
Value twoFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 2.0));
Value input = adaptor.getInput();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
Value innerDim0b =
rewriter.create<arith::SubIOp>(loc, innerDim0a, oneIndex);
Value innerDim1b =
rewriter.create<arith::SubIOp>(loc, innerDim1a, oneIndex);
Value innerDim0c =
rewriter.create<arith::IndexCastOp>(loc, int64type, innerDim0b);
Value innerDim1c =
rewriter.create<arith::IndexCastOp>(loc, int64type, innerDim1b);
Value innerDim0d =
rewriter.create<arith::SIToFPOp>(loc, floatType, innerDim0c);
Value innerDim1d =
rewriter.create<arith::SIToFPOp>(loc, floatType, innerDim1c);
Value innerDim0e =
rewriter.create<arith::DivFOp>(loc, innerDim0d, twoFloat);
Value innerDim1e =
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
Value grid = adaptor.getGrid();
auto gridType = cast<RankedTensorType>(grid.getType());
auto gridShape = gridType.getShape();
auto gridRank = gridType.getRank();
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
SmallVector<Value> extractGridShape = getTensorSizes(rewriter, loc, grid);
SmallVector<Value> extractGridStride(gridRank, oneIndex);
int64_t lastGridDim = gridRank - 1;
extractGridShape[lastGridDim] = oneIndex;
extractGridStride[lastGridDim] = twoIndex;
SmallVector<Value> extractGridOffsets1(gridRank, zeroIndex);
extractGridOffsets1[lastGridDim] = oneIndex;
SmallVector<int64_t> gridShapeExtracted(gridShape);
gridShapeExtracted.back() = 1;
SmallVector<int64_t> gridShapeCollapsed{gridShape[0], gridShape[1],
gridShape[2]};
auto grid0 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets0, extractGridShape, extractGridStride);
auto grid1 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets1, extractGridShape, extractGridStride);
SmallVector<ReassociationIndices> associations{ReassociationIndices{0},
ReassociationIndices{1},
ReassociationIndices{2, 3}};
auto gridCollapsed0 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid0, associations);
auto gridCollapsed1 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid1, associations);
AffineMap gridMap = AffineMap::get(4, 0,
{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3)},
op->getContext());
SmallVector<AffineMap> gridMaps{gridMap, gridMap,
rewriter.getMultiDimIdentityMap(gridRank)};
SmallVector<utils::IteratorType> gridIterators(
gridRank, utils::IteratorType::parallel);
SmallVector<int64_t> resultShape{inputShape[0], inputShape[1], gridShape[1],
gridShape[2]};
auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
Value idxB, Value idxC, Value idxD) -> Value {
SmallVector<Value> index{idxA, idxB, idxC, idxD};
Value result = b.create<tensor::ExtractOp>(loc, input, index);
return result;
};
auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {
Value dm = b.create<arith::SubFOp>(loc, oneFloat, d);
Value ra = b.create<arith::MulFOp>(loc, x, dm);
Value rb = b.create<arith::MulFOp>(loc, y, d);
Value res = b.create<arith::AddFOp>(loc, ra, rb);
return res;
};
auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y,
Value d) -> Value {
Value halfConst = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.5));
Value checkClosest =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, d, halfConst);
Value res = b.create<arith::SelectOp>(loc, checkClosest, x, y);
return res;
};
auto lambdaInterpolate = [&](OpBuilder &b, Location loc, Value iMode,
Value x, Value y, Value d) -> Value {
Value linear = lambdaLinear(b, loc, x, y, d);
Value nearest = lambdaNearest(b, loc, x, y, d);
Value zeroInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
Value checkMode = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
iMode, zeroInt);
Value res = b.create<arith::SelectOp>(loc, checkMode, linear, nearest);
return res;
};
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<Value> resultSize{};
if (resultType.isDynamicDim(0))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (resultType.isDynamicDim(1))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (resultType.isDynamicDim(2))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
if (resultType.isDynamicDim(3))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();
Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
auto sGrid = rewriter.create<linalg::GenericOp>(
loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1},
ValueRange(resultFinal), gridMaps, gridIterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gr0 = args[1];
Value gr1 = args[0];
Value gr0Half = b.create<arith::DivFOp>(loc, gr0, twoFloat);
Value gr1Half = b.create<arith::DivFOp>(loc, gr1, twoFloat);
Value gr0HalfSelect =
b.create<arith::SelectOp>(loc, alignCorners, zeroFloat, gr0Half);
Value gr1HalfSelect =
b.create<arith::SelectOp>(loc, alignCorners, zeroFloat, gr1Half);
Value gplus0 = b.create<arith::AddFOp>(loc, gr0, oneFloat);
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
Value result0 =
b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
Value result1 =
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
Value checkLowerBound0 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
Value checkLowerBound1 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result1, zeroFloat);
Value lowerOrig0 = b.create<arith::FPToSIOp>(loc, int64type, result0);
Value lowerOrig1 = b.create<arith::FPToSIOp>(loc, int64type, result1);
Value zeroInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
Value oneInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 1));
Value lowerSub0 = b.create<arith::SubIOp>(loc, lowerOrig0, oneInt);
Value lowerSub1 = b.create<arith::SubIOp>(loc, lowerOrig1, oneInt);
Value lower0 = b.create<arith::SelectOp>(loc, checkLowerBound0,
lowerSub0, lowerOrig0);
Value lower1 = b.create<arith::SelectOp>(loc, checkLowerBound1,
lowerSub1, lowerOrig1);
Value lowerValid0 =
b.create<arith::SelectOp>(loc, checkLowerBound0, zeroInt, lower0);
Value lowerValid1 =
b.create<arith::SelectOp>(loc, checkLowerBound1, zeroInt, lower1);
Value upper0 =
b.create<arith::AddIOp>(loc, int64type, lower0, oneInt);
Value upper1 =
b.create<arith::AddIOp>(loc, int64type, lower1, oneInt);
Value notValidUpper0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper0, innerDim0c);
Value notValidUpper1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper1, innerDim1c);
Value upperValid0 =
b.create<arith::SelectOp>(loc, notValidUpper0, lower0, upper0);
Value upperValid1 =
b.create<arith::SelectOp>(loc, notValidUpper1, lower1, upper1);
Value lw0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowerValid0);
Value lw1 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowerValid1);
Value up0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid0);
Value up1 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid1);
Value N = b.create<linalg::IndexOp>(loc, 0);
Value C = b.create<linalg::IndexOp>(loc, 1);
Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1);
Value result00a = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result00);
Value result00b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result00a);
Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1);
Value result01a = b.create<arith::SelectOp>(loc, notValidUpper1,
zeroFloat, result01);
Value result01b = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result01a);
Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1);
Value result10a = b.create<arith::SelectOp>(loc, notValidUpper0,
zeroFloat, result10);
Value result10b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result10a);
Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1);
Value result11a = b.create<arith::SelectOp>(loc, notValidUpper0,
zeroFloat, result11);
Value result11b = b.create<arith::SelectOp>(loc, notValidUpper1,
zeroFloat, result11a);
Value lw0a = b.create<arith::SIToFPOp>(loc, floatType, lower0);
Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1);
Value d1 = b.create<arith::SubFOp>(loc, result0, lw0a);
Value d0 = b.create<arith::SubFOp>(loc, result1, lw1a);
Value resultScaled0 =
lambdaInterpolate(b, loc, interMode, result00b, result01b, d0);
Value resultScaled1 =
lambdaInterpolate(b, loc, interMode, result10b, result11b, d0);
Value resultScaled = lambdaInterpolate(
b, loc, interMode, resultScaled0, resultScaled1, d1);
b.create<linalg::YieldOp>(loc, resultScaled);
});
rewriter.replaceOp(op, sGrid.getResults());
return success();
}
};
} // namespace
static Value NearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
for (unsigned i = 2; i < inputRank; i++) {
Value outIndex = indices[i];
Value inputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i - 2]);
Value outputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizes[i - 2]);
// scale = length_resized / length_original
// x_original = x_resized / scale
Value scale;
if (scaleValues.empty())
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);
else
scale = scaleValues[i - 2];
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), outIndex);
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
Value proj = b.create<arith::DivFOp>(loc, outFP, scale);
// get nearest pixel using floor
Value nearestFP = b.create<math::FloorOp>(loc, proj);
Value nearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), nearestFP);
Value nearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), nearestInt);
indices[i] = nearest;
}
Value retVal = b.create<tensor::ExtractOp>(loc, input, indices);
return retVal;
}
static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Value cstOneEps =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.000001));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
// length_resized
Value outputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizes[i]);
// scale = length_resized/length_original
Value scale;
if (alignCornersBool) {
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
Value outputSizeSubOne =
b.create<arith::SubFOp>(loc, outputSizeFP, cstOneFloat);
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
outputSizeSubOne, zero);
scale = b.create<arith::DivFOp>(loc, inputSubOne, outputSizeSubOne);
scale = b.create<arith::SelectOp>(loc, cmp, zero, scale);
coordStr = "_align_corners";
} else if (scaleValues.empty())
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputFP);
else
scale = scaleValues[i];
// y_resized
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(),
indices[i + dimOffset]);
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
Value preClip;
if (coordStr == "_align_corners") {
preClip = b.create<arith::MulFOp>(loc, outFP, scale);
}
if (coordStr == "_asymmetric") {
preClip = b.create<arith::DivFOp>(loc, outFP, scale);
}
if (coordStr == "_pytorch_half_pixel" || coordStr == "") {
// half-pixel modes
// y_resized + 0.5
Value outPlusHalf = b.create<arith::AddFOp>(loc, outFP, cstHalf);
// (y_resized + 0.5) / scale
Value outDivScale = b.create<arith::DivFOp>(loc, outPlusHalf, scale);
// _ - 0.5
preClip = b.create<arith::SubFOp>(loc, outDivScale, cstHalf);
}
// for pytorch half pixel , special case for length_resized == 1:
if (coordStr == "_pytorch_half_pixel") {
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
outputSizeFP, cstOneFloat);
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
}
// clip to 0,inf
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
// length_original - 1.001
Value inputSubOneEps = b.create<arith::SubFOp>(loc, inputFP, cstOneEps);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// clip to [0,length_original - 1.001]
projEps.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOneEps));
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
lowFP.push_back(b.create<math::FloorOp>(loc, projEps[i]));
Value projPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, projEps[i]);
highFP.push_back(b.create<math::FloorOp>(loc, projPlusOne));
Value lowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowFP[i]);
low.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowInt));
Value highInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), highFP[i]);
high.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), highInt));
}
SmallVector<Value> cornerValues;
indices[dimOffset] = low[0];
indices[dimOffset + 1] = low[1];
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[dimOffset] = low[0];
indices[dimOffset + 1] = high[1];
Value p01 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[dimOffset] = high[0];
indices[dimOffset + 1] = low[1];
Value p10 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[dimOffset] = high[0];
indices[dimOffset + 1] = high[1];
Value p11 = b.create<tensor::ExtractOp>(loc, input, indices);
// Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)),
// where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc.
// We interpolate via the weighted average of pij by weights Aij
// the formula is retval = Sum(pij*Aij for i and j in range(2))
// Note: we do not need to divide by total rect area == 1
// lengths : Aij == dyi*dxj
Value dy0 = b.create<arith::SubFOp>(loc, highFP[0], proj[0]);
Value dy1 = b.create<arith::SubFOp>(loc, proj[0], lowFP[0]);
Value dx0 = b.create<arith::SubFOp>(loc, highFP[1], proj[1]);
Value dx1 = b.create<arith::SubFOp>(loc, proj[1], lowFP[1]);
// left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01)
Value dx0p00 = b.create<arith::MulFOp>(loc, dx0, p00);
Value dx1p01 = b.create<arith::MulFOp>(loc, dx1, p01);
Value sum = b.create<arith::AddFOp>(loc, dx0p00, dx1p01);
Value left = b.create<arith::MulFOp>(loc, dy0, sum);
// right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11)
Value dx0p10 = b.create<arith::MulFOp>(loc, dx0, p10);
Value dx1p11 = b.create<arith::MulFOp>(loc, dx1, p11);
sum = b.create<arith::AddFOp>(loc, dx0p10, dx1p11);
Value right = b.create<arith::MulFOp>(loc, dy1, sum);
return b.create<arith::AddFOp>(loc, left, right);
}
namespace {
class ConvertInterpolateOp
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string mode;
// note: to support onnx.Resize, we are passing some extra options through
// the mode attribute. For example, onnx.Resize with mode="linear" and
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
// op with the non-standard mode="bilinear_asymmetric".
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
return failure();
}
Location loc = op->getLoc();
Value input = adaptor.getInput();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
return rewriter.notifyMatchFailure(
op,
"cannot perform bilinear interpolation when input spatial dims != 2");
SmallVector<Value> outputSizeIntValues;
SmallVector<Value> inputSizes;
SmallVector<Value> ScaleFactorFloatValues;
for (unsigned i = 2; i < inputRank; i++) {
Value inputSize = getDimOp(rewriter, loc, input, i);
inputSizes.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSize));
}
if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
bool recompScale;
if (!matchPattern(op.getRecomputeScaleFactor(),
m_TorchConstantBool(&recompScale)))
recompScale = false;
SmallVector<Value> ScaleFactorTorchFloat;
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
for (unsigned i = 0; i < inputRank - 2; i++) {
Value inputSizeFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizes[i]);
ScaleFactorFloatValues[i] = rewriter.create<arith::TruncFOp>(
loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]);
Value outputSize = rewriter.create<arith::MulFOp>(
loc, inputSizeFP, ScaleFactorFloatValues[i]);
outputSize = rewriter.create<math::FloorOp>(loc, outputSize);
outputSize = rewriter.create<arith::FPToSIOp>(
loc, rewriter.getI64Type(), outputSize);
outputSizeIntValues.push_back(outputSize);
}
if (recompScale)
ScaleFactorFloatValues.clear();
} else {
SmallVector<Value> outputSizeTorchInt;
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
}
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
for (unsigned i = 2; i < inputRank; i++) {
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2]));
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{}, outTensor,
/*indexingMaps=*/idMap,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value retVal;
if (mode.substr(0, 7) == "nearest") {
retVal = NearestInterpolate(
b, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(7));
} else if (mode.substr(0, 8) == "bilinear") {
retVal = BilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(8));
}
b.create<linalg::YieldOp>(loc, retVal);
})
.getResult(0);
Type newResultType =
getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
target.addIllegalOp<AtenDetachOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenLogitOp>();
patterns.add<ConvertLogitOp>(typeConverter, context);
target.addIllegalOp<PrimsCollapseOp>();
patterns.add<ConvertPrimsCollapseOp>(typeConverter, context);
target.addIllegalOp<PrimsSplitDimOp>();
patterns.add<ConvertPrimsSplitDimOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossBackwardOp>();
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
target.addIllegalOp<TensorStaticInfoCastOp>();
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
target.addIllegalOp<AtenIntReprOp>();
patterns.add<ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerChannelQuantizedTensorOp>();
patterns.add<ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
target.addIllegalOp<AtenGridSamplerOp>();
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
target.addIllegalOp<Aten__InterpolateSizeListScaleListOp>();
patterns.add<ConvertInterpolateOp>(typeConverter, context);
}