mirror of https://github.com/llvm/torch-mlir
parent
ba24a46910
commit
5706697e0b
|
@ -306,6 +306,7 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"ToCopyModule_basic",
|
||||
"TransposeIntModule_basic",
|
||||
"TransposeIntNegDimsModule_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
|
||||
# See https://github.com/llvm/torch-mlir/issues/2178
|
||||
"Add_Module_basic"
|
||||
|
@ -811,6 +812,7 @@ STABLEHLO_PASS_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
"AliasModule_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
|
@ -1223,6 +1225,7 @@ LTC_XFAIL_SET = {
|
|||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DIndexModule_basic",
|
||||
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
|
|
|
@ -58,6 +58,12 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
|||
Value params_value,
|
||||
Value indices_value);
|
||||
|
||||
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||
Operation *op, Type outType,
|
||||
Value paramsValue, Value indicesValue,
|
||||
Value fillValues);
|
||||
|
||||
|
||||
// Lowers ReduceAll to a sequence of TOSA ops.
|
||||
std::optional<Value>
|
||||
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||||
|
|
|
@ -3405,6 +3405,150 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||
Aten_IndexPutImplOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// a = torch.tensor([[0, 1, 2, 3]])
|
||||
// a[..., 1:] = torch.tensor([4, 5, 6])
|
||||
// = a[..., 1:4] = torch.tensor([4, 5, 6])
|
||||
// = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
|
||||
// 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
||||
// (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
|
||||
// 3])), # indicies torch.tensor([4, 5, 6])) #
|
||||
// value
|
||||
// = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
|
||||
// (None, torch.tensor([1, 2, 3]),),# indicies
|
||||
// torch.tensor([4, 5, 6])) # value
|
||||
|
||||
// Not a tensor type.
|
||||
auto input = adaptor.getSelf();
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
auto fillValues = adaptor.getValues();
|
||||
auto valuesType = adaptor.getValues().getType().dyn_cast<TensorType>();
|
||||
if (!valuesType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
// Deal with torch.prim.ListConstruct of non const value to get the index
|
||||
auto tensorList = op.getIndices();
|
||||
SmallVector<Value> tensorsTorchType;
|
||||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
auto indexTensors = getTypeConvertedValues(
|
||||
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// convert list of indices with none into indices tensor without none
|
||||
// indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3])
|
||||
// ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]]
|
||||
if (indexTensors.size() <= 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support indexput with multiple index.");
|
||||
}
|
||||
SmallVector<Value> indicesTfConcatTensors;
|
||||
SmallVector<int64_t> indexesRank;
|
||||
SmallVector<SmallVector<int64_t>> indexesShape;
|
||||
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto index = indexTensors[i];
|
||||
auto indexTorch = tensorsTorchType[i];
|
||||
// TODO add support for none index other than i==0, like (index0, None)
|
||||
// (None, index1)
|
||||
if (i == 0 && indexTorch.getType().isa<Torch::NoneType>()) {
|
||||
// convert None to [0,0,0]
|
||||
auto indexNext = indexTensors[i + 1];
|
||||
auto indexNextTorch = tensorsTorchType[i + 1];
|
||||
if (indexNextTorch.getType().isa<Torch::NoneType>()){
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Multiple None index is not support for now.");
|
||||
}
|
||||
auto indexNextType = indexNext.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexNextShape = indexNextType.getShape();
|
||||
|
||||
int64_t size = 1;
|
||||
for (auto s : indexNextShape)
|
||||
size *= s;
|
||||
SmallVector<int32_t> values(size, i);
|
||||
index =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, indexNextShape)
|
||||
.value();
|
||||
}
|
||||
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexShape = indexType.getShape();
|
||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [3] -> [3,1]
|
||||
// convert [0,0,0] to [[0],[0],[0]]
|
||||
SmallVector<int64_t> indiceShapeOneDim;
|
||||
for (auto shape : indexShape) {
|
||||
indiceShapeOneDim.push_back(shape);
|
||||
}
|
||||
indiceShapeOneDim.push_back(1);
|
||||
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
||||
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
|
||||
|
||||
// create concat tensor for indicesTf
|
||||
// ([[0],[0],[0]], [[1],[2],[3]])
|
||||
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
||||
}
|
||||
|
||||
// Right now only support multiple indexes with same shape
|
||||
// TODO for different shape multiple indexes, add broadcast_to for small
|
||||
// shape
|
||||
for (auto indexShapeOneDim : indexesShape) {
|
||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: Only support multi indexes with same shape");
|
||||
}
|
||||
}
|
||||
|
||||
// concat each indices into indicesTf: shape ([3,1],[3,1]) -> [3,2]
|
||||
// ([0,0,0],[1,2,3]) -> [[0,1],[0,2], [0,3]]
|
||||
auto indicesShapeConcat = indexesShape[0];
|
||||
uint64_t lastDim = indexesRank[0];
|
||||
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||
indicesTfConcatTensors, lastDim);
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf scatterNd algorithm with tf style indices as input, algorithm mostly take from convertGatherNdOp.
|
||||
auto result = tosa::convertScatterNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult(), fillValues);
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert ScatterNdOp fail for index tensor.");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||||
AtenIndexTensorOp op, OpAdaptor adaptor,
|
||||
|
@ -3467,7 +3611,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
|||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
// Make type of index tosa compatible, i64 to i32.
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
|
@ -4819,6 +4963,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAbsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
|
|
|
@ -412,6 +412,277 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
|||
.getResult();
|
||||
}
|
||||
|
||||
// Lower indexput op to tosa::scatter op
|
||||
// Mostly take from the up function convertGatherNdOp()
|
||||
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||
Operation *op, Type outType,
|
||||
Value paramsValue, Value indicesValue,
|
||||
Value fillValues) {
|
||||
auto resultType = outType.dyn_cast<ShapedType>();
|
||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!resultType || !paramsType || !indicesType)
|
||||
return std::nullopt;
|
||||
|
||||
// N: number of batches
|
||||
// Always 1 for ScatterOp
|
||||
//
|
||||
// Because TOSA's Scatter operator already uses the symbol 'N' for
|
||||
// the number of batches, we will use the symbol 'ND' to specify the
|
||||
// number of dimensions that are sliced from params instead of'N' in
|
||||
// the TF MLIR documentation.
|
||||
//
|
||||
// ND: indices.shape[-1]
|
||||
//
|
||||
// W: number of indices in each batch
|
||||
// Computed as:
|
||||
// product(indices.shape[0:-1]) (all but the last dimension)
|
||||
//
|
||||
// K: range of each index
|
||||
// Computed as:
|
||||
// product(params.shape[0:ND-1])
|
||||
//
|
||||
// C: number of channels for each index
|
||||
// Computed as:
|
||||
// product(params.shape[ND:])
|
||||
//
|
||||
// The params tensor needs to be reshaped, but not transposed, to move the
|
||||
// dimensions into [N, K, C] order.
|
||||
//
|
||||
// The dimensions of the input params[] tensor are grouped in the following
|
||||
// order to begin with:
|
||||
//
|
||||
// [ParamIndices, ParamChannels]
|
||||
// |------------||-------------|
|
||||
// K C
|
||||
//
|
||||
// The reshape simply flattens the params tensor into a 2D [K, C] shape.
|
||||
//
|
||||
// Indices needs to be put in the form of [N, W], but a simple flattening
|
||||
// will not suffice, because the indices need to index into a [W]-shape
|
||||
// vector instead of the params.shape[0:ND-1] tensor that we had before.
|
||||
//
|
||||
// To flatten the coordinates, first reshape indices to a [W, ND] matrix,
|
||||
// where the matrix now represents W ND-dimensional coordinates into the
|
||||
// params tensor.
|
||||
//
|
||||
// From here, we take each of the ND dimensions and multiply it with
|
||||
// the size of the next params dimension (or 1 for the last
|
||||
// dimension), then sum all these together with a reduce_sum
|
||||
// operator. This is exactly the same mathematics as one would use
|
||||
// flatten the indices of an N-dimensional row-major array into a
|
||||
// 1-D array in C.
|
||||
//
|
||||
// More precisely, do an element-wise multiply with [params.shape[1
|
||||
// .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a
|
||||
// [W]-shaped tensor, then trivially reshape to [N=1, W] to be
|
||||
// compatible with the scatter operator's shape.
|
||||
//
|
||||
// Then perform the tosa.scatter() operation.
|
||||
//
|
||||
// Now we have result = [N, K, C].
|
||||
//
|
||||
// Reshape with a single, simple reshape to the final output shape of:
|
||||
// [Indices, ParamChannels]
|
||||
//
|
||||
// Where, Indices is indices.shape[0:ND-1]
|
||||
//
|
||||
// For easy understanding, all following comments take an exact value for each
|
||||
// argument Example: Take TF style indices as input
|
||||
// torch.aten._index_put_impl %input, %indices, %fillValue, %false, %false :
|
||||
// !torch.vtensor<[1,4],si64>, !torch.vtensor<[3,2],si64>,
|
||||
// !torch.vtensor<[1,3],si64>, !torch.bool, !torch.bool ->
|
||||
// !torch.vtensor<[1,4],si64>
|
||||
// Detail algorithm visualization:
|
||||
|
||||
int N = 1, W = 1, K = 1, fillK = 1, C = 1, ND = 1;
|
||||
|
||||
int paramsRank = paramsType.getShape().size(); // 2
|
||||
int indicesRank = indicesType.getShape().size(); // 2
|
||||
|
||||
// ND: indices.shape[-1]
|
||||
ND = indicesType.getShape()[indicesRank - 1]; // 2 depth of input
|
||||
|
||||
if (ND > paramsRank) {
|
||||
(void)rewriter.notifyMatchFailure(
|
||||
op, "size of last dimension of indices must be <= params rank");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Calculate N, K, W, C. (N is always 1)
|
||||
// number of indices/selected value in each batch product(indices.shape[0:-1])
|
||||
// (all but the last dimension) W = 1*3 = 3
|
||||
for (int i = 0; i < (indicesRank - 1); i++) {
|
||||
W *= indicesType.getShape()[i];
|
||||
}
|
||||
|
||||
// K: range of each index, total number of inputs(chould be scatter) after
|
||||
// flattened k = 1*1*4 = 4
|
||||
for (int i = 0; i < ND; i++) {
|
||||
K *= paramsType.getShape()[i];
|
||||
}
|
||||
|
||||
// C: number of channels for each index : numbers of values inside each
|
||||
// input(chould be scatter) C = product(params.shape[ND:] ND = 2, paramsRank,
|
||||
// C = 1
|
||||
for (int i = ND; i < paramsRank; i++) {
|
||||
C *= paramsType.getShape()[i];
|
||||
}
|
||||
|
||||
// int N = 1, W = 3, K = 4, fillk = 3, C = 1, ND = 2;
|
||||
SmallVector<int64_t, 3> tosaInputValuesShape({N, K, C}); // {1,4,1}
|
||||
SmallVector<int64_t, 2> tosaIndicesShape({N, W}); // {1,3}
|
||||
SmallVector<int64_t, 2> indicesMatrixShape({W, ND}); // {3,2}
|
||||
SmallVector<int64_t, 2> indicesMatrixReducesumShape({W, 1}); // {3,1}
|
||||
|
||||
// Preprocess fill value.
|
||||
// There are 2 cases of fillValues,
|
||||
// 1. !torch.vtensor<[1,3],si64>
|
||||
// [[0,0,0]] -> [[[0], [0], [0]]]
|
||||
// 2. !torch.vtensor<[],si64>
|
||||
// reshape(1) tile(3) reshape(1,3) reshape(1,3,1)
|
||||
// [] -> [0] -> [0,0,0] -> [[0,0,0]] -> [[[0], [0], [0]]]
|
||||
// reshape to [1] and then tile to same number of indicesValue.shape[0],
|
||||
// [1,1,1]
|
||||
if (fillValuesType.getRank() == 0) {
|
||||
// [] -> [0]
|
||||
SmallVector<int64_t, 1> oneShape({1}); // {3,1}
|
||||
auto tosaFillValuesOneReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(oneShape, fillValuesType.getElementType()),
|
||||
fillValues, rewriter.getDenseI64ArrayAttr(oneShape));
|
||||
|
||||
// [0] -> [0,0,0]
|
||||
SmallVector<int64_t, 1> tileShape({W}); // {3}
|
||||
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
|
||||
tosaFillValuesOneReshapeOp.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(tileShape));
|
||||
|
||||
// [0,0,0] -> [[0,0,0]]
|
||||
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}
|
||||
auto newTosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(newTosaFillValuesShape,
|
||||
fillValuesType.getElementType()),
|
||||
tosaFillValuesTileOp.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
|
||||
fillValues = newTosaFillValuesReshapeOp.getResult();
|
||||
fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
// fillK: range of each index, total number of fillInput(could be scatter)
|
||||
// after flattened k = 1*1*3 = 3
|
||||
for (int i = 0; i < ND; i++) {
|
||||
fillK *= fillValuesType.getShape()[i];
|
||||
}
|
||||
SmallVector<int64_t, 3> tosaFillValuesShape({N, fillK, C}); // {1,3,1}
|
||||
|
||||
// Reshape/Flatten fillValues to 3d tensor
|
||||
// [[0,0,0]] -> [[[0], [0], [0]]]
|
||||
// %10 = "tosa.reshape"(%1) {new_shape = array<i64: 1, 3, 1>} :
|
||||
// (tensor<1x3xi64>) -> tensor<1x3x1xi64>
|
||||
auto tosaFillValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(tosaFillValuesShape,
|
||||
fillValuesType.getElementType()),
|
||||
fillValues, rewriter.getDenseI64ArrayAttr(tosaFillValuesShape));
|
||||
|
||||
// Reshape/Flatten input to 3d tensor
|
||||
// [[1, 2, 3, 4]] -> [[[1], [2], [3], [4]]]
|
||||
// %9 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 1>} :
|
||||
// (tensor<1x4xi64>) -> tensor<1x4x1xi64>
|
||||
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(tosaInputValuesShape, paramsType.getElementType()),
|
||||
paramsValue, rewriter.getDenseI64ArrayAttr(tosaInputValuesShape));
|
||||
|
||||
// Reshape/Flatten the input indices tensor to a 2d [W, ND] matrix.
|
||||
// [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]]
|
||||
// %11 = "tosa.reshape"(%8) {new_shape = array<i64: 3, 2>} : (tensor<3x2xi32>)
|
||||
// -> tensor<3x2xi32>
|
||||
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
|
||||
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
|
||||
|
||||
SmallVector<int32_t> flattenedCoeffVec; // [4,1]
|
||||
// flattenedCoeffVec = [4,1]
|
||||
for (int i = 1; i < ND; i++) {
|
||||
flattenedCoeffVec.push_back(paramsType.getShape()[i]);
|
||||
}
|
||||
flattenedCoeffVec.push_back(1);
|
||||
|
||||
// flattenedCoeffVec = [4,1]
|
||||
for (int i = ND - 1; i > 0; i--) {
|
||||
flattenedCoeffVec[i - 1] *= flattenedCoeffVec[i];
|
||||
}
|
||||
|
||||
// Create the tosaConstTensor for the flattenedCoeffVec.
|
||||
// %12 = "tosa.const"() {value = dense<[4, 1]> : tensor<2xi32>} : () ->
|
||||
// tensor<2xi32>
|
||||
auto flattenedCoeffValue =
|
||||
getConstTensor<int32_t>(rewriter, op, flattenedCoeffVec,
|
||||
{static_cast<int64_t>(flattenedCoeffVec.size())});
|
||||
|
||||
if (!flattenedCoeffValue)
|
||||
return std::nullopt;
|
||||
|
||||
// Multiply the coefficients by the coordinates.
|
||||
// [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]]
|
||||
// %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>,
|
||||
// tensor<2xi32>) -> tensor<3x2xi32>
|
||||
auto flattenedIndicesMulOp = tosa::CreateOpAndInfer<tosa::MulOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
|
||||
indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0);
|
||||
|
||||
// Sum up the products of the coefficients and coordinates
|
||||
// [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]]
|
||||
// %14 = "tosa.reduce_sum"(%13) {axis = 1 : i64} : (tensor<3x2xi32>) ->
|
||||
// tensor<3x1xi32>
|
||||
auto flattenedIndicesReduceOp = tosa::CreateOpAndInfer<tosa::ReduceSumOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesMatrixReducesumShape,
|
||||
indicesType.getElementType()),
|
||||
flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// And reshape to [N, W]
|
||||
// [[1],[2],[3]] -> [[1,2,3]]
|
||||
// %15 = "tosa.reshape"(%14) {new_shape = array<i64: 1, 3>} :
|
||||
// (tensor<3x1xi32>) -> tensor<1x3xi32>
|
||||
auto tosaIndicesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
|
||||
flattenedIndicesReduceOp.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
|
||||
|
||||
// Now the Scatter op itself
|
||||
// %16 = "tosa.scatter"(%9, %15, %10) : (tensor<1x4x1xi64>, tensor<1x3xi32>,
|
||||
// tensor<1x3x1xi64>) -> tensor<1x4x1xi64> input = [[[1], [2], [3], [4]]],
|
||||
// indices = [[1,2,3]], fillValues= [[[0], [0], [0]]] result = [[[1], [0],
|
||||
// [0], [0]]]
|
||||
auto tosaScatterOp = tosa::CreateOpAndInfer<tosa::ScatterOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(tosaInputValuesShape, resultType.getElementType()),
|
||||
tosaValuesReshapeOp.getResult(), tosaIndicesReshapeOp.getResult(),
|
||||
tosaFillValuesReshapeOp.getResult());
|
||||
|
||||
// Finally, reshape back to the original output shape of [Indices,
|
||||
// ParamChannels].
|
||||
// [[1, 0, 0, 0]]
|
||||
// %17 = "tosa.reshape"(%16) {new_shape = array<i64: 1, 4>} :
|
||||
// (tensor<1x4x1xi64>) -> tensor<1x4xi64>
|
||||
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(), resultType, tosaScatterOp.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(resultType.getShape()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
||||
// Common function for lowering reduce operations to TOSA ops.
|
||||
template <typename T>
|
||||
std::optional<Value> convertReduceOpCommon(
|
||||
|
|
|
@ -167,6 +167,8 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([
|
|||
"expand-strided-metadata",
|
||||
"finalize-memref-to-llvm",
|
||||
"lower-affine",
|
||||
"convert-bufferization-to-memref",
|
||||
"finalize-memref-to-llvm",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
|
|
|
@ -61,6 +61,30 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 4], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (None, index),
|
||||
value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl2DNoneIndexStaticModule())
|
||||
def IndexPutImpl2DNoneIndexStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), tu.randint(1, 3, high=1))
|
||||
|
||||
|
||||
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ TOSA_TO_LINALG_FUNC_PIPELINE = ",".join([
|
|||
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
|
||||
# that depend on TOSA as well as TOSA-to-Standard.
|
||||
"tosa-to-arith",
|
||||
"tosa-to-scf",
|
||||
# Named ops must be legalized prior to general tosa-to-linalg
|
||||
"tosa-to-linalg-named",
|
||||
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them
|
||||
|
|
Loading…
Reference in New Issue