
4670 lines
199 KiB
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_upstream; // For ScalarType and type
// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
// -----------------------------------------------------------------------------
// This is going to eventually be O(#aten ops), which is in the 100s.
// Most of these patterns consist of:
// 1. Checking that the operand/result types and other static properties are
// good-enough to create a valid linalg op (such as operands being of
// ranks/dtypes acceptable to the linalg op).
// 2. Creating dynamic error guards, usually checking a predicate on the
// compatibility of operand shapes.
// 3. Creating init tensors for the computation op. Usually this involves
// reifying IR for a shape transfer function based on the operand shapes.
// 4. Creating a named linalg op to replace the original op.
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
// that these patterns become mostly mechanical associations of
// " ->".
static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
PatternRewriter &rewriter) {
// Check the value tensor is ranked as expected by Linalg.
// TODO: Remove this check but use a separate verification pass to verify the
// invariants expected by later passes.
auto isValidLinalgType = [](Type type) {
auto tensor = type.dyn_cast<ValueTensorType>();
return !tensor ||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
llvm::all_of(op->getResultTypes(), isValidLinalgType);
if (!valid)
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
return success();
static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op,
Value v) {
Type type = v.getType();
if (type.isa<OptionalType>() || type.isa<Torch::NoneType>() ||
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
return success();
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
"dim arg of toPositiveDim must be integer type");
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predDimGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
Value dimInt = b.create<SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
return dimInt;
// Generate IR: assert(dim >= 0 && dim < inputRank)
static void assertIsValidDim(OpBuilder &b, Location loc, Value dim,
Value inputRank) {
assert(dim.getType().isa<IntegerType>() &&
"dim arg of assertIsValidDim must be integer type");
Value cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
Value predGEZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
b.create<AssertOp>(loc, predGEZero,
b.getStringAttr("dim must be greater or equal to zero"));
Value predLTInputRank =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
b.create<AssertOp>(loc, predLTInputRank,
b.getStringAttr("dim must be smaller than inputRank"));
// Hack to deal with the Torch list type arguments which is not supported end
// to end. Constant values can be be extracted directly and non constant
// list values are not supported.
// TODO: loose this constraint when properly support list type
static bool isConstantIntListMatching(Value value,
SmallVectorImpl<int64_t> &expects) {
SmallVector<int64_t> intValues;
if (!matchPattern(value, m_TorchConstantIntList(intValues)))
return false;
if (intValues.size() != expects.size())
return false;
for (auto it : llvm::zip(intValues, expects)) {
if (std::get<0>(it) != std::get<1>(it))
return false;
return true;
static Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
return b.createOrFold<tensor::DimOp>(loc, v, dim);
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
Value rhsDim) {
Type lhsType = lhsDim.getType();
Type rhsType = rhsDim.getType();
auto checkIntOrIndex = [](Type type) {
assert(type.isa<IntegerType>() ||
type.isa<IndexType>() && "must be either integer or index type");
Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim;
Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim;
Value contractingDimEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt);
b.create<AssertOp>(loc, contractingDimEqual,
b.getStringAttr("mismatching contracting dimension"));
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
Value tensor, int dim) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
assert(dim < type.getRank() &&
"The given dim must be smaller than tensor rank");
SmallVector<Value> sizes;
for (int i = 0; i <= dim; i++)
sizes.push_back(getDimOp(b, loc, tensor, i));
return sizes;
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
Value tensor) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
RankedTensorType type = initTensor.getType().cast<RankedTensorType>();
Value c0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
// Creates a tensor with required `sizes` and `elemTy` and fills it with
// initElem.
static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy, Value initElem) {
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
// Creates a constant of type `elemType` with value `val`.
static Value getConstant(OpBuilder &b, Location loc, int64_t val,
Type elemType) {
Attribute attr = {};
if (elemType.isa<mlir::FloatType>())
attr = b.getFloatAttr(elemType, val);
if (elemType.isa<mlir::IndexType>())
attr = b.getIndexAttr(val);
if (elemType.isa<mlir::IntegerType>())
attr = b.getIntegerAttr(
elemType, APInt(elemType.cast<IntegerType>().getWidth(), val));
if (!attr)
return nullptr;
return b.create<arith::ConstantOp>(loc, elemType, attr);
// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
// dim_out =
// floor((dim_in + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1
static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt) {
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);
// in + 2 * padding
Value inAddDoublePadding =
b.create<arith::AddIOp>(loc, castIndexToInt(b, loc, in), doublePadding);
// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
Value dilationTimesKernelSize =
b.create<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
Value temp =
b.create<arith::SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
Value dividend = b.create<arith::SubIOp>(loc, temp, c1);
Value division = b.create<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.create<arith::AddIOp>(loc, division, c1);
return castIntToIndex(b, loc, out);
static SmallVector<Value>
getAsConstantIntValues(OpBuilder &b, Location loc,
SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
return b.create<arith::ConstantOp>(loc,
b.getIntegerAttr(b.getI64Type(), val));
static SmallVector<Value>
getAsConstantIndexValues(OpBuilder &b, Location loc,
SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
return b.create<arith::ConstantOp>(loc, b.getIndexAttr(val));
static SmallVector<OpFoldResult>
getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
// This is a temporary solution to deal with types that are not fully supported
// like list, dict. For those container tyes, this helper can be used to
// convert their elements to valid target type.
// TODO: remove this when list gets full support.
static SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
TypeConverter *converter,
SmallVectorImpl<Value> &vs) {
return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) {
return converter->materializeTargetConversion(
b, loc, converter->convertType(v.getType()), v);
// Helper function to get the padding tensor given the padding int values.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &lowPaddingInts,
SmallVectorImpl<int64_t> &highPaddingInts,
Value pad) {
Location loc = op->getLoc();
Type rankedTensorType = tensor::PadOp::inferResultType(
input.getType().cast<RankedTensorType>(), lowPaddingInts,
SmallVector<OpFoldResult> lowPaddings =
getAsOpFoldResult(b, loc, lowPaddingInts);
SmallVector<OpFoldResult> highPaddings =
getAsOpFoldResult(b, loc, highPaddingInts);
Value paddedInput = tensor::createPadScalarOp(
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
/*packing=*/false, loc, b);
return paddedInput;
// Helper function to get the padding tensor given the padding int values.
// It's assumed that the padding on the low end and high end are the same,
// and that zero padding is required.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts) {
assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType");
Location loc = op->getLoc();
Value c0 = b.create<arith::ConstantOp>(
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean,
Value sigma) {
Type elementType = x.getType();
Value xMinusMean = b.create<arith::SubFOp>(loc, x, mean);
Value two = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 2));
Value sqrt2 = b.create<math::SqrtOp>(loc, two);
Value erfArg = b.create<arith::DivFOp>(loc, xMinusMean, sqrt2);
Value erf = b.create<math::ErfOp>(loc, erfArg);
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
Value erfPlus1 = b.create<arith::AddFOp>(loc, one, erf);
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value normalCdf = b.create<arith::MulFOp>(loc, oneHalf, erfPlus1);
return normalCdf;
static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
Type elementType = x.getType();
Value zero = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
return buildNormalCdf(b, loc, x, zero, one);
namespace {
class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenAdaptiveAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.self(); /* in form of N*C*H*W */
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type elementType = inputType.getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
auto inputRank = inputType.getRank();
if (inputRank != 4)
return rewriter.notifyMatchFailure(op, "input should be rank 4");
SmallVector<int64_t, 2> expects{1, 1};
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
if (!isConstantIntListMatching(op.output_size(), expects))
return rewriter.notifyMatchFailure(
op, "only support output_size with H and W both equal to constant 1");
Value N = getDimOp(rewriter, loc, input, 0);
Value C = getDimOp(rewriter, loc, input, 1);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C}, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value initTensor0 =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
SmallVector<AffineExpr, 2> ncExprs;
ncExprs.push_back(mlir::getAffineDimExpr(0, context));
ncExprs.push_back(mlir::getAffineDimExpr(1, context));
auto ncIndexingMap = AffineMap::get(
/*symbolCount=*/0, ncExprs, context);
SmallVector<AffineMap, 2> indexingMaps = {
rewriter.getMultiDimIdentityMap(4), // input
ncIndexingMap, // output
SmallVector<StringRef, 4> iteratorTypesSum{"parallel", "parallel",
"reduction", "reduction"};
Value sumPool2d = rewriter
loc, initTensor0.getType(), input, initTensor0,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result = rewriter.create<arith::AddFOp>(
loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
// Calculate H*W so that avg can be got from sum / (H*W)
Value H = getDimOp(rewriter, loc, input, 2);
Value W = getDimOp(rewriter, loc, input, 3);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(
loc, IntegerType::get(context, 64), v);
Value HtimesW = rewriter.create<arith::MulIOp>(loc, castIndexToInt(H),
Value HtimesWf =
rewriter.create<arith::SIToFPOp>(loc, elementType, HtimesW);
Value c1Index = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
Value outputTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C, c1Index, c1Index}, elementType);
SmallVector<AffineMap, 2> indexingMapsAvg{
ncIndexingMap, rewriter.getMultiDimIdentityMap(4)};
SmallVector<StringRef, 4> iteratorTypesAvg(4, "parallel");
Value avgPool2d =
loc, outputTensor.getType(), sumPool2d, outputTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value avg = b.create<arith::DivFOp>(loc, args[0], HtimesWf);
b.create<linalg::YieldOp>(loc, avg);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, avgPool2d);
return success();
} // namespace
namespace {
class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenConv2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.input(); /* in form of N*C*H*W */
Value weight = adaptor.weight(); /* in form of F*C*H*W */
Value groups = adaptor.groups();
Type elementType =
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
Type intType = IntegerType::get(context, 64);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
Value N = getDimOp(rewriter, loc, input, 0);
Value Hin = getDimOp(rewriter, loc, input, 2);
Value Win = getDimOp(rewriter, loc, input, 3);
Value F = getDimOp(rewriter, loc, weight, 0);
Value weightH = getDimOp(rewriter, loc, weight, 2);
Value weightW = getDimOp(rewriter, loc, weight, 3);
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure(
op, "only support constant padding values");
SmallVector<int64_t, 2> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
Value c1 =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, groups, c1);
rewriter.create<AssertOp>(loc, groupEqual1,
rewriter.getStringAttr("expect groups to be 1"));
// Pad the input tensor according to padding.
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
Value paddedInput =
getPaddedTensor(op, rewriter, input, paddingIncludingNC);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
Value Hout = getOutputDimForConvOps(
rewriter, loc, Hin, paddingIntValues[0], dilationIntValues[0],
castIndexToInt(weightH), strideIntValues[0]);
Value Wout = getOutputDimForConvOps(
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
castIndexToInt(weightW), strideIntValues[1]);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, F, Hout, Wout}, elementType);
Value bias = adaptor.bias();
Value biasInitTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
if (elementType != biasType.getElementType())
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor = rewriter
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value conv2d =
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
biasInitTensor, stridesAttr, dilationAttr)
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
return success();
} // namespace
// Normalization formula:
// ((input - mean) / sqrt(var + eps)) * weight + bias
static Value createLinalgPayloadCalculationForNormOps(
OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value var,
Value eps, Value weight, Value bias) {
Value inputSubMean = b.create<arith::SubFOp>(loc, input, mean);
// The eps is always f64.
Value truncatedEps = b.create<arith::TruncFOp>(loc, elemTy, eps);
Value varPlusEps = b.create<arith::AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
Value temp = b.create<arith::MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<arith::MulFOp>(loc, temp, weight);
Value plusBias = b.create<arith::AddFOp>(loc, timesWeight, bias);
return plusBias;
static void createLinalgPayloadCalculationForGatherOps(
OpBuilder &b, Location loc, Value input, int64_t inputRank, Value index,
int64_t dim, int64_t outputRank) {
SmallVector<Value> indices;
for (int i = 0; i < inputRank; i++) {
if (i == dim) {
indices.push_back(castIntToIndex(b, loc, index));
} else {
// `outputRank` might be larger than `inputRank`. The `linalg::IndexOp`
// takes in the dimension of the output. Add `inputDimOffset` to
// related to the correct dimension of the output for dimension larger
// than the given `dim`.
int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank;
indices.push_back(b.create<linalg::IndexOp>(loc, i + inputDimOffset));
// Assert index < input.sizes[dim]
Value indexLTInputDim = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, index,
castIndexToInt(b, loc, getDimOp(b, loc, input, dim)));
b.create<AssertOp>(loc, indexLTInputDim,
b.getStringAttr("index must be smaller than dim size"));
// Assert index >= 0
Value cst0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
Value indexGEThanZero =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
b.create<AssertOp>(loc, indexGEThanZero,
b.getStringAttr("index must be larger or equal to 0"));
Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
b.create<linalg::YieldOp>(loc, extract);
namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value runningMean = adaptor.running_mean();
Value runningVar = adaptor.running_var();
Value training =;
Value eps = adaptor.eps();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)) ||
failed(checkNotNone(rewriter, op, runningMean)) ||
failed(checkNotNone(rewriter, op, runningVar)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (inputRank <= 2)
return rewriter.notifyMatchFailure(
op, "input should have rank larger than 2");
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1");
// TODO: Add support for training.
auto constFalse = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
auto trainingFalse = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, training, constFalse);
loc, trainingFalse,
rewriter.getStringAttr("training is not supported for now"));
// num_features C from an expected input of size (N,C,D,H,W ...)
Value numFeatures = rewriter.create<tensor::DimOp>(loc, input, 1);
auto contractingDim0EqualsNumFeatures = [&](Value v) {
auto dim0 = rewriter.create<tensor::DimOp>(loc, v, 0);
auto dim0Equal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, numFeatures, dim0);
loc, dim0Equal,
"expect the size of dim 0 equal to the number of features"));
auto indexingMap = AffineMap::get(
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
indexingMap, // weight
indexingMap, // bias
indexingMap, // runningMean
indexingMap, // runningVar
rewriter.getMultiDimIdentityMap(inputRank), // output
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
Value batchNorm =
loc, input.getType(),
ValueRange{input, weight, bias, runningMean, runningVar}, input,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4];
Value result = createLinalgPayloadCalculationForNormOps(
b, loc, var.getType(), input, mean, var, eps, weight,
b.create<linalg::YieldOp>(loc, result);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
return success();
} // namespace
// For layernorm, the mean and standard-deviation are calculated separately over
// the last certain number dimensions which have to be of the shape specified by
// normalized_shape.
// The shapes of different parts are as the following:
// +-------------------+--------------------+
// | meanAndVarShape | normalizedShape |
// +-------------------+---------------------
// <------------+ inputShape +-------------->
// There are the following steps:
// Step 1. Check if all the arguments meet the requirements.
// Step 2. Common parts to be used for getting mean and var.
// This includes elements count, affineMap and iteratorTypes.
// Step 3. Get mean.
// Step 4. Get var.
// Step 5. Get layernorm.
namespace {
class ConvertAtenNativeLayerNormOp
: public OpConversionPattern<AtenNativeLayerNormOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value eps = adaptor.eps();
Value normalizedShape = op.normalized_shape();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType();
// Step 1. Check if all the arguments meet the requirements.
SmallVector<Value> normalizedShapeSizesTorchInt;
if (!getListConstructElements(normalizedShape,
normalizedShapeSizesTorchInt)) {
return rewriter.notifyMatchFailure(op,
"Unimplemented normalized_shape not"
"constructed from ListConstruct");
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
if (weightType.getRank() != normalizedShapeRank ||
biasType.getRank() != normalizedShapeRank ||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
"normalized shape not compatible");
// Check all the dimensions match the normalized_shape
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
for (auto en : enumerate((normalizedShapeSizesInt))) {
auto index = en.index();
auto inputDim =
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
auto weightDim = getDimOp(rewriter, loc, weight, index);
auto biasDim = getDimOp(rewriter, loc, bias, index);
auto expectedSize = en.value();
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
// Get iterator types for input shape.
SmallVector<StringRef> normalizedShapeIteratorTypes(
normalizedShapeRank, getReductionIteratorTypeName());
SmallVector<StringRef> meanAndVarIterationTypes(
meanAndVarShapeRank, getParallelIteratorTypeName());
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
// Step 2. Common parts to be used for getting mean and var.
// Get sizes and affineMaps needed for mean and var.
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<AffineExpr> meanAndVarShapeExprs;
for (int i = 0; i < meanAndVarShapeRank; i++)
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto meanAndVarShapeAffineMap = AffineMap::get(
/*symbolCount=*/0, meanAndVarShapeExprs, context);
SmallVector<Value> meanAndVarShapeSizes =
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
// Get number of elements to be used for calculating mean and var.
Value elemCnts = normalizedShapeSizesInt[0];
for (int i = 1; i < normalizedShapeRank; i++) {
elemCnts = rewriter.create<arith::MulIOp>(loc, elemCnts,
Value elemCntsFloat =
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);
// Helper to calculate mean and var.
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
SmallVector<AffineMap> indexingMaps(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
return rewriter
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value sumOrSqureSum = args[0];
Value result =
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
b.create<linalg::YieldOp>(loc, result);
// Step 3. Get mean.
// Get sum to be used for calculating mean.
SmallVector<AffineMap, 2> sumIndexingMaps = {
inputShapeAffineMap, // input
meanAndVarShapeAffineMap, // output
auto initSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value sum = rewriter
loc, initSumTensor.getType(), input, initSumTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result =
rewriter.create<arith::AddFOp>(loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
Value mean = genMeanOrVarCalculation(sum);
// Step 4. Get var.
// Calculate squareSum for the layer.
SmallVector<AffineMap> squareSumIndexingMaps{
auto initSquareSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value squareSum =
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], squareSum = args[2];
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
Value result =
rewriter.create<arith::AddFOp>(loc, squareSum, square);
b.create<linalg::YieldOp>(loc, result);
Value var = genMeanOrVarCalculation(squareSum);
// Step 5. Get layernorm.
// Get affineMap for normalized shape.
SmallVector<AffineExpr> normalizedShapeExprs;
for (int i = meanAndVarShapeRank; i < inputRank; i++)
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto normalizedShapeAffineMap = AffineMap::get(
/*symbolCount=*/0, normalizedShapeExprs, context);
auto inputSizes = getTensorSizes(rewriter, loc, input);
Value initLayerNormTensor =
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
indexingMaps.resize(3, meanAndVarShapeAffineMap);
indexingMaps.resize(5, normalizedShapeAffineMap);
SmallVector<StringRef> layerNormIterationTypes(
inputRank, getParallelIteratorTypeName());
Value layerNorm =
loc, initLayerNormTensor.getType(),
ValueRange{input, mean, var, weight, bias}, initLayerNormTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], var = args[2],
weight = args[3], bias = args[4];
Value result = createLinalgPayloadCalculationForNormOps(
b, loc, elemTy, input, mean, var, eps, weight, bias);
b.create<linalg::YieldOp>(loc, result);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type varResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success();
} // namespace
namespace {
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenMmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = adaptor.self();
Value rhs = adaptor.mat2();
// A user can write an errorneous program where `` is in fact called
// with operands of invalid rank or dtype. We cannot convert to linalg in
// this case or we will get a verifier error, which corresponds to breaking
// of *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a
// mature compiler usually have the verifier turned off for compile time
// reasons).
// The compiler cannot crash even if the user wrote an erroneous program!
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
return rewriter.notifyMatchFailure(
op, "expected both operands to to be rank 2");
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
loc, contractingDimEqual,
"mismatching contracting dimension for"));
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value zeroFill =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Value matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
ValueRange{lhs, rhs}, zeroFill)
// When constructed with just dynamic sizes, InitTensorOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result
// of the MatmulOp will have this type too. So cast it to the desired type
// so that in the end we have the original result type.
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
} // namespace
namespace {
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = adaptor.self();
Value rhs = adaptor.other();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
// The different cases of torch_matmul op is mentioned here:
// First Case: Dot Product.
if (lhsRank == 1 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
Value dotProd =
.create<linalg::DotOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
return success();
// Second Case: Vec-Mat Multiplication.
if (lhsRank == 1 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
Value matmul =
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
// Third Case: Matrix-Vec Multiplication.
if (lhsRank == 2 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
Value matmul =
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
// Fourth Case: Batch-Matrix Multiplication.
// TODO: Broadcasting of batch dimension is remaining.
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
unsigned batchRank = lhsRank - 2;
SmallVector<Value, 4> resultShape;
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<StringRef> iteratorTypes;
// Since broadcasting is a TODO, check whether the lhs and rhs batch
// dimension match.
for (unsigned i = 0; i < batchRank; i++) {
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Push the final matrix dimension.
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 1)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
rewriter.getAffineDimExpr(batchRank + 2)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 2)});
Value initTensor0 =
createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps =
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
{"parallel", "reduction", "parallel"});
Value finalRes =
loc, newResultType, ValueRange{lhs, rhs}, initTensor0,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value mul = b.create<arith::MulFOp>(loc, l, r);
Value add = b.create<arith::AddFOp>(loc, mul, res);
b.create<linalg::YieldOp>(loc, add);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
return failure();
} // namespace
namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenBmmOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value lhs = adaptor.self();
Value rhs = adaptor.mat2();
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected both operands to aten.bmm to be rank 3");
if (!lhsType.getElementType().isa<mlir::FloatType>() ||
lhsType.getElementType() != rhsType.getElementType())
return op.emitError(
"unimplemented: non floating point operands or operands of "
"different types");
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2);
// Check the batch numbers are equal.
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
// Check the matrixs shapes are valid for mulplication.
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor0 = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType);
Value bmm =
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
ValueRange{lhs, rhs}, initTensor0)
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
return success();
} // namespace
namespace {
class ConvertAtenDropoutOp : public OpConversionPattern<AtenDropoutOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenDropoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
bool train;
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
return rewriter.notifyMatchFailure(op,
"Expected train to be constant bool.");
if (train)
return failure();
auto resultType = getTypeConverter()
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
return success();
} // namespace
// Given `input`, `target`, `nll_loss_forward` is given by:
// for i in range(0, len(target)):
// indi = target[i];
// nll_loss_forward[i] = -(input[i][indi]);
// TODO: `weight` and `reduction` operands are still to be taken care of.
namespace {
class ConvertAtenNllLossForwardOp
: public OpConversionPattern<AtenNllLossForwardOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenNllLossForwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value input = adaptor.self();
Value target =;
Value weight = adaptor.weight();
int64_t reduce_dim;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduce_dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Handle reduction.
if (reduce_dim != 0)
return rewriter.notifyMatchFailure(
op, "reduction along dimensions is not supported.");
// TODO: Incorporate the weight argument.
if (!weight.getType().isa<mlir::torch::Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");
Value ignoreIndex = adaptor.ignore_index();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
// TODO: Cases with targetRank != 1 where `Mean` reduction is required.
if (inputRank != 2 || targetRank != 1) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank 2 and 1 respectively");
RankedTensorType resultType = getTypeConverter()
Type elementType = resultType.getElementType();
Value targetDim = getDimOp(rewriter, loc, target, 0);
Value initTensor0 =
createZeroInitTensor(rewriter, loc, {targetDim}, elementType);
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
SmallVector<AffineExpr> targetExpr;
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName()};
auto indexingMaps = AffineMap::inferFromExprList({targetExpr, targetExpr});
Value finalRes =
loc, resultType, ValueRange{target}, initTensor0,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value indI = rewriter.create<linalg::IndexOp>(loc, 0);
// The final result is given by:
// final_res = (indI == ignoreIndexVal) ? 0 :
// input[indI][IndTarget]
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indI, ignoreIndexVal);
Value result = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{indI, indTarget});
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, result);
Value selectFinal = rewriter.create<mlir::SelectOp>(
loc, cmpEq, zeroVal, negate);
b.create<linalg::YieldOp>(loc, selectFinal);
// TODO: Update the second result tensor.
Value weightUpdated =
createZeroInitTensor(rewriter, loc, {}, elementType);
rewriter.replaceOp(op, {finalRes, weightUpdated});
return success();
} // namespace
// Given `grad_output`, `input`, `target`, `nll_loss_backward` is given by:
// for i in range(0, len(input[0])):
// for j in range(0, len(input[1])):
// nll_loss_backward[i][j] = (j == target[i]) ? -grad_output[i] : 0
// TODO: `weight` and `reduction` operands are still to be taken care of.
namespace {
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenNllLossBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value input = adaptor.self();
Value target =;
Value weight = adaptor.weight();
Value gradOutput = adaptor.grad_output();
int64_t reduction;
if (!matchPattern(op.reduction(), m_TorchConstantInt(&reduction)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
// TODO: Handle reduction.
if (reduction != Reduction::None)
return rewriter.notifyMatchFailure(
op, "reduction along dimensions is not supported.");
// TODO: Incorporate the weight argument.
if (!weight.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented, the weight operand is not incorporated.");
Value ignoreIndex = adaptor.ignore_index();
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
// TODO: Cases with targetRank != 1 where `Mean` or `Sum` reduction is
// required.
if (inputRank != 2 || targetRank != 1) {
return rewriter.notifyMatchFailure(
op, "expected input and target to be rank 2 and 1 respectively");
RankedTensorType resultType = getTypeConverter()
Type elementType = resultType.getElementType();
// Given there is no reduction `grad_input` size is equal to `input` size.
auto outputSize = getTensorSizes(rewriter, loc, input);
Value initTensor0 =
createZeroInitTensor(rewriter, loc, outputSize, elementType);
Value zeroVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
SmallVector<AffineExpr> targetExpr{rewriter.getAffineDimExpr(0)};
SmallVector<AffineExpr> resultExpr{rewriter.getAffineDimExpr(0),
SmallVector<StringRef> iteratorTypes{getParallelIteratorTypeName(),
auto indexingMaps =
AffineMap::inferFromExprList({targetExpr, targetExpr, resultExpr});
Value finalRes =
loc, resultType, ValueRange{target, gradOutput}, initTensor0,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indTarget = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
Value indJ = rewriter.create<linalg::IndexOp>(loc, 1);
// The final result is given by:
// grad_input[i][j] = (j == target[i]) ? -grad_output[i] : 0
Value cmpEq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, indJ, indTarget);
// The target index shouldn't be equal to `ignoreIndex`.
Value cmpNe = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, ignoreIndexVal, indTarget);
Value finalPredicate =
rewriter.create<arith::AndIOp>(loc, cmpEq, cmpNe);
Value negate =
rewriter.create<arith::NegFOp>(loc, elementType, args[1]);
Value selectFinal = rewriter.create<mlir::SelectOp>(
loc, finalPredicate, negate, zeroVal);
b.create<linalg::YieldOp>(loc, selectFinal);
rewriter.replaceOp(op, finalRes);
return success();
} // namespace
namespace {
// See comments at in convertMmOp and the heading for this section for general
// considerations. This function needs to be auto-generated.
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
// TODO: Handle the case of bias being None (bias is optional).
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or rank 3");
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
inputType.getElementType() != biasType.getElementType()) {
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
// TODO: We can handle a static size 1 here at some complexity cost, but the
// dynamic case is not representable in linalg. We don't handle either for
// now. Biases are generally statically shaped for most models (since for
// inference they are constants, and for training they don't change shape
// typically), so this is not too constraining.
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
Value batchDim = nullptr;
int restDim = 0;
if (inputType.getRank() == 3) {
batchDim = getDimOp(rewriter, loc, input, 0);
restDim = 1;
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
loc, contractingDimEqual,
"mismatching contracting dimension for aten.linear"));
// Here we take advantage of ruling out the size-1 case above.
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
Value initTensor;
SmallVector<AffineMap> broadcastIndexingMaps;
Value transposedWeightInitTensor;
if (inputType.getRank() > 2) {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, inputDim0, weightDim0},
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, weightDim1, weightDim0},
broadcastIndexingMaps = {
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim)}, context),
} else {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0},
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
broadcastIndexingMaps = {
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)}, context),
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted =
loc, initTensor.getType(), bias, initTensor,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
// TODO: This whole aten.linear lowering should eventually be generated from
// a single linalg ODS generator statement. Both the bias and matmul part.
SmallVector<AffineMap> transposeIndexingMaps = {
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim),
rewriter.getAffineDimExpr(0 + restDim)},
Value transposedWeights =
loc, transposedWeightInitTensor.getType(), weight,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
Value matmul;
if (batchDim)
matmul = rewriter
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
matmul = rewriter
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
} // namespace
// Convert a scalar value to the target type. The scalar value can be an element
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
// should be converted builtin types.
static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
Type dtype) {
Type scalarType = scalar.getType();
if (scalarType == dtype)
return scalar;
// TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to
// be able to know if we need signed or unsigned conversion.
auto isByteOrChar = [](Type type) {
if (auto integerTy = type.dyn_cast<mlir::IntegerType>()) {
return integerTy.getWidth() == 8;
return false;
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
dtype.isSignlessInteger(1)) {
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
<< "unsupported byte, char or bool type for convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
return b.create<arith::TruncFOp>(loc, scalar, dtype);
// Only scalarFloat width < dtypeFloat width can reach here.
return b.create<arith::ExtFOp>(loc, scalar, dtype);
if (scalarType.isSignlessInteger(1))
return b.create<arith::UIToFPOp>(loc, scalar, dtype);
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::SIToFPOp>(loc, scalar, dtype);
if (auto dtypeInteger = dtype.dyn_cast<mlir::IntegerType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>())
return b.create<arith::FPToSIOp>(loc, scalar, dtype);
auto scalarInteger = scalarType.cast<mlir::IntegerType>();
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
return b.create<arith::TruncIOp>(loc, scalar, dtype);
if (scalarType.isSignlessInteger(1))
return b.create<arith::ExtUIOp>(loc, scalar, dtype);
// Only scalarInteger width < dtypeInteger width can reach here.
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
// unsigned handling is needed, and we checked for that case above.
return b.create<arith::ExtSIOp>(loc, scalar, dtype);
llvm_unreachable("convertScalarToDtype should handle all the types");
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
if (isa<AtenTanhOp>(op))
return b.create<math::TanhOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op))
return b.create<math::ExpOp>(loc, payloadArgs[0]);
if (isa<AtenFloorOp>(op))
return b.create<math::FloorOp>(loc, payloadArgs[0]);
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenLogOp>(op))
return b.create<math::LogOp>(loc, payloadArgs[0]);
if (isa<AtenSqrtOp>(op))
return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenRsqrtOp>(op))
return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
if (!clone.memory_format().getType().isa<Torch::NoneType>()) {
clone.emitError("unimplemented: only default memory format is supported");
return nullptr;
return payloadArgs[0];
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType()
.isa<mlir::FloatType>()) {
"Bitwise_And does not support floating point dtype");
return nullptr;
Type dtype = converter->convertType(bitwiseAndTensor.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
if (isa<AtenLog2Op>(op))
return b.create<math::Log2Op>(loc, payloadArgs[0]);
if (isa<AtenAbsOp>(op))
return b.create<math::AbsOp>(loc, payloadArgs[0]);
if (isa<AtenSigmoidOp>(op)) {
Type elementType = payloadArgs[0].getType();
auto one = b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto negate = b.create<arith::NegFOp>(loc, payloadArgs[0]);
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
return b.create<arith::DivFOp>(loc, one, added);
if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType()
.isa<mlir::FloatType>()) {
relu.emitError("unimplemented: non-floating point dtype");
return nullptr;
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
if (!lrelu.getType()
.isa<mlir::FloatType>()) {
lrelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType()
.isa<mlir::FloatType>()) {
gelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
if (auto geluBackward = dyn_cast<AtenGeluBackwardOp>(op)) {
if (!geluBackward.getType()
.isa<mlir::FloatType>()) {
geluBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
Type elementType = payloadArgs[1].getType();
Value cstAlpha0 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 1.12837916709551257390));
Value cstAlpha1 = b.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.70710678118654752440));
Value oneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.5));
Value kAlpha = b.create<arith::MulFOp>(loc, cstAlpha0, cstAlpha1);
Value kAlphaHalf = b.create<arith::MulFOp>(loc, kAlpha, oneHalf);
Value negOneHalf =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, -0.5));
Value inputSquared =
b.create<arith::MulFOp>(loc, payloadArgs[1], payloadArgs[1]);
Value negHalfInputSquared =
b.create<arith::MulFOp>(loc, inputSquared, negOneHalf);
Value dinput = b.create<math::ExpOp>(loc, negHalfInputSquared);
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]);
Value dinputInput = b.create<arith::MulFOp>(loc, dinput, payloadArgs[1]);
Value dinputInputAlpha =
b.create<arith::MulFOp>(loc, dinputInput, kAlphaHalf);
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(add.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::AddIOp>(loc, lhs, scaled);
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
AtenSubTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(sub.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value alpha = convertScalarToDtype(b, loc, adaptor.alpha(), dtype);
if (dtype.isa<mlir::FloatType>()) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::SubFOp>(loc, lhs, scaled);
} else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::SubIOp>(loc, lhs, scaled);
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
Type dtype = converter->convertType(subScalar.getType())
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
if (dtype.isa<mlir::FloatType>()) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::SubFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::SubIOp>(loc, self, mult);
subScalar.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
Type dtype = converter->convertType(addScalar.getType())
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
if (dtype.isa<mlir::FloatType>()) {
Value mult = b.create<arith::MulFOp>(loc, other, alpha);
return b.create<arith::AddFOp>(loc, self, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
Value mult = b.create<arith::MulIOp>(loc, other, alpha);
return b.create<arith::AddIOp>(loc, self, mult);
addScalar.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
AtenMulTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(mul.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
return nullptr;
Type elementalType =
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], payloadArgs[1]);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], payloadArgs[1]);
gtTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
AtenEqTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
Type elementalType =
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], payloadArgs[1]);
if (elementalType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], payloadArgs[1]);
eqTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
AtenLtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
Type elementalType =
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], payloadArgs[1]);
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], payloadArgs[1]);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], payloadArgs[1]);
ltTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(div.getType())
if (!dtype.isa<mlir::FloatType>())
div.emitError("unimplemented: non-floating point dtype");
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, lhs, rhs);
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType()
.isa<mlir::FloatType>()) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
Type dtype = pow.self().getType().cast<ValueTensorType>().getDtype();
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = gtScalar.self().getType().cast<BaseTensorType>().getDtype();
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
// one static function.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], otherPromoted);
gtScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = eqScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], otherPromoted);
if (dtype.isa<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
"unimplemented: type promotion from tensor to scalar");
return nullptr;
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], otherPromoted);
eqScalar.emitError("unimplemented: dtype isn't supported");
return nullptr;
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = ltScalar.self().getType().cast<BaseTensorType>().getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `` and `` share a
// lot of code that can be refactored.
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
"unimplemented: type promotion from tensor to scalar");
return nullptr;
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Type dtype = converter->convertType(whereSelf.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
return b.create<SelectOp>(loc, payloadArgs[0], lhs, rhs);
if (auto lerp = dyn_cast<AtenLerpTensorOp>(op)) {
if (!lerp.getType()
.isa<mlir::FloatType>()) {
lerp.emitError("unimplemented: non-floating point dtype");
return nullptr;
AtenLerpTensorOp::Adaptor adaptor(payloadArgs);
auto start = adaptor.self();
auto end = adaptor.end();
auto weight = adaptor.weight();
auto delta = b.create<arith::SubFOp>(loc, end, start);
auto weightedDelta = b.create<arith::MulFOp>(loc, delta, weight);
return b.create<arith::AddFOp>(loc, start, weightedDelta);
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
if (!minimum.getType()
.isa<mlir::FloatType>()) {
minimum.emitError("unimplemented: non-floating point dtype");
return nullptr;
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
if (!maximum.getType()
.isa<mlir::FloatType>()) {
maximum.emitError("unimplemented: non-floating point dtype");
return nullptr;
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], payloadArgs[1]);
return b.create<SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
Type dtype = converter->convertType(clamp.getType())
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.min();
auto max = adaptor.max();
if (min.getType().isa<Torch::OptionalType>() ||
max.getType().isa<Torch::OptionalType>()) {
clamp.emitError("unimplemented: runtime optional type");
return nullptr;
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
result = b.create<SelectOp>(loc, pred, minPromoted, result);
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
result = b.create<SelectOp>(loc, pred, maxPromoted, result);
return result;
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
Type dtype = converter->convertType(rsub.getType())
if (!dtype.isa<mlir::FloatType>()) {
rsub.emitError("unimplemented: non-floating point dtype");
return nullptr;
Value self = payloadArgs[0];
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
Type dtype = converter->convertType(mulScalar.getType())
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
if (dtype.isa<mlir::FloatType>())
return b.create<arith::MulFOp>(loc, lhs, rhs);
if (dtype.isa<mlir::IntegerType>())
return b.create<arith::MulIOp>(loc, lhs, rhs);
mulScalar.emitError("unimplemented: Only integer/float dtype supported");
return nullptr;
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
Type dtype = converter->convertType(atenToDtype.getType())
Value result = convertScalarToDtype(b, loc, input, dtype);
return result;
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
Type dtype = converter->convertType(divScalar.getType())
if (!dtype.isa<mlir::FloatType>()) {
divScalar.emitError("unimplemented: non-floating point dtype");
return nullptr;
Value self = payloadArgs[0];
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::DivFOp>(loc, self, other);
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
if (!reciprocal.getType()
.isa<mlir::FloatType>()) {
reciprocal.emitError("unimplemented: non-floating point dtype");
return nullptr;
Type elementType = payloadArgs[0].getType();
// assert(element != 0)
auto zero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ONE,
payloadArgs[0], zero);
loc, pred, b.getStringAttr("unimplemented: tensor with zero element"));
auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? value : self
AtenThresholdOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdOp.getType())
Value self = payloadArgs[0];
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype);
Value predicate;
if (dtype.isa<mlir::FloatType>())
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
return b.create<SelectOp>(loc, predicate, value, self);
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? 0 : grad
AtenThresholdBackwardOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdBackward.getType())
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate;
if (dtype.isa<mlir::FloatType>())
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
return b.create<SelectOp>(loc, predicate, constantZero, grad);
op->emitError("unimplemented lowering in "
return nullptr;
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
Operation *op,
Type elementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
return b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
if (isa<AtenMaxOp>(op) && elementType.isa<mlir::FloatType>())
return b.create<arith::ConstantOp>(
loc, b.getFloatAttr(
op->emitError("unimplemented lowering in "
return nullptr;
static Value createLinalgPayloadCalculationForReduceOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands, Type resultElementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
resultElementType.isa<mlir::FloatType>()) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
return b.create<arith::AddFOp>(loc, self, result);
} else if (isa<AtenMaxOp>(op) && resultElementType.isa<mlir::FloatType>()) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
return b.create<arith::MaxFOp>(loc, self, result);
op->emitError("unimplemented lowering in "
return nullptr;
namespace {
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
// op, producing two output buffers.
// The first output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type.
// The second output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = maxDimOp.getLoc();
Value input = adaptor.self();
RankedTensorType valResultType =
RankedTensorType idxResultType =
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
"aten.max_dim to linalg.* requires integer-like result type");
bool keepDim = false;
if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim)))
return failure();
int64_t dim;
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
dim = toPositiveDim(dim, inputType.getRank());
if (!isValidDim(dim, inputType.getRank()))
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
"aten.max_dim to linalg.* requires Float input element type");
// Constant op to account for the reduction along dim.
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dim != i) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
} else if (keepDim)
// First fill the output buffer for the index.
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);
// Second fill the output buffer for the running max.
Value initTensorMax =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, inElementType)
FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
Value fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
if (unsigned(dim) == size.index()) {
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
} else {
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value newValue = blockArgs[0];
Value oldValue = blockArgs[1];
Value oldIndex = blockArgs[2];
Value newIndex = rewriter.create<arith::IndexCastOp>(
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));
Value predicate;
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<mlir::SelectOp>(nestedLoc, predicate,
newValue, oldValue);
auto resultIndex = rewriter.create<mlir::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedLoc, ValueRange({resultMax, resultIndex}));
// This cast is required to fix the shape in the case of keepDim=True
Value maxValuesCast = rewriter.create<tensor::CastOp>(
loc, valResultType, linalgOp.getResult(0));
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
return success();
} // namespace
namespace {
// Converts an elementwise op.
// This specifically includes:
// - converting elementwise ops of any tensor arity
// - converting elementwise ops with any number of scalar captures (such as a
// scalar alpha to torch.aten.Add)
// - broadcasting of static size-1 dimensions
// Currently, we adopt the behavior that "size 1" broadcasting is a runtime
// error if it happens dynamically.
// Looking forward a bit, eventually, it probably makes sense to have
// a "linalg.generic-like" op for modeling a fused subgraph of numpy-broadcasted
// operands. Modeling elementwise ops that way is potentially useful to allow a
// more centralized reasoning about multiversioning. However a cost model will
// be needed for "pre-fusing" elementwise ops that way, as it can potentially be
// a pessimization. A mild extension of this pattern should work for such a
// general op.
struct ConvertElementwiseOp : ConversionPattern {
ConvertElementwiseOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range(
operands, [](Value v) { return v.getType().isa<RankedTensorType>(); }));
auto resultType = getTypeConverter()
auto resultRank = resultType.getRank();
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
// The overall error handling strategy here is best viewed by thinking about
// what happens for a single result dimension. This loop not structured that
// way because it is hard to create the affine maps for each operand unless
// we structure the loop to iterate over tensor operands as the outer loop
// instead of inner loop. This pseudocode gives better intuition:
// ```
// for each result dimension:
// for each tensor operand:
// if it doesn't even have high enough rank relative to the result:
// continue
// if it is a static size-1 along this result dimension:
// continue
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
// Initialize the resultShape to all 1's, as a fallback in case
// all sizes along that result dimension are statically 1.
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
for (auto size : llvm::enumerate(type.getShape())) {
// If the size is statically known to be 1, we don't want any
// error guards to be spuriously emitted, since we are specifically
// allowing size-1 broadcasts in this case, as they correspond to a
// constant-0 indexing map.
if (size.value() == 1) {
// The rank of this operand might be smaller than the overall rank of
// the broadcast. Add an offset to correlate it to the correct
// dimension of the result.
auto resultDim = size.index() + (resultRank - type.getRank());
// The generated linalg op will now be iterating along the full size
// of this dimension. Record that fact.
// Now, we need to ensure that such iteration is not going to trigger
// undefined behavior, by doing appropriate checks against the current
// dimension size.
auto currentDimSize =
getDimOp(rewriter, loc, tensorOperand, size.index());
// If the result size of this dimension has so far only hit the
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
// new Value to `resultShape[resultDim]`), then we have no other dynamic
// values to check against, and merely need to record the current
// dimension size.
if (resultShape[resultDim] == c1) {
resultShape[resultDim] = currentDimSize;
// We prohibit the size-1 dynamic broadcasting scenario, so just check
// for exact equality with the running result size.
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, resultShape[resultDim],
rewriter.create<AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, getContext()));
SmallVector<StringRef> iteratorTypes(resultRank,
// Add the indexing map for the outs init tensor.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsOpFoldResult(resultShape), resultType.getElementType());
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/initTensor.getType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForElementwiseOp(
b, loc, getTypeConverter(), payloadArgs, op, operands);
if (!result) {
hadErrorCreatingPayload = true;
b.create<linalg::YieldOp>(loc, result);
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
return success();
} // namespace
namespace {
struct ConvertReductionOp : ConversionPattern {
ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
// This function is in charge of all the rewriting that will take
// place in `matchAndRewrite`. In particular, it converts
// the reduce operation into an `linalg.generic` operation
// to reduce the input tensor along the dimensions specified in
// `dimeSet`.
createReductionLinalgGeneric(Operation *op, ArrayRef<Value> operands,
const DenseSet<int64_t> &dimSet, bool keepDim,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
auto resultType = getTypeConverter()
// Get the result shape by obtaining the size of each
// dimension in the input tensor that is not getting reduced.
// If `keepDim` is true, the rank of the output tensor
// is kept the same as the rank of the input tensor, and the
// reduced dimensions are set to have size 1.
auto c1 = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
auto currentDimSize =
rewriter.create<tensor::DimOp>(loc, tensorOperand, i);
if (!dimSet.contains(i))
else if (keepDim)
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
// Here we also set the type of iterator: parallel or reduction.
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
SmallVector<AffineExpr> resultExprs;
for (auto size : llvm::enumerate(inputType.getShape())) {
if (dimSet.contains(size.index())) {
// If `keepDim`, create affine map to the first element
// in the current dimension.
if (keepDim)
} else {
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultShape, resultType.getElementType());
Value initValue = createLinalgNeutralElementForReduceOp(
rewriter, loc, op, resultType.getElementType());
Value accumulator =
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/accumulator.getType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = createLinalgPayloadCalculationForReduceOp(
b, loc, payloadArgs, op, operands, resultType.getElementType());
if (!result) {
hadErrorCreatingPayload = true;
b.create<linalg::YieldOp>(loc, result);
if (hadErrorCreatingPayload)
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
return success();
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// Every reduce operation must set a value for the `dimSet` and
// `keepDim` in accordance with their specification.
DenseSet<int64_t> dimSet;
bool keepDim = false;
if (isa<AtenSumOp>(op) || isa<AtenMaxOp>(op)) {
auto tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the
// input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
} else if (auto sumDimIntListOp = dyn_cast<AtenSumDimIntListOp>(op)) {
auto tensorOperand = operands[0];
auto inputType = tensorOperand.getType().cast<RankedTensorType>();
if (!matchPattern(sumDimIntListOp.keepdim(),
return failure();
SmallVector<int64_t> dimList;
if (!matchPattern(sumDimIntListOp.dim(), m_TorchConstantIntList(dimList)))
return failure();
for (auto dim : dimList) {
// Torch allows for negative values in dimSet to go in reverse
// order in the dimensions of the input tensor.
dim = dim >= 0 ? dim : dim + inputType.getRank();
// Drop invalid dimensions
if (dim < inputType.getRank())
} else {
return rewriter.notifyMatchFailure(op, "not a supported reduce op");
return createReductionLinalgGeneric(op, operands, dimSet, keepDim,
} // namespace
namespace {
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
Value ceilMode = adaptor.ceil_mode();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t, 2> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
SmallVector<int64_t, 2> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int paddings");
SmallVector<int64_t, 2> kernelSizeInts;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
Value falseValue = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
Value ceilModeFalse = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, ceilMode, falseValue);
loc, ceilModeFalse,
rewriter.getStringAttr("only ceil_mode false is supported"));
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
Value paddedInput = getPaddedTensor(op, rewriter, self, paddingIncludingNC);
Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1);
Value H = getDimOp(rewriter, loc, self, 2);
Value W = getDimOp(rewriter, loc, self, 3);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> kernelSizeIntValues =
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
Value Hout = getOutputDimForConvOps(
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
kernelSizeIntValues[0], strideIntValues[0]);
Value Wout = getOutputDimForConvOps(
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
kernelSizeIntValues[1], strideIntValues[1]);
// Initialize output tensor with smallest floating point value
Value outTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C, Hout, Wout}, elementType);
auto initialAttr = rewriter.getFloatAttr(
/*Negative*/ true));
Value initValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
Value outTensorInitialized =
rewriter.create<linalg::FillOp>(loc, initValue, outTensor).getResult(0);
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
Value maxPool2d = rewriter
loc, outTensorInitialized.getType(),
ValueRange{paddedInput, windowTensor},
outTensorInitialized, stridesAttr, dilationAttr)
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();
} // namespace
namespace {
class ConvertAtenConstantPadNdOp
: public OpConversionPattern<AtenConstantPadNdOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
auto type = self.getType().cast<RankedTensorType>();
int64_t rank = type.getRank();
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> padInts;
if (!matchPattern(op.pad(), m_TorchConstantIntList(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even");
if (rank < 0 || padRank > (uint64_t)rank)
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
// Initialize low/high paddings with the dims that should not be padded.
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
for (uint64_t i = padRank; i > 0; --i) {
lowPadding.push_back(padInts[i * 2 - 2]);
highPadding.push_back(padInts[i * 2 - 1]);
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.value(), elementType);
Value paddedInput = getPaddedTensor(op, rewriter, self, lowPadding,
highPadding, castedValue);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
return success();
} // namespace
namespace {
class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenFlattenUsingIntsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
int64_t startDim;
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&startDim)))
return rewriter.notifyMatchFailure(op, "start_dim must be constant");
int64_t endDim;
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&endDim)))
return rewriter.notifyMatchFailure(op, "end_dim must be constant");
auto type = adaptor.self().getType().cast<RankedTensorType>();
auto inputRank = type.getRank();
auto resultType =
if (startDim < 0)
startDim += inputRank;
if (endDim < 0)
endDim += inputRank;
if (inputRank == 0) {
SmallVector<ReassociationIndices> reassociation;
if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0))
return rewriter.notifyMatchFailure(
op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0");
op, resultType, adaptor.self(), reassociation);
return success();
if (startDim < 0 || startDim >= inputRank || endDim < 0 ||
endDim >= inputRank || startDim > endDim)
return rewriter.notifyMatchFailure(
op, "statically invalid flattening dim range");
SmallVector<ReassociationIndices> reassociation(resultType.getRank());
int j = 0;
for (auto i : llvm::seq<int64_t>(0, inputRank)) {
if (i < startDim || i >= endDim)
Value collapsedTensor = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), adaptor.self(), reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
return success();
} // namespace
namespace {
/// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to
/// `linalg.TensorExpandShape` op only when one or multiple static dimensions
/// are expanded. All the other cases of `aten.View` op need to be handled.
/// TODO: Handle all the other cases of `aten.View` op.
class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
int64_t inputRank = inputType.getRank();
TypeConverter *typeConverter = getTypeConverter();
auto resultType =
int64_t resultRank = resultType.getRank();
// Currently, we only handle the expanding OR collapsing cases, we do not
// handle expanding And collapsing happening at the same time or cases where
// it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor.
// TODO: For the expanding And collapsing case, we will need to identify
// which dimensions are collapsing and which are expanding and do it in two
// steps.
// TODO: For neither collapsing nor expanding, we could find a intermediate
// shape to collapse and then expanded to the target shape. Like [2,3] =>
// [6] => [3, 2].
if (inputRank == resultRank)
return rewriter.notifyMatchFailure(
op, "unimplemented: the view op is neither expanding nor collapsing");
if (resultRank == 0)
return rewriter.notifyMatchFailure(op,
"result shape of rank 0 is invalid");
// TODO: add support for case inputRank 0 expanded to size 1
if (inputRank == 0)
return rewriter.notifyMatchFailure(
op, "unimplemented: input rank 0 is not supported");
bool isCollapse = inputRank > resultRank ? true : false;
int64_t collapsedRank = isCollapse ? resultRank : inputRank;
int64_t expandedRank = isCollapse ? inputRank : resultRank;
// Extract the desired output size as a list of integers. This list should
// have been created using the operation `torch.prim.ListConstruct`.
SmallVector<Value> outputSizeTorchInt;
if (!getListConstructElements(op.size(), outputSizeTorchInt)) {
return rewriter.notifyMatchFailure(op,
"unimplemented: the target size is "
"not constructed from ListConstruct");
SmallVector<Value> outputSizeInt = getTypeConvertedValues(
rewriter, loc, typeConverter, outputSizeTorchInt);
if (resultRank != (int64_t)outputSizeInt.size()) {
return rewriter.notifyMatchFailure(
op, "desired size list length mismatches with the result type rank");
SmallVector<Value> inputSizeTorchInt = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> expandedShapeTorchInt =
llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt);
ArrayRef<Value> collapsedShapeTorchInt =
llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt);
// Iterate through the view op size list to do the following:
// 1. Combine output size list and input tensor type info to get the most
// static outputShape.
// 2. Fill in the reassociation for size list item where the output dim size
// is got from `, inputDim)`. We naively
// assume this means the corresponding dimension is not expanded or
// collapsed. Note this may technically not always be true.
// TODO: think of a way better way to at least detect when this assumption
// is violated.
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
SmallVector<ReassociationIndices> reassociation(collapsedRank);
for (auto en : llvm::enumerate(outputSizeTorchInt)) {
int64_t inputDim;
int64_t outputDim = en.index();
// Match, inputDim) with constant inputDim
if (matchPattern(en.value(),
m_TorchTensorSizeInt(op.self(), &inputDim))) {
auto collapsedDim = isCollapse ? outputDim : inputDim;
auto expandedDim = isCollapse ? inputDim : outputDim;
if (!inputType.isDynamicDim(inputDim)) {
outputShape[outputDim] = inputShape[inputDim];
int64_t size;
if (matchPattern(en.value(), m_TorchConstantInt(&size)))
outputShape[outputDim] = size;
SmallVector<int64_t> collapsedShape =
isCollapse ? outputShape : llvm::to_vector(inputShape);
SmallVector<int64_t> expandedShape =
isCollapse ? llvm::to_vector(inputShape) : outputShape;
// The while loop does the following:
// 1. Fill in the reassociation indices for dimensions that are expanded.
// Check the interval dimensions between two unchanged dims in the
// collapsedShape. If the interval is size 1, associate all the dims
// in the expandedShape shape until the next unchanged dim. If the interval
// is larger than size 1, figure out the associations with assumptions that
// dynamic dimensions are not splitted.
// 2. Set collapsedShape and expandedShape following the requirements by
// tensor.expand_shape verification code:
// a. As long as one or more of the related dimensions in the expanded
// shape is dynamic the collapsed dimension is dynamic.
// b. If all of the related dimensions are static, the collapsed
// dimension must be static. In other words, if a collapsed dimension is
// dynamic, at least one of the related dimensions need to be dynamic.
int64_t collapsedDim = 0, expandedDim = 0;
while (collapsedDim < collapsedRank && expandedDim < expandedRank) {
// Not empty means the associations has been filled in and the dimension
// is unchanged.
if (!reassociation[collapsedDim].empty()) {
if (expandedDim != reassociation[collapsedDim][0])
return op.emitOpError("Unsupported: expanded dims are off from the "
"expected dim got from reassociation");
// Collect the dims that are collapsed until hitting the next dim that's
// unchanged.
SmallVector<int64_t> collapsedDims;
while (collapsedDim < collapsedRank &&
reassociation[collapsedDim].empty()) {
// the next reassociation is for a dim that's unchanged.
int64_t expandedDimNext = collapsedDim != collapsedRank
? reassociation[collapsedDim][0]
: expandedRank;
if (collapsedDims.size() == 1) {
int64_t collapsedDimSize = 1;
int64_t collapsedDim = collapsedDims[0];
for (auto i : llvm::seq<int64_t>(expandedDim, expandedDimNext)) {
if (collapsedDimSize == kUnknownSize)
int64_t expandedDimSize = expandedShape[i];
if (expandedDimSize == kUnknownSize) {
collapsedDimSize = kUnknownSize;
collapsedDimSize *= expandedShape[i];
// To meet both requirements from tensor.expand_shape verification code.
collapsedShape[collapsedDim] = collapsedDimSize;
expandedDim = expandedDimNext;
// collpasedDims are expanded to [expandedDim, expandedDimNext)
if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size())
op.emitError("unimplemented: mixed of expanding and collapsing "
"operations for view");
for (auto collapsedDim : collapsedDims) {
if (collapsedShape[collapsedDim] == kUnknownSize) {
if (expandedDim >= expandedDimNext) {
return rewriter.notifyMatchFailure(
"desired size is not compatible with the input tensor size");
checkDimEqualHelper(rewriter, loc,
// To meet the second requirement from tensor.expand_shape
// verification code.
expandedShape[expandedDim] = kUnknownSize;
} else {
int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
// A do-while loop is used here to handle the cases where the
// collapsed shape tensor has a dimension of size 1.
do {
int64_t expandedDimSize = expandedShape[expandedDim];
if (expandedDim >= expandedDimNext ||
expandedShape[expandedDim] == kUnknownSize ||
remainingSizeToExpand % expandedDimSize != 0) {
return rewriter.notifyMatchFailure(
op, "total number of elements mismatch in the expansion");
remainingSizeToExpand /= expandedDimSize;
} while (remainingSizeToExpand != 1);
if (collapsedDim != collapsedRank || expandedDim != expandedRank)
return rewriter.notifyMatchFailure(op, "view shape is not supported");
Type adjustedResultType =
RankedTensorType::get(isCollapse ? collapsedShape : expandedShape,
Type adjustedInputType =
RankedTensorType::get(isCollapse ? expandedShape : collapsedShape,
Value castedInput =
rewriter.create<tensor::CastOp>(loc, adjustedInputType, input);
Value result =
? rewriter
.create<tensor::CollapseShapeOp>(loc, adjustedResultType,
castedInput, reassociation)
: rewriter
.create<tensor::ExpandShapeOp>(loc, adjustedResultType,
castedInput, reassociation)
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
} // namespace
namespace {
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenSqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
TypeConverter *typeConverter = getTypeConverter();
auto resultType =
int64_t resultRank = resultType.getRank();
if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
// In case the operand tensor type is statically shaped with all dimensions
// being unit extent, it will be collapsed to a 0-D tensor.
if (resultRank == 0) {
SmallVector<ReassociationIndices> reassociation;
op, resultType, input, reassociation);
return success();
// All the static size-1 dimensions at the beginning(going from higher to
// lower dimensions) will be collapsed into the first dynamic or first non
// size-1 static dimension. All the other static size-1 dimensions will be
// collapsed into its previous dynamic or non size-1 static dimension.
SmallVector<ReassociationIndices> reassociation(resultRank);
bool isSqueezed = false;
int64_t headOnesCount = 0;
while (headOnesCount < inputRank &&
inputType.getDimSize(headOnesCount) == 1) {
isSqueezed = true;
// TODO: Add support for size-1 dynamic dimensions.
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
int64_t j = -1;
for (auto i : llvm::seq<int64_t>(headOnesCount, inputRank)) {
if (inputType.isDynamicDim(i)) {
// Make sure that size-1 dynamic dimension does not exist.
Value dimSize = getDimOp(rewriter, loc, input, i);
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, dimSize, one);
loc, dimSizeNotOne,
"unimplemented: size 1 dynamic dimension is not supported"));
} else if (inputType.getDimSize(i) != 1) {
} else {
// `isSqueezed` checks if the operand tensor type contains at least one
// unit dimension.
isSqueezed = true;
if (j == resultRank)
// Make sure that result type rank is compatible with the squeezed size.
if (j != resultRank - 1)
return rewriter.notifyMatchFailure(
op, "expected output size mismatches with the result type rank");
if (isSqueezed) {
op, resultType, input, reassociation);
} else {
// If the operand tensor type does not have any unit dimension,
// `aten.squeeze` will behave as an identity operation.
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
} // namespace
namespace {
class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenSqueezeDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
if (inputRank == 0) {
return rewriter.notifyMatchFailure(
op, "zero input rank should have been handled by the folder");
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
// TODO: Handle the case where the dim(th) dimension is dynamic.
if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
TypeConverter *typeConverter = getTypeConverter();
auto resultType =
int64_t resultRank = resultType.getRank();
// If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
SmallVector<ReassociationIndices> reassociationMap(resultRank);
bool alreadyCrossedSqueezedDim = false;
for (int i = 0; i != resultRank; i++) {
if (alreadyCrossedSqueezedDim) {
reassociationMap[i].push_back(i + 1);
} else {
if (dim != 0 && i != dim - 1)
alreadyCrossedSqueezedDim = true;
if (dim == 0)
if (i == dim - 1)
// Note: In case the operand tensor type is of unit rank and is statically
// shaped with unit dimension, the `reassociationMap` will be empty and the
// input will be collapsed to a 0-D tensor.
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
return success();
} // namespace
namespace {
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenUnsqueezeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
auto inputRank =
if (dim < 0)
dim += inputRank + 1;
if (!(0 <= dim && dim <= inputRank))
return rewriter.notifyMatchFailure(op, "statically invalid");
SmallVector<ReassociationIndices> reassociationMap(inputRank);
// From the perspective of the reassociation map, the situation of
// unsqueezing before or after the last dimension is symmetrical.
// Normalize it to the "before" case.
// The 0 case is special here, since there is no last dimension to insert
// before -- we simply rely on the loop below iterating 0 times.
if (dim == inputRank && inputRank != 0)
dim = inputRank - 1;
bool alreadyCrossedExpandedDim = false;
for (int i = 0; i != inputRank; i++) {
if (alreadyCrossedExpandedDim) {
reassociationMap[i].push_back(i + 1);
} else {
if (i == dim) {
reassociationMap[i].push_back(i + 1);
alreadyCrossedExpandedDim = true;
auto resultType = getTypeConverter()
op, resultType, adaptor.self(), reassociationMap);
return success();
} // namespace
namespace {
class ConvertAtenTransposeIntOp
: public OpConversionPattern<AtenTransposeIntOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
int64_t dim0;
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
int64_t dim1;
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
auto inVector = adaptor.self();
auto inType = inVector.getType().cast<RankedTensorType>();
auto inputRank = inType.getRank();
auto outType = getTypeConverter()
auto elementType = inType.getElementType();
dim0 = toPositiveDim(dim0, inputRank);
if (!isValidDim(dim0, inputRank))
return rewriter.notifyMatchFailure(op, "dim0 out of range");
dim1 = toPositiveDim(dim1, inputRank);
if (!isValidDim(dim1, inputRank))
return rewriter.notifyMatchFailure(op, "dim1 out of range");
auto loc = op.getLoc();
SmallVector<Value> outputDims;
for (auto i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i));
std::swap(outputDims[dim0], outputDims[dim1]);
Value outVector =
rewriter.create<linalg::InitTensorOp>(loc, outputDims, elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (auto i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (auto i = 0; i < inputRank; i++) {
if (i == dim0)
else if (i == dim1)
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inputRank, 0, idExprs, op.getContext()),
AffineMap::get(inputRank, 0, swapExprs, op.getContext())};
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
auto transpose = rewriter
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
return success();
} // namespace
namespace {
class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
SmallVector<int64_t> dimensions;
if (!matchPattern(op.dims(), m_TorchConstantIntList(dimensions)))
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
Value inVector = adaptor.self();
auto inType = inVector.getType().cast<RankedTensorType>();
int64_t inputRank = inType.getRank();
auto outType = getTypeConverter()
Type elementType = inType.getElementType();
// Check if the dimensions are a valid constants.
int64_t numDimensions = dimensions.size();
if (inputRank != numDimensions)
return rewriter.notifyMatchFailure(
op, "size of `dims` must be equal to the rank of the input");
for (unsigned i = 0; i < numDimensions; i++) {
if (dimensions[i] < 0)
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
if (!isValidDim(dimensions[i], inputRank))
return rewriter.notifyMatchFailure(op, "dimension out of range");
Location loc = op.getLoc();
SmallVector<Value> outputDims;
for (unsigned i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i]));
Value outVector =
rewriter.create<linalg::InitTensorOp>(loc, outputDims, elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (unsigned i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (unsigned i = 0; i < inputRank; i++)
SmallVector<AffineMap> indexingMaps =
AffineMap::inferFromExprList({idExprs, swapExprs});
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
auto transpose = rewriter
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
return success();
} // namespace
namespace {
class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType =
int64_t resultRank = resultType.getRank();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
Value startOrEndBuiltin, Value valueForNone) {
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
return valueForNone;
auto dimSizeAsInt = castIndexToInt(rewriter, loc, dimSize);
Value startOrEndToPositive =
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
// startOrEnd < 0 ? 0 : startOrEnd
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
Value startOrEndAtLeastZero = rewriter.create<SelectOp>(
loc, predDimSltZero, cst0, startOrEndToPositive);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
if (op.start().getType().isa<OptionalType>() ||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
if (!op.step().getType().isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize =
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dim] = resultSize;
SmallVector<Value> offsets(inputType.getRank(), zero);
SmallVector<Value> strides(inputType.getRank(), one);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
// TODO: This code is for selectOp, remove once squeeze dim is added
if (resultRank < inputType.getRank()) {
SmallVector<ReassociationIndices> reassociation(resultRank);
int64_t resultIdx = 0;
for (auto i : llvm::seq<int64_t>(0, inputType.getRank())) {
if (resultIdx < resultRank)
if (i != dim)
result =
rewriter.create<tensor::CollapseShapeOp>(loc, result, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
} // namespace
namespace {
class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenCatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
Value dimValue = op.dim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
// Collect all the tensors to be concatenated.
auto tensorList = op.tensors();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors =
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
RankedTensorType newResultType =
int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides;
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(loc, tensors[0], i));
// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : makeArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
sizes[dim] = resultDimSize;
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
Value result = rewriter.create<linalg::InitTensorOp>(
loc, sizes, newResultType.getElementType());
for (auto tensor : tensors) {
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, tensor);
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, tensor, result,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
offsets[dim] =
rewriter.createOrFold<arith::AddIOp>(loc, offsets[dim], sizes[dim]);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
} // namespace
namespace {
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value dimValue = op.dim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
Value indices = adaptor.index();
Value self = adaptor.self();
RankedTensorType newResultTy =
int64_t rank = newResultTy.getRank();
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
Value result = createZeroInitTensor(rewriter, loc, sizes,
SmallVector<AffineMap, 2> affineMaps(2,
SmallVector<StringRef> iteratorTypes(rank, getParallelIteratorTypeName());
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, newResultTy, indices, result, affineMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
auto index = args[0];
createLinalgPayloadCalculationForGatherOps(b, loc, self, rank, index,
dim, rank);
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
} // namespace
namespace {
class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenEmbeddingOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value weight = adaptor.weight();
Value indices = adaptor.indices();
RankedTensorType newResultType =
auto weightTy = weight.getType().cast<RankedTensorType>();
if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
Type elemTy = weightTy.getElementType();
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
int64_t resultRank = sizes.size();
auto indicesTy = weight.getType().cast<RankedTensorType>();
int64_t indicesRank = indicesTy.getRank();
SmallVector<AffineExpr> indicesExprs;
for (int i = 0; i < indicesRank; i++)
auto indicesAffineMap = AffineMap::get(
/*symbolCount=*/0, indicesExprs, op->getContext());
SmallVector<AffineMap, 2> indexingMaps = {
SmallVector<StringRef> iteratorTypes(sizes.size(),
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
Value embeddingResult =
loc, initTensor.getType(), indices, initTensor,
/*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = args[0];
b, loc, weight, weightTy.getRank(), index, /*dim=*/0,
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
return success();
} // namespace
namespace {
class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenSizeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
Value dim = adaptor.dim();
auto type = self.getType().cast<RankedTensorType>();
Value inputRank = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(type.getRank()));
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
Value size = rewriter.create<tensor::DimOp>(
loc, adaptor.self(), castIntToIndex(rewriter, loc, dimPositive));
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, size));
return success();
} // namespace
// Casts a 0d integer tensor to elemental type.
namespace {
class ConvertAtenIntTensorOp : public OpConversionPattern<AtenIntTensorOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenIntTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value intTensor = adaptor.a();
auto tensorType = intTensor.getType().cast<RankedTensorType>();
if (tensorType.getRank() != 0)
return rewriter.notifyMatchFailure(
op, "invalid rank: the rank of the input tensor must be 0");
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, intTensor);
return success();
} // namespace
namespace {
class ConvertAtenFill_ScalarOp : public OpConversionPattern<AtenFill_ScalarOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenFill_ScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value self = adaptor.self();
Value initVal = adaptor.value();
auto tensorType = self.getType().cast<RankedTensorType>();
Value initValCasted = convertScalarToDtype(rewriter, loc, initVal,
Value result =
createInitTensor(rewriter, loc, getTensorSizes(rewriter, loc, self),
tensorType.getElementType(), initValCasted);
rewriter.replaceOp(op, result);
return success();
} // namespace
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Value self = adaptor.self();
auto selfType = self.getType().cast<RankedTensorType>();
ArrayRef<int64_t> selfShape = selfType.getShape();
Type elementType = selfType.getElementType();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> inShape, outShape;
if (!getListConstructElements(adaptor.size(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
SmallVector<Value> inShapeConverted =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape);
if (inShape.size() < selfShape.size())
return rewriter.notifyMatchFailure(
op, "invalid shape: must not be smaller than rank of tensor");
size_t diff = inShape.size() - selfShape.size();
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
for (size_t i = 0; i < inShape.size(); i++) {
Value shapeValue = inShapeConverted[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
loc, isValid,
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
if (selfShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, self, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
loc, isValid,
"only broadcasting singleton dimensions supported"));
outExpr.push_back(mlir::getAffineDimExpr(i, context));
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inShape.size(), 0, outExpr, context),
SmallVector<StringRef> iteratorTypes(inShape.size(), "parallel");
Value result = rewriter
loc, outTensor.getType(), self, outTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
} // namespace
namespace {
class ConvertAtenContiguousOp : public OpConversionPattern<AtenContiguousOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenContiguousOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
rewriter.replaceOp(op, adaptor.self());
return success();
} // namespace
namespace {
struct ConvertAtenScalarToTensorLike : ConversionPattern {
ConvertAtenScalarToTensorLike(TypeConverter &typeConverter,
MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTensorIntOp, AtenTensorFloatOp>(op))
return rewriter.notifyMatchFailure(
op, "not a supported Scalar to Tensor like op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value elemVal, dtype, device, requires_grad;
if (AtenTensorIntOp tensorIntOp = dyn_cast<AtenTensorIntOp>(op)) {
AtenTensorIntOp::Adaptor adaptor(operands);
elemVal = adaptor.t();
dtype = tensorIntOp.dtype();
device = tensorIntOp.device();
requires_grad = tensorIntOp.requires_grad();
if (AtenTensorFloatOp tensorFloatOp = dyn_cast<AtenTensorFloatOp>(op)) {
AtenTensorFloatOp::Adaptor adaptor(operands);
elemVal = adaptor.t();
dtype = tensorFloatOp.dtype();
device = tensorFloatOp.device();
requires_grad = tensorFloatOp.requires_grad();
// TODO: Dtype conversion.
if (!dtype.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
// TODO: Device information.
if (!device.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None device information");
RankedTensorType resultType = getTypeConverter()
Type outElementType = resultType.getElementType();
Value elemValProm =
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
Value zeroDTensor =
createInitTensor(rewriter, loc, {}, outElementType, elemValProm);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, zeroDTensor);
return success();
} // namespace
namespace {
// Converts constant tensor allocation like ops.
template <typename OpTy, int fillVal>
class ConvertConstantTensorAllocOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: size must be constructed using ListConstruct");
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType =
typeConverter->convertType(op.getType()).template cast<RankedTensorType>();
Type outElemType = resultType.getElementType();
// Create an uninitialized tensor of `resultSize` shape and fill it with
// value `fillVal`.
Value constVal = getConstant(rewriter, loc, fillVal, outElemType);
Value outputTensor =
createInitTensor(rewriter, loc, resultSizeIndex, outElemType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
} // namespace
namespace {
// Converts `aten.empty` to `linalg.init_tensor` op.
class ConvertAtenEmptyMemoryFormatOp
: public OpConversionPattern<AtenEmptyMemoryFormatOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory and memory_format features.
// Only `none` layout is supported.
if (!op.layout().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
// Only `none` memory_format is supported.
if (!op.memory_format().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default memory format is supported");
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
if (!getListConstructElements(op.size(), resultSizeTorchInt)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: size must be constructed using ListConstruct");
resultSize = getTypeConvertedValues(rewriter, loc, typeConverter,
for (auto size : resultSize)
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
// Create an uninitialized tensor of `resultSize` shape.
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultSizeIndex, resultType.getElementType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
return success();
} // namespace
namespace {
class ConvertPrimNumToTensorScalarOp
: public OpConversionPattern<PrimNumToTensorScalarOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(PrimNumToTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value a = adaptor.a();
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, ValueRange{}, a.getType())
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, a, outTensor);
return success();
} // namespace
namespace {
class ConvertAtenNumelOp : public OpConversionPattern<AtenNumelOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenNumelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.self();
SmallVector<Value> sizes(getTensorSizes(rewriter, loc, self));
Value productResult =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
for (size_t i = 0; i < sizes.size(); i++)
productResult =
rewriter.create<arith::MulIOp>(loc, productResult, sizes[i]);
rewriter.replaceOp(op, castIndexToInt(rewriter, loc, productResult));
return success();
} // namespace
namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
// integer argument dim = 1. The size of the output tensor will be [4, 2, 6].
// The approach is as follows:
// for i in range(input.size[0])
// for j in range(index.size[0])
// for k in range(input.size[2])
// indexValue = index[j]
// output[i,j,k] = input[i,indexValue,k]
class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value input = adaptor.self();
Value indices = adaptor.index();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
Type elementType = resultType.getElementType();
unsigned inputRank = inputType.getRank();
int64_t dimInt;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
SmallVector<AffineExpr> resultExpr;
AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt);
SmallVector<StringRef> iteratorTypes;
for (unsigned i = 0; i < inputRank; i++) {
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
Value finalRes =
loc, initTensor.getType(), ValueRange{indices}, initTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
SmallVector<Value> indexTarget;
for (unsigned i = 0; i < inputRank; i++)
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
indexTarget[dimInt] = index;
Value extractedElement =
b.create<tensor::ExtractOp>(loc, input, indexTarget);
b.create<linalg::YieldOp>(loc, extractedElement);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
} // namespace
namespace {
// Let's say the result of the `aten.arange.start_step` is `output` which is a
// 1-d output tensor. The approach used for generating the output tensor is as
// follows:
// for i in range(ceil((end-start)/step))
// output[i] = start + (i * step)
class ConvertAtenArangeStartStepOp
: public OpConversionPattern<AtenArangeStartStepOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Add support for layout, pin_memory features.
// Only `none` layout is supported.
if (!op.layout().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "unimplemented: only default layout is supported");
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
Location loc = op.getLoc();
TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
Type dtype = resultType.getElementType();
Value start = convertScalarToDtype(rewriter, loc, adaptor.start(), dtype);
Value end = convertScalarToDtype(rewriter, loc, adaptor.end(), dtype);
Value step = convertScalarToDtype(rewriter, loc, adaptor.step(), dtype);
// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
Value resultShape;
if (dtype.isa<mlir::IntegerType>()) {
Value subOut = rewriter.create<arith::SubIOp>(loc, end, start);
resultShape = rewriter.create<arith::CeilDivSIOp>(loc, subOut, step);
} else {
Value subOut = rewriter.create<arith::SubFOp>(loc, end, start);
Value divOut = rewriter.create<arith::DivFOp>(loc, subOut, step);
Value ceilOut = rewriter.create<math::CeilOp>(loc, divOut);
resultShape =
rewriter.create<arith::FPToUIOp>(loc, rewriter.getI64Type(), ceilOut);
resultShape = castIntToIndex(rewriter, loc, resultShape);
Value resultTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, dtype);
StringRef iteratorType = getParallelIteratorTypeName();
AffineMap indexingMap =
AffineMap::getMultiDimIdentityMap(1, op->getContext());
Value finalRes =
loc, /*resultTensorTypes=*/resultTensor.getType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value index = b.create<linalg::IndexOp>(loc, 0);
index = castIndexToInt(b, loc, index);
index = convertScalarToDtype(b, loc, index, dtype);
Value mulOut, result;
if (dtype.isa<mlir::FloatType>()) {
mulOut = b.create<arith::MulFOp>(loc, step, index);
result = b.create<arith::AddFOp>(loc, start, mulOut);
} else {
mulOut = b.create<arith::MulIOp>(loc, step, index);
result = b.create<arith::AddIOp>(loc, start, mulOut);
b.create<linalg::YieldOp>(loc, result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
} // namespace
namespace {
class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value input = adaptor.self();
Value indices = op.indices();
SmallVector<Value> indicesTuple;
if (!getListConstructElements(indices, indicesTuple)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the indices list is not from a list construct");
SmallVector<Value> indicesVal =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTuple);
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
Type elementType = resultType.getElementType();
unsigned inputRank = inputType.getRank();
unsigned numIndexTensors = indicesTuple.size();
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
// Case 1 : When numIndexTensors == 1 and `input` is a 1-d tensor.
// TODO: generalize the implementation for other cases.
if (numIndexTensors == 1 && inputRank == 1) {
if (failed(checkNotNone(rewriter, op, indicesVal[0])))
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
unsigned resultRank =
SmallVector<Value> resultShape;
SmallVector<AffineExpr> indicesExpr, resultExpr;
SmallVector<StringRef> iteratorTypes;
for (unsigned i = 0; i < resultRank; i++)
resultShape.push_back(getDimOp(rewriter, loc, indicesVal[0], i));
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
for (unsigned i = 0; i < resultRank; i++) {
auto indexingMaps =
AffineMap::inferFromExprList({indicesExpr, resultExpr});
Value finalRes =
loc, initTensor.getType(), ValueRange{indicesVal[0]},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indexTarget = castIntToIndex(b, loc, args[0]);
Value extractedElement =
b.create<tensor::ExtractOp>(loc, input, indexTarget);
b.create<linalg::YieldOp>(loc, extractedElement);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
} else
return rewriter.notifyMatchFailure(
op, "unimplemented: support for this set of inputs not present");
} // namespace
namespace {
class ConvertPseudoAtenUniformOp
: public OpConversionPattern<PseudoAtenUniformOp> {
using OpConversionPattern::OpConversionPattern;
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.self();
Value from = adaptor.from();
Value to =;
Value generator = adaptor.generator();
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();
if (!elemTy.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
// Build the core formula of LCG Algorithm that makes use of element index:
// For output matrix with rank N:
// temp1 = (cast(I64, index(D.0)) + seed) * multiplier + incrementStep
// ...
// tempN = (cast(I64, index(D.(N))) + tempN-1) * multiplier + incr
// Refer to
// The value of multiplier and incrementStep are referenced from
// for 2^64.
Value multiplier = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(6364136223846793005));
Value incrementStep = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(1442695040888963407));
// Tn = (index + Tn-1) * multiplier + incrementStep
auto getNextTemp = [&](OpBuilder &b, Value index, Value temp) {
Value castIndex =
b.create<arith::IndexCastOp>(loc, b.getI64Type(), index);
Value add = b.create<arith::AddIOp>(loc, castIndex, temp);
Value mult = b.create<arith::MulIOp>(loc, add, multiplier);
return b.create<arith::AddIOp>(loc, mult, incrementStep);
// Get initial seed, min and max used by `linalg.generic` compute payload.
Value initialSeed = rewriter.create<GetNextSeedOp>(loc);
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
// Construct the `linalg.generic` op.
auto resultRank = resultType.getRank();
SmallVector<AffineMap, 1> indexingMaps(
1, rewriter.getMultiDimIdentityMap(resultRank));
SmallVector<StringRef> iteratorTypes(resultRank,
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, self);
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
Value uniformRes =
loc, resultType, /*inputs=*/ValueRange{},
/*outputs=*/initTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value temp = initialSeed;
for (int i = 0; i < resultRank; i++) {
Value index = b.create<linalg::IndexOp>(loc, i);
temp = getNextTemp(b, index, temp);
// scale = (max - min) * const(F64, 5.4210108E-20)
// which is derived from rand(min,max) =
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
Value epsilon = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(min.getType(), 5.4210108E-20));
Value range = b.create<arith::SubFOp>(loc, max, min);
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
// res = cast(F64, tempN) * scale + min
Value updateFloat =
b.create<arith::UIToFPOp>(loc, elemTy, temp);
Value updateScaled =
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
b.create<linalg::YieldOp>(loc, res);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, uniformRes);
return success();
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertTorchToLinalg
: public ConvertTorchToLinalgBase<ConvertTorchToLinalg> {
void getDependentDialects(DialectRegistry &registry) const override {
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
math::MathDialect, tensor::TensorDialect,
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
patterns.add<ConvertAtenMmOp>(typeConverter, context);
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
patterns.add<ConvertAtenSqueezeDimOp>(typeConverter, context);
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
patterns.add<ConvertAtenConv2dOp>(typeConverter, context);
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
patterns.add<ConvertAtenViewOp>(typeConverter, context);
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
patterns.add<ConvertAtenConstantPadNdOp>(typeConverter, context);
patterns.add<ConvertReductionOp>(typeConverter, context);
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
patterns.add<ConvertAtenCatOp>(typeConverter, context);
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
patterns.add<ConvertAtenSizeIntOp>(typeConverter, context);
patterns.add<ConvertAtenEmbeddingOp>(typeConverter, context);
patterns.add<ConvertAtenEmptyMemoryFormatOp>(typeConverter, context);
target.addIllegalOp<AtenZerosOp, AtenOnesOp>();
patterns.add<ConvertConstantTensorAllocOp<AtenZerosOp, 0>>(typeConverter,
patterns.add<ConvertConstantTensorAllocOp<AtenOnesOp, 1>>(typeConverter,
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
patterns.add<ConvertAtenDropoutOp>(typeConverter, context);
patterns.add<ConvertAtenFill_ScalarOp>(typeConverter, context);
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter, context);
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
return signalPassFailure();
} // namespace
mlir::torch::createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>();