torch-mlir/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

1045 lines
41 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include <climits>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
using namespace mlir::torch::Torch;
std::optional<Value>
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
ArrayRef<int64_t> indexShape) {
unsigned indexRank = indexShape.size();
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
// Create torch.meshgrid inputs
// Example: indexShape=[1,4,2]
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
for (int i = 0; i < indexShape[dim]; i++) {
indicesMetaElement.push_back(i);
}
// Compute total number of meta element repeat times:
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
for (int i = 0; i < static_cast<int>(indexRank); i++) {
if (i == dim) {
continue;
} else {
indicesMetaElementRepeatTimes *= indexShape[i];
}
}
if (dim != static_cast<int>(indexShape.size()) - 1) {
// Create one dim indices for index except for last dim
// Create indices raw vector.
// torch.stack(torch.meshgrid)
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) {
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
indicesVec.push_back(indicesMetaElement[elementId]);
}
}
} else { // Create the one dim indices for last dim of index
// Create indices raw vector
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
for (size_t elementId = 0; elementId < indicesMetaElement.size();
elementId++) {
indicesVec.push_back(indicesMetaElement[elementId]);
}
}
}
// Create tosa::ConstOp Tensor for indicesVec with target shape.
// torch.unsqueeze(torch.stack(torch.meshgrid)))
// dim0: tensor([[ [ [0], [0] ],
// [ [0], [0] ],
// [ [0], [0] ],
// [ [0], [0] ], ]]) 1*4*2*1
// dim1: tensor([[ [ [0], [0] ],
// [ [1], [1] ],
// [ [2], [2] ],
// [ [3], [3] ], ]]) 1*4*2*1
// dim2/last dim: tensor([[ [ [0], [1] ],
// [ [0], [1] ],
// [ [0], [1] ],
// [ [0], [1] ], ]]) 1*4*2*1
auto indicesDim = getConstTensor<int32_t>(rewriter, op,
/*vec=*/indicesVec,
/*shape=*/indicesOneDimShape);
return indicesDim;
}
tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
lhs, rhs, shift);
}
template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
Value lhs, Value rhs) {
auto lhsElemTy = cast<TensorType>(lhs.getType()).getElementType();
auto rhsElemTy = cast<TensorType>(rhs.getType()).getElementType();
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
(void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type");
}
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), outType,
lhs, rhs);
}
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
Operation *op,
Value paramsValue,
Value indexValue,
int32_t axis) {
// For easy understanding of this algorithm, the following comments are with
// an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2,
// !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32>
// https://gist.github.com/AmosLewis/2f18434397025211da4491735bcc6db6
//
// Convert Torch Index to TF Indices
// [[ [[ d0 d1 d2 d0 d1 d2
// [0,0], [[0, 0, 0],[0, 0, 0]],
// [1,0], [[0, 1, 1],[0, 1, 0]],
// [2,1], [[0, 2, 2],[0, 2, 1]],
// [2,1] [[0, 3, 2],[0, 3, 1]]
// ]] 1*4*2 ]] 1*4*2*3
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indexType = dyn_cast<RankedTensorType>(indexValue.getType());
auto paramsShape = paramsType.getShape(); // [1 4 3]
auto indexShape = indexType.getShape(); // [1 4 2]
int paramsRank = paramsShape.size(); // 3
int indexRank = indexShape.size(); // 3
// Initialize the final tf indices shape, and the shape of each dim that can
// concat to this tf indices
SmallVector<int64_t> indicesShape; // [1 4 2 3]
SmallVector<int64_t> indicesOneDimShape; // [1 4 2 1]
for (auto shape : indexShape) {
indicesShape.push_back(shape);
indicesOneDimShape.push_back(shape);
}
indicesShape.push_back(paramsRank);
indicesOneDimShape.push_back(1);
// Get the chosen axis index
// indexValue reshape to indicesDim: shape append 1
// [1 4 2] -> [1 4 2 1]
// dim2: tensor([[ [ [0], [0] ],
// [ [1], [0] ],
// [ [2], [1] ],
// [ [2], [1] ], ]]) 1*4*2*1
auto indicesChosenAxis = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()),
indexValue, rewriter.getDenseI64ArrayAttr(indicesOneDimShape));
SmallVector<Value> concatInputs;
for (auto dim = 0; dim < paramsRank; dim++) {
if (dim != axis) {
auto indices = createOneDimTfIndices(rewriter, op, indicesOneDimShape,
dim, indexShape);
concatInputs.push_back(indices.value());
} else {
// the chosen axis indices will be replaced by index[i][j][k]
concatInputs.push_back(indicesChosenAxis.getResult());
}
}
// detailed example explanation
// https://gist.github.com/AmosLewis/932a8dee3ba7657dcc6d09a4da4775d4 Get TF
// indices: 1*4*2*3
// [[ d0 d1 d2 d0 d1 d2
// [[0, 0, 0],[0, 0, 0]],
// [[0, 1, 1],[0, 1, 0]],
// [[0, 2, 2],[0, 2, 1]],
// [[0, 3, 2],[0, 3, 1]]
// ]]
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShape, rewriter.getIntegerType(32)),
concatInputs, indexRank);
return indicesTf.getResult();
}
// Lowers Gather operators to a sequence of TOSA ops.
// taken from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type outType, Value paramsValue,
Value indicesValue) {
auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
if (!resultType || !paramsType || !indicesType)
return std::nullopt;
// N: number of batches
// Always 1 for GatherND
//
// Because TOSA's GATHER operator already uses the symbol 'N' for
// the number of batches, we will use the symbol 'ND' to specify the
// number of dimensions that are sliced from params instead of'N' in
// the TF MLIR documentation.
//
// ND: indices.shape[-1]
//
// W: number of indices in each batch
// Computed as:
// product(indices.shape[0:-1]) (all but the last dimension)
//
// K: range of each index
// Computed as:
// product(params.shape[0:ND-1])
//
// C: number of channels for each index
// Computed as:
// product(params.shape[ND:])
//
// The params tensor needs to be reshaped, but not transposed, to move the
// dimensions into [N, K, C] order.
//
// The dimensions of the input params[] tensor are grouped in the following
// order to begin with:
//
// [ParamIndices, ParamChannels]
// |------------||-------------|
// K C
//
// The reshape simply flattens the params tensor into a 2D [K, C] shape.
//
// Indices needs to be put in the form of [N, W], but a simple flattening
// will not suffice, because the indices need to index into a [W]-shape
// vector instead of the params.shape[0:ND-1] tensor that we had before.
//
// To flatten the coordinates, first reshape indices to a [W, ND] matrix,
// where the matrix now represents W ND-dimensional coordinates into the
// params tensor.
//
// From here, we take each of the ND dimensions and multiply it with
// the size of the next params dimension (or 1 for the last
// dimension), then sum all these together with a reduce_sum
// operator. This is exactly the same mathematics as one would use
// flatten the indices of an N-dimensional row-major array into a
// 1-D array in C.
//
// More precisely, do an element-wise multiply with [params.shape[1
// .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
// [W]-shaped tensor, then trivially reshape to [N=1, W] to be
// compatible with the GATHER operator's shape.
//
// Then perform the tosa.GATHER() operation.
//
// Now we have result = [N, K, C].
//
// Reshape with a single, simple reshape to the final output shape of:
// [Indices, ParamChannels]
//
// Where, Indices is indices.shape[0:ND-1]
//
// For easy understanding, all following comments take an exact value for each
// argument Example: Take TF style indices as input
// func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>,
// %arg1: !torch.vtensor<[1,4,2,3],i32>) -> !torch.vtensor<[1,4,2],f32>
// Detail algorithm visualization:
// https://gist.github.com/AmosLewis/bb6e3a0ad9fd1705c9f9d42a2eefbb88
int N = 1, W = 1, K = 1, C = 1, ND = 1;
int paramsRank = paramsType.getShape().size(); // 3
int indicesRank = indicesType.getShape().size(); // 4
// ND: indices.shape[-1]
ND = indicesType.getShape()[indicesRank - 1]; // 3
if (ND > paramsRank) {
(void)rewriter.notifyMatchFailure(
op, "size of last dimension of indices must be <= params rank");
return std::nullopt;
}
// Calculate N, K, W, C. (N is always 1)
// number of indices in each batch. product(indices.shape[0:-1]) (all but the
// last dimension) W = 1*4*2 = 8
for (int i = 0; i < (indicesRank - 1); i++) {
W *= indicesType.getShape()[i];
}
// K: range of each index, total number of inputs(chould be gather) after
// flattened k = 1*1*4*3 = 12
for (int i = 0; i < ND; i++) {
K *= paramsType.getShape()[i];
}
// C: number of channels for each index : numbers of values inside each
// input(chould be gather) C = product(params.shape[ND:] ND = 3, paramsRank,
// C = 1
for (int i = ND; i < paramsRank; i++) {
C *= paramsType.getShape()[i];
}
// int N = 1, W = 8, K = 12, C = 1, ND = 3;
SmallVector<int64_t, 3> tosaValuesShape({N, K, C}); // {1,12,1}
SmallVector<int64_t, 2> tosaIndicesShape({N, W}); // {1,8}
SmallVector<int64_t, 2> indicesMatrixShape({W, ND}); // {8,3}
SmallVector<int64_t, 2> indicesMatrixReducesumShape(
{W, 1}); // {8,1} This is different from tf tosa code
SmallVector<int64_t, 3> tosaGatherResultShape({N, W, C}); // {1,8,1}
// %2 = "tosa.reshape"(%0) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) ->
// tensor<1x12x1xf32>
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()),
paramsValue, rewriter.getDenseI64ArrayAttr(tosaValuesShape));
// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
SmallVector<int32_t> flattenedCoeffVec; // [12,3,1]
// flattenedCoeffVec = [4,3,1]
for (int i = 1; i < ND; i++) {
flattenedCoeffVec.push_back(paramsType.getShape()[i]);
}
flattenedCoeffVec.push_back(1);
// flattenedCoeffVec = [12,3,1]
for (int i = ND - 1; i > 0; i--) {
flattenedCoeffVec[i - 1] *= flattenedCoeffVec[i];
}
// Create the tosaConstTensor for the flattenedCoeffVec
// %4 = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () ->
// tensor<3xi32>
auto flattenedCoeffValue =
getConstTensor<int32_t>(rewriter, op, flattenedCoeffVec,
{static_cast<int64_t>(flattenedCoeffVec.size())});
if (!flattenedCoeffValue)
return std::nullopt;
// Multiply the coefficients by the coordinates
// %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>,
// tensor<3xi32>) -> tensor<8x3xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
// Sum up the products of the coefficients and coordinates
// %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) ->
// tensor<8x1xi32>
auto flattenedIndicesReduceOp = tosa::CreateOpAndInfer<tosa::ReduceSumOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));
// And reshape to [N, W]
// %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) ->
// tensor<1x8xi32>
auto tosaIndicesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
flattenedIndicesReduceOp.getResult(),
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
// Now the gather op itself
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
// tensor<1x8x1xf32>
auto tosaGatherOp = tosa::CreateOpAndInfer<tosa::GatherOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaGatherResultShape,
resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult());
// Finally, reshape back to the original output shape of [Indices,
// ParamChannels]. %10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} :
// (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> %11 = torch_c.from_builtin_tensor
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
rewriter.getDenseI64ArrayAttr(resultType.getShape()))
.getResult();
}
// Lower indexput op to tosa::scatter op
// Mostly take from the up function convertGatherNdOp()
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Operation *op, Type outType,
Value paramsValue, Value indicesValue,
Value fillValues) {
auto resultType = dyn_cast<ShapedType>(outType);
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
auto fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
if (!resultType || !paramsType || !indicesType)
return std::nullopt;
// N: number of batches
// Always 1 for ScatterOp
//
// Because TOSA's Scatter operator already uses the symbol 'N' for
// the number of batches, we will use the symbol 'ND' to specify the
// number of dimensions that are sliced from params instead of'N' in
// the TF MLIR documentation.
//
// ND: indices.shape[-1]
//
// W: number of indices in each batch
// Computed as:
// product(indices.shape[0:-1]) (all but the last dimension)
//
// K: range of each index
// Computed as:
// product(params.shape[0:ND-1])
//
// C: number of channels for each index
// Computed as:
// product(params.shape[ND:])
//
// The params tensor needs to be reshaped, but not transposed, to move the
// dimensions into [N, K, C] order.
//
// The dimensions of the input params[] tensor are grouped in the following
// order to begin with:
//
// [ParamIndices, ParamChannels]
// |------------||-------------|
// K C
//
// The reshape simply flattens the params tensor into a 2D [K, C] shape.
//
// Indices needs to be put in the form of [N, W], but a simple flattening
// will not suffice, because the indices need to index into a [W]-shape
// vector instead of the params.shape[0:ND-1] tensor that we had before.
//
// To flatten the coordinates, first reshape indices to a [W, ND] matrix,
// where the matrix now represents W ND-dimensional coordinates into the
// params tensor.
//
// From here, we take each of the ND dimensions and multiply it with
// the size of the next params dimension (or 1 for the last
// dimension), then sum all these together with a reduce_sum
// operator. This is exactly the same mathematics as one would use
// flatten the indices of an N-dimensional row-major array into a
// 1-D array in C.
//
// More precisely, do an element-wise multiply with [params.shape[1
// .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
// [W]-shaped tensor, then trivially reshape to [N=1, W] to be
// compatible with the scatter operator's shape.
//
// Then perform the tosa.scatter() operation.
//
// Now we have result = [N, K, C].
//
// Reshape with a single, simple reshape to the final output shape of:
// [Indices, ParamChannels]
//
// Where, Indices is indices.shape[0:ND-1]
//
// For easy understanding, all following comments take an exact value for each
// argument Example: Take TF style indices as input
// torch.aten._index_put_impl %input, %indices, %fillValue, %false, %false :
// !torch.vtensor<[1,4],si64>, !torch.vtensor<[3,2],si64>,
// !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool ->
// !torch.vtensor<[1,4],si64>
// Detail algorithm visualization:
int N = 1, W = 1, K = 1, fillK = 1, C = 1, ND = 1;
int paramsRank = paramsType.getShape().size(); // 2
int indicesRank = indicesType.getShape().size(); // 2
// ND: indices.shape[-1]
ND = indicesType.getShape()[indicesRank - 1]; // 2 depth of input
if (ND > paramsRank) {
(void)rewriter.notifyMatchFailure(
op, "size of last dimension of indices must be <= params rank");
return std::nullopt;
}
// Calculate N, K, W, C. (N is always 1)
// number of indices/selected value in each batch product(indices.shape[0:-1])
// (all but the last dimension) W = 1*3 = 3
for (int i = 0; i < (indicesRank - 1); i++) {
W *= indicesType.getShape()[i];
}
// K: range of each index, total number of inputs(chould be scatter) after
// flattened k = 1*1*4 = 4
for (int i = 0; i < ND; i++) {
K *= paramsType.getShape()[i];
}
// C: number of channels for each index : numbers of values inside each
// input(chould be scatter) C = product(params.shape[ND:] ND = 2, paramsRank,
// C = 1
for (int i = ND; i < paramsRank; i++) {
C *= paramsType.getShape()[i];
}
// int N = 1, W = 3, K = 4, fillk = 3, C = 1, ND = 2;
SmallVector<int64_t, 3> tosaInputValuesShape({N, K, C}); // {1,4,1}
SmallVector<int64_t, 2> tosaIndicesShape({N, W}); // {1,3}
SmallVector<int64_t, 2> indicesMatrixShape({W, ND}); // {3,2}
SmallVector<int64_t, 2> indicesMatrixReducesumShape({W, 1}); // {3,1}
// Preprocess fill value.
// There are 2 cases of fillValues,
// 1. !torch.vtensor<[1,3],si64>
// [[0,0,0]] -> [[[0], [0], [0]]]
// 2. !torch.vtensor<[],si64>
// reshape(1) tile(3) reshape(1,3) reshape(1,3,1)
// [] -> [0] -> [0,0,0] -> [[0,0,0]] -> [[[0], [0], [0]]]
// reshape to [1] and then tile to same number of indicesValue.shape[0],
// [1,1,1]
if (fillValuesType.getRank() == 0) {
// [] -> [0]
SmallVector<int64_t, 1> oneShape({1}); // {3,1}
auto tosaFillValuesOneReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(oneShape, fillValuesType.getElementType()),
fillValues, rewriter.getDenseI64ArrayAttr(oneShape));
// [0] -> [0,0,0]
SmallVector<int64_t, 1> tileShape({W}); // {3}
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
tosaFillValuesOneReshapeOp.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
// [0,0,0] -> [[0,0,0]]
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}
auto newTosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(newTosaFillValuesShape,
fillValuesType.getElementType()),
tosaFillValuesTileOp.getResult(),
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
fillValues = newTosaFillValuesReshapeOp.getResult();
fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
}
// fillK: range of each index, total number of fillInput(could be scatter)
// after flattened k = 1*1*3 = 3
for (int i = 0; i < ND; i++) {
fillK *= fillValuesType.getShape()[i];
}
SmallVector<int64_t, 3> tosaFillValuesShape({N, fillK, C}); // {1,3,1}
// Reshape/Flatten fillValues to 3d tensor
// [[0,0,0]] -> [[[0], [0], [0]]]
// %10 = "tosa.reshape"(%1) {new_shape = array<i64: 1, 3, 1>} :
// (tensor<1x3xi64>) -> tensor<1x3x1xi64>
auto tosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaFillValuesShape,
fillValuesType.getElementType()),
fillValues, rewriter.getDenseI64ArrayAttr(tosaFillValuesShape));
// Reshape/Flatten input to 3d tensor
// [[1, 2, 3, 4]] -> [[[1], [2], [3], [4]]]
// %9 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 1>} :
// (tensor<1x4xi64>) -> tensor<1x4x1xi64>
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaInputValuesShape, paramsType.getElementType()),
paramsValue, rewriter.getDenseI64ArrayAttr(tosaInputValuesShape));
// Reshape/Flatten the input indices tensor to a 2d [W, ND] matrix.
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
// -> tensor<3x2xi32>
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
SmallVector<int32_t> flattenedCoeffVec; // [4,1]
// flattenedCoeffVec = [4,1]
for (int i = 1; i < ND; i++) {
flattenedCoeffVec.push_back(paramsType.getShape()[i]);
}
flattenedCoeffVec.push_back(1);
// flattenedCoeffVec = [4,1]
for (int i = ND - 1; i > 0; i--) {
flattenedCoeffVec[i - 1] *= flattenedCoeffVec[i];
}
// Create the tosaConstTensor for the flattenedCoeffVec.
// %12 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () ->
// tensor<2xi32>
auto flattenedCoeffValue =
getConstTensor<int32_t>(rewriter, op, flattenedCoeffVec,
{static_cast<int64_t>(flattenedCoeffVec.size())});
if (!flattenedCoeffValue)
return std::nullopt;
// Multiply the coefficients by the coordinates.
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
// tensor<2xi32>) -> tensor<3x2xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
// %14 = "tosa.reduce_sum"(%13) {axis = 1 : i64} : (tensor<3x2xi32>) ->
// tensor<3x1xi32>
auto flattenedIndicesReduceOp = tosa::CreateOpAndInfer<tosa::ReduceSumOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1));
// And reshape to [N, W]
// [[1],[2],[3]] -> [[1,2,3]]
// %15 = "tosa.reshape"(%14) {new_shape = array<i64: 1, 3>} :
// (tensor<3x1xi32>) -> tensor<1x3xi32>
auto tosaIndicesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
flattenedIndicesReduceOp.getResult(),
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
// Now the Scatter op itself
// %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>,
// tensor<1x3x1xi64>) -> tensor<1x4x1xi64> input = [[[1], [2], [3], [4]]],
// indices = [[1,2,3]], fillValues= [[[0], [0], [0]]] result = [[[1], [0],
// [0], [0]]]
auto tosaScatterOp = tosa::CreateOpAndInfer<tosa::ScatterOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaInputValuesShape, resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult(),
tosaFillValuesReshapeOp.getResult());
// Finally, reshape back to the original output shape of [Indices,
// ParamChannels].
// [[1, 0, 0, 0]]
// %17 = "tosa.reshape"(%16) {new_shape = array<i64: 1, 4>} :
// (tensor<1x4x1xi64>) -> tensor<1x4xi64>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaScatterOp.getResult(),
rewriter.getDenseI64ArrayAttr(resultType.getShape()))
.getResult();
}
// Common function for lowering reduce operations to TOSA ops.
template <typename T>
std::optional<Value> convertReduceOpCommon(
PatternRewriter &rewriter, Operation *op, RankedTensorType output_type,
Value input_value, ElementsAttr axes_elems, bool keep_dims,
Type reduce_element_type, bool is_quantized, double input_scale,
int64_t input_zp, double output_scale, int64_t output_zp) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
ArrayRef<int64_t> input_shape = input_type.getShape();
ArrayRef<int64_t> output_shape = output_type.getShape();
auto input_rank = input_shape.size();
Value val = input_value;
if (axes_elems.getNumElements() == 0) {
// No axes means return the original tensor.
auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
rewriter, op->getLoc(), output_type, val);
val = identity_op.getResult();
} else {
// Reduce along each axis
SmallVector<int64_t> shape_vec(input_shape.begin(), input_shape.end());
if (is_quantized) {
val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp);
}
for (int i = 0; i < axes_elems.getNumElements(); i++) {
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);
shape_vec[axis_val] = 1;
RankedTensorType reduce_type =
RankedTensorType::get(shape_vec, reduce_element_type);
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);
val = reduce_op.getResult();
}
if (is_quantized) {
RankedTensorType output_rescale_type =
RankedTensorType::get(shape_vec, output_type.getElementType());
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
0, output_zp, false, true);
}
// Optionally squeeze out the reduced axes.
if (!keep_dims) {
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), output_type, val,
rewriter.getDenseI64ArrayAttr(output_shape));
val = reshape_op.getResult();
}
}
return val;
}
// Lowers ReduceAll to a sequence of TOSA ops.
std::optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
return convertReduceOpCommon<tosa::ReduceAllOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceAny to a sequence of TOSA ops.
std::optional<Value>
convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
return convertReduceOpCommon<tosa::ReduceAnyOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceMin to a sequence of TOSA ops.
std::optional<Value>
convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
return convertReduceOpCommon<tosa::ReduceMinOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceMax to a sequence of TOSA ops.
std::optional<Value>
convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
return convertReduceOpCommon<tosa::ReduceMaxOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceProd to a sequence of TOSA ops.
std::optional<Value>
convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype || output_is_qtype) {
op->emitOpError("ConvertReduceProdOp: input/output tensor should "
"be all floating-point.");
return std::nullopt;
}
return convertReduceOpCommon<tosa::ReduceProdOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
output_type.getElementType(), false, 1.0f, 0, 1.0f, 0);
}
// Lowers ReduceSum to a sequence of TOSA ops.
std::optional<Value>
convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point.");
return std::nullopt;
}
double input_scale = 1.0f;
double output_scale = 1.0f;
int64_t input_zp = 0;
int64_t output_zp = 0;
Type reduce_element_type = input_type.getElementType();
if (input_is_qtype) {
auto input_qtype =
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype =
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
int32_t input_shift = 20;
input_scale =
static_cast<double>(1 << input_shift) * input_qtype.getScale();
output_scale =
1.0 / (output_qtype.getScale() * static_cast<double>(1 << input_shift));
input_zp = input_qtype.getZeroPoint();
output_zp = output_qtype.getZeroPoint();
reduce_element_type = rewriter.getI32Type();
}
return convertReduceOpCommon<tosa::ReduceSumOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
output_zp);
}
// Lowers ReduceMean to a sequence of TOSA ops.
std::optional<Value>
convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
// reduce_mean is lowered as followed:
// op1 = reduce_sum(input)
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
bool output_is_qtype =
output_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point.");
return std::nullopt;
}
// Only supports float type mean() if it's non-quantized
if (!input_is_qtype && !output_type.getElementType().isa<mlir::FloatType>()) {
op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element "
"not FloatType!");
return std::nullopt;
}
int64_t input_rank = input_type.getRank();
ArrayRef<int64_t> inputShape = input_type.getShape();
int64_t num_elems_on_reduced_axis = 1;
for (int i = 0; i < axes_elems.getNumElements(); i++) {
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
if (axis_val < 0)
axis_val += input_rank;
if (inputShape[axis_val] < 0)
op->emitOpError("Failed convertReduceMean: support for dynamic input "
"shape not implemented");
num_elems_on_reduced_axis *= inputShape[axis_val];
}
double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
double input_scale = 1.0f;
double output_scale = 1.0f;
int64_t input_zp = 0;
int64_t output_zp = 0;
Type reduce_element_type = input_type.getElementType();
if (input_is_qtype) {
auto input_qtype =
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
auto output_qtype =
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
// Combine 'div_scale' as part of output rescale
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
input_zp = input_qtype.getZeroPoint();
output_zp = output_qtype.getZeroPoint();
reduce_element_type = rewriter.getI32Type();
}
auto val = convertReduceOpCommon<tosa::ReduceSumOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims,
reduce_element_type, input_is_qtype, input_scale, input_zp, output_scale,
output_zp);
if (!val.has_value())
return std::nullopt;
if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);
return CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), output_type,
val.value(), div_const, 0)
.getResult();
}
return val;
}
// Lowers LinalgVectorNorm to a sequence of TOSA ops.
std::optional<Value>
convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType output_type, Value input_value,
ElementsAttr axes_elems, bool keep_dims) {
RankedTensorType input_type =
dyn_cast<RankedTensorType>(input_value.getType());
if (!input_type)
return std::nullopt;
Type elemType = output_type.getElementType();
if (!isa<mlir::FloatType>(elemType)) {
op->emitOpError("Only floating-point datatype legalization supported for "
"AtenLinalgVectorNorm op");
return std::nullopt;
}
auto linalgVectorNormOp = cast<AtenLinalgVectorNormOp>(op);
// TODO: Add support for ord = {0, +inf, -inf}.
auto epsilon = 1e-5;
double ordLiteralFloat = 1.0;
int64_t ordLiteralInt = 1;
Value ordVal;
if (matchPattern(linalgVectorNormOp.getOrd(),
torch::Torch::m_TorchConstantFloat(&ordLiteralFloat))) {
ordVal = tosa::getConstTensor<float>(rewriter, op,
{static_cast<float>(ordLiteralFloat)},
{}, elemType)
.value();
} else if (matchPattern(linalgVectorNormOp.getOrd(),
torch::Torch::m_TorchConstantInt(&ordLiteralInt))) {
ordVal = tosa::getConstTensor<float>(rewriter, op,
{static_cast<float>(ordLiteralInt)},
{}, elemType)
.value();
} else {
op->emitOpError("only support FP or INT type ord parameter");
return std::nullopt;
}
if (fabs(ordLiteralFloat) < epsilon ||
fabs(static_cast<double>(ordLiteralInt)) < epsilon) {
op->emitOpError("unimplemented: L0 norm");
return std::nullopt;
}
if (std::isinf(ordLiteralFloat) ||
std::isinf(static_cast<double>(ordLiteralInt))) {
op->emitOpError("unimplemented: ord = +/- inf");
return std::nullopt;
}
auto absVal = CreateOpAndInfer<tosa::AbsOp>(rewriter, op->getLoc(),
input_type, input_value)
.getResult();
auto powVal = CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(),
input_type, absVal, ordVal)
.getResult();
std::optional<Value> result = convertReduceSumOp(
rewriter, op, output_type, powVal, axes_elems, keep_dims);
if (!result)
return std::nullopt;
auto reciprocalVal = CreateOpAndInfer<tosa::ReciprocalOp>(
rewriter, op->getLoc(), ordVal.getType(), ordVal)
.getResult();
return CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(), output_type,
result.value(), reciprocalVal)
.getResult();
}
} // namespace tosa
} // namespace mlir