[TOSA] Add aten._index_put_impl support (#2031)

Add e2e support by add  "tosa-to-scf"
pull/2318/head
Chi_Liu 2023-07-17 09:51:24 -07:00 committed by GitHub
parent ba24a46910
commit 5706697e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 453 additions and 1 deletions

View File

@ -306,6 +306,7 @@ TORCHDYNAMO_CRASHING_SET = {
"ToCopyModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
# See https://github.com/llvm/torch-mlir/issues/2178
"Add_Module_basic"
@ -811,6 +812,7 @@ STABLEHLO_PASS_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"IndexPutImpl2DNoneIndexStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"ConstantBoolParameterModule_basic",
@ -1223,6 +1225,7 @@ LTC_XFAIL_SET = {
"IndexPutImpl2DFloatAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",

View File

@ -58,6 +58,12 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
Value params_value,
Value indices_value);
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Operation *op, Type outType,
Value paramsValue, Value indicesValue,
Value fillValues);
// Lowers ReduceAll to a sequence of TOSA ops.
std::optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,

View File

@ -3405,6 +3405,150 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
Aten_IndexPutImplOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// a = torch.tensor([[0, 1, 2, 3]])
// a[..., 1:] = torch.tensor([4, 5, 6])
// = a[..., 1:4] = torch.tensor([4, 5, 6])
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
// 3])), # indicies torch.tensor([4, 5, 6])) #
// value
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
// (None, torch.tensor([1, 2, 3]),),# indicies
// torch.tensor([4, 5, 6])) # value
// Not a tensor type.
auto input = adaptor.getSelf();
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
auto fillValues = adaptor.getValues();
auto valuesType = adaptor.getValues().getType().dyn_cast<TensorType>();
if (!valuesType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
// Deal with torch.prim.ListConstruct of non const value to get the index
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
auto outType = getTypeConverter()->convertType(op.getType());
// convert list of indices with none into indices tensor without none
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3])
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
if (indexTensors.size() <= 1) {
return rewriter.notifyMatchFailure(
op, "Only support indexput with multiple index.");
}
SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index other than i==0, like (index0, None)
// (None, index1)
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) {
// convert None to [0,0,0]
auto indexNext = indexTensors[i + 1];
auto indexNextTorch = tensorsTorchType[i + 1];
if (indexNextTorch.getType().isa<Torch::NoneType>()){
return rewriter.notifyMatchFailure(
op, "Multiple None index is not support for now.");
}
auto indexNextType = indexNext.getType().dyn_cast<RankedTensorType>();
auto indexNextShape = indexNextType.getShape();
int64_t size = 1;
for (auto s : indexNextShape)
size *= s;
SmallVector<int32_t> values(size, i);
index =
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
.value();
}
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
// Expand last dim of index to tf indices [3] -> [3,1]
// convert [0,0,0] to [[0],[0],[0]]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
// create concat tensor for indicesTf
// ([[0],[0],[0]], [[1],[2],[3]])
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}
// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
}
}
// concat each indices into indicesTf: shape ([3,1],[3,1]) -> [3,2]
// ([0,0,0],[1,2,3]) -> [[0,1],[0,2], [0,3]]
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp.
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
indicesTf.getResult(), fillValues);
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert ScatterNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
@ -3467,7 +3611,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
// Make type of index tosa compatible, i64 to i32.
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
@ -4819,6 +4963,7 @@ public:
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenAbsOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);

View File

@ -412,6 +412,277 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
.getResult();
}
// Lower indexput op to tosa::scatter op
// Mostly take from the up function convertGatherNdOp()
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Operation *op, Type outType,
Value paramsValue, Value indicesValue,
Value fillValues) {
auto resultType = outType.dyn_cast<ShapedType>();
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
if (!resultType || !paramsType || !indicesType)
return std::nullopt;
// N: number of batches
// Always 1 for ScatterOp
//
// Because TOSA's Scatter operator already uses the symbol 'N' for
// the number of batches, we will use the symbol 'ND' to specify the
// number of dimensions that are sliced from params instead of'N' in
// the TF MLIR documentation.
//
// ND: indices.shape[-1]
//
// W: number of indices in each batch
// Computed as:
// product(indices.shape[0:-1]) (all but the last dimension)
//
// K: range of each index
// Computed as:
// product(params.shape[0:ND-1])
//
// C: number of channels for each index
// Computed as:
// product(params.shape[ND:])
//
// The params tensor needs to be reshaped, but not transposed, to move the
// dimensions into [N, K, C] order.
//
// The dimensions of the input params[] tensor are grouped in the following
// order to begin with:
//
// [ParamIndices, ParamChannels]
// |------------||-------------|
// K C
//
// The reshape simply flattens the params tensor into a 2D [K, C] shape.
//
// Indices needs to be put in the form of [N, W], but a simple flattening
// will not suffice, because the indices need to index into a [W]-shape
// vector instead of the params.shape[0:ND-1] tensor that we had before.
//
// To flatten the coordinates, first reshape indices to a [W, ND] matrix,
// where the matrix now represents W ND-dimensional coordinates into the
// params tensor.
//
// From here, we take each of the ND dimensions and multiply it with
// the size of the next params dimension (or 1 for the last
// dimension), then sum all these together with a reduce_sum
// operator. This is exactly the same mathematics as one would use
// flatten the indices of an N-dimensional row-major array into a
// 1-D array in C.
//
// More precisely, do an element-wise multiply with [params.shape[1
// .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
// [W]-shaped tensor, then trivially reshape to [N=1, W] to be
// compatible with the scatter operator's shape.
//
// Then perform the tosa.scatter() operation.
//
// Now we have result = [N, K, C].
//
// Reshape with a single, simple reshape to the final output shape of:
// [Indices, ParamChannels]
//
// Where, Indices is indices.shape[0:ND-1]
//
// For easy understanding, all following comments take an exact value for each
// argument Example: Take TF style indices as input
// torch.aten._index_put_impl %input, %indices, %fillValue, %false, %false :
// !torch.vtensor<[1,4],si64>, !torch.vtensor<[3,2],si64>,
// !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool ->
// !torch.vtensor<[1,4],si64>
// Detail algorithm visualization:
int N = 1, W = 1, K = 1, fillK = 1, C = 1, ND = 1;
int paramsRank = paramsType.getShape().size(); // 2
int indicesRank = indicesType.getShape().size(); // 2
// ND: indices.shape[-1]
ND = indicesType.getShape()[indicesRank - 1]; // 2 depth of input
if (ND > paramsRank) {
(void)rewriter.notifyMatchFailure(
op, "size of last dimension of indices must be <= params rank");
return std::nullopt;
}
// Calculate N, K, W, C. (N is always 1)
// number of indices/selected value in each batch product(indices.shape[0:-1])
// (all but the last dimension) W = 1*3 = 3
for (int i = 0; i < (indicesRank - 1); i++) {
W *= indicesType.getShape()[i];
}
// K: range of each index, total number of inputs(chould be scatter) after
// flattened k = 1*1*4 = 4
for (int i = 0; i < ND; i++) {
K *= paramsType.getShape()[i];
}
// C: number of channels for each index : numbers of values inside each
// input(chould be scatter) C = product(params.shape[ND:] ND = 2, paramsRank,
// C = 1
for (int i = ND; i < paramsRank; i++) {
C *= paramsType.getShape()[i];
}
// int N = 1, W = 3, K = 4, fillk = 3, C = 1, ND = 2;
SmallVector<int64_t, 3> tosaInputValuesShape({N, K, C}); // {1,4,1}
SmallVector<int64_t, 2> tosaIndicesShape({N, W}); // {1,3}
SmallVector<int64_t, 2> indicesMatrixShape({W, ND}); // {3,2}
SmallVector<int64_t, 2> indicesMatrixReducesumShape({W, 1}); // {3,1}
// Preprocess fill value.
// There are 2 cases of fillValues,
// 1. !torch.vtensor<[1,3],si64>
// [[0,0,0]] -> [[[0], [0], [0]]]
// 2. !torch.vtensor<[],si64>
// reshape(1) tile(3) reshape(1,3) reshape(1,3,1)
// [] -> [0] -> [0,0,0] -> [[0,0,0]] -> [[[0], [0], [0]]]
// reshape to [1] and then tile to same number of indicesValue.shape[0],
// [1,1,1]
if (fillValuesType.getRank() == 0) {
// [] -> [0]
SmallVector<int64_t, 1> oneShape({1}); // {3,1}
auto tosaFillValuesOneReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(oneShape, fillValuesType.getElementType()),
fillValues, rewriter.getDenseI64ArrayAttr(oneShape));
// [0] -> [0,0,0]
SmallVector<int64_t, 1> tileShape({W}); // {3}
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
tosaFillValuesOneReshapeOp.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
// [0,0,0] -> [[0,0,0]]
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}
auto newTosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(newTosaFillValuesShape,
fillValuesType.getElementType()),
tosaFillValuesTileOp.getResult(),
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
fillValues = newTosaFillValuesReshapeOp.getResult();
fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
}
// fillK: range of each index, total number of fillInput(could be scatter)
// after flattened k = 1*1*3 = 3
for (int i = 0; i < ND; i++) {
fillK *= fillValuesType.getShape()[i];
}
SmallVector<int64_t, 3> tosaFillValuesShape({N, fillK, C}); // {1,3,1}
// Reshape/Flatten fillValues to 3d tensor
// [[0,0,0]] -> [[[0], [0], [0]]]
// %10 = "tosa.reshape"(%1) {new_shape = array<i64: 1, 3, 1>} :
// (tensor<1x3xi64>) -> tensor<1x3x1xi64>
auto tosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaFillValuesShape,
fillValuesType.getElementType()),
fillValues, rewriter.getDenseI64ArrayAttr(tosaFillValuesShape));
// Reshape/Flatten input to 3d tensor
// [[1, 2, 3, 4]] -> [[[1], [2], [3], [4]]]
// %9 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 1>} :
// (tensor<1x4xi64>) -> tensor<1x4x1xi64>
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaInputValuesShape, paramsType.getElementType()),
paramsValue, rewriter.getDenseI64ArrayAttr(tosaInputValuesShape));
// Reshape/Flatten the input indices tensor to a 2d [W, ND] matrix.
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
// -> tensor<3x2xi32>
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
SmallVector<int32_t> flattenedCoeffVec; // [4,1]
// flattenedCoeffVec = [4,1]
for (int i = 1; i < ND; i++) {
flattenedCoeffVec.push_back(paramsType.getShape()[i]);
}
flattenedCoeffVec.push_back(1);
// flattenedCoeffVec = [4,1]
for (int i = ND - 1; i > 0; i--) {
flattenedCoeffVec[i - 1] *= flattenedCoeffVec[i];
}
// Create the tosaConstTensor for the flattenedCoeffVec.
// %12 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () ->
// tensor<2xi32>
auto flattenedCoeffValue =
getConstTensor<int32_t>(rewriter, op, flattenedCoeffVec,
{static_cast<int64_t>(flattenedCoeffVec.size())});
if (!flattenedCoeffValue)
return std::nullopt;
// Multiply the coefficients by the coordinates.
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
// tensor<2xi32>) -> tensor<3x2xi32>
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
// Sum up the products of the coefficients and coordinates
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
// %14 = "tosa.reduce_sum"(%13) {axis = 1 : i64} : (tensor<3x2xi32>) ->
// tensor<3x1xi32>
auto flattenedIndicesReduceOp = tosa::CreateOpAndInfer<tosa::ReduceSumOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesMatrixReducesumShape,
indicesType.getElementType()),
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
// And reshape to [N, W]
// [[1],[2],[3]] -> [[1,2,3]]
// %15 = "tosa.reshape"(%14) {new_shape = array<i64: 1, 3>} :
// (tensor<3x1xi32>) -> tensor<1x3xi32>
auto tosaIndicesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
flattenedIndicesReduceOp.getResult(),
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
// Now the Scatter op itself
// %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>,
// tensor<1x3x1xi64>) -> tensor<1x4x1xi64> input = [[[1], [2], [3], [4]]],
// indices = [[1,2,3]], fillValues= [[[0], [0], [0]]] result = [[[1], [0],
// [0], [0]]]
auto tosaScatterOp = tosa::CreateOpAndInfer<tosa::ScatterOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tosaInputValuesShape, resultType.getElementType()),
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult(),
tosaFillValuesReshapeOp.getResult());
// Finally, reshape back to the original output shape of [Indices,
// ParamChannels].
// [[1, 0, 0, 0]]
// %17 = "tosa.reshape"(%16) {new_shape = array<i64: 1, 4>} :
// (tensor<1x4x1xi64>) -> tensor<1x4xi64>
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(), resultType, tosaScatterOp.getResult(),
rewriter.getDenseI64ArrayAttr(resultType.getShape()))
.getResult();
}
// Common function for lowering reduce operations to TOSA ops.
template <typename T>
std::optional<Value> convertReduceOpCommon(

View File

@ -167,6 +167,8 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
"expand-strided-metadata",
"finalize-memref-to-llvm",
"lower-affine",
"convert-bufferization-to-memref",
"finalize-memref-to-llvm",
"func.func(convert-arith-to-llvm)",
"convert-func-to-llvm",
"convert-cf-to-llvm",

View File

@ -61,6 +61,30 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 4], torch.int64, True),
([3], torch.int64, True),
([1, 3], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (None, index),
value,
accumulate=False,
unsafe=False)
@register_test_case(
module_factory=lambda: IndexPutImpl2DNoneIndexStaticModule())
def IndexPutImpl2DNoneIndexStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), tu.randint(1, 3, high=1))
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):

View File

@ -25,6 +25,7 @@ TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
# that depend on TOSA as well as TOSA-to-Standard.
"tosa-to-arith",
"tosa-to-scf",
# Named ops must be legalized prior to general tosa-to-linalg
"tosa-to-linalg-named",
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them