mirror of https://github.com/llvm/torch-mlir
parent
67ab708b63
commit
8a7340dfb5
|
@ -644,6 +644,7 @@ TOSA_PASS_SET = {
|
|||
"TypePromotionZeroRankHigherCategoryModule_basic",
|
||||
"GatherStaticModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
|
@ -757,6 +758,7 @@ LTC_XFAIL_SET = {
|
|||
"IndexTensorModule3dInput_basic",
|
||||
"IndexTensorModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputNonContiguous_basic",
|
||||
"IndexTensorMultiInputOneDim_basic",
|
||||
|
|
|
@ -3328,17 +3328,106 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
|||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
auto tensors = getTypeConvertedValues(rewriter, op->getLoc(),
|
||||
getTypeConverter(), tensorsTorchType);
|
||||
auto indexTensors = getTypeConvertedValues(
|
||||
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
||||
|
||||
// TODO add support for multiple index
|
||||
if ( tensors.size() > 1){
|
||||
return op.emitError(
|
||||
"unimplemented: the index tensor list from list construct > 1");
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Support for multiple indexes
|
||||
if (indexTensors.size() > 1) {
|
||||
// t[i, i]
|
||||
// = torch.ops.aten.index(t,(i,i))
|
||||
// = tensor([[ t[1,1], t[2,2], t[3,3]],
|
||||
// [ t[3,3], t[2,2], t[1,1]]])
|
||||
// = tensor([[ 7, 13, 19], [19, 13, 7]])
|
||||
// = tf.gather_nd(t,tf.ii_expand)
|
||||
// ii_expand
|
||||
// = tf.concat((i_expand,i_expand), dim=2)
|
||||
// = tf.constant([[[1,1],[2,2],[3,3]],
|
||||
// [[3,3],[2,2],[1,1]]]) # 2*3*2
|
||||
SmallVector<Value> indicesTfConcatTensors;
|
||||
SmallVector<int64_t> indexesRank;
|
||||
SmallVector<SmallVector<int64_t>> indexesShape;
|
||||
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto index = indexTensors[i];
|
||||
auto indexTorch = tensorsTorchType[i];
|
||||
// TODO add support for none index input like torch.ops.aten.index(x,
|
||||
// (None, index1, index2, None))
|
||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexShape = indexType.getShape();
|
||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indiceShapeOneDim;
|
||||
for (auto shape : indexShape) {
|
||||
indiceShapeOneDim.push_back(shape);
|
||||
}
|
||||
indiceShapeOneDim.push_back(1);
|
||||
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
||||
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
|
||||
|
||||
// create concat tensor for indicesTf
|
||||
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
||||
}
|
||||
|
||||
// Right now only support multiple indexes with same shape
|
||||
// TODO for different shape multiple indexes, add broadcast_to for small
|
||||
// shape
|
||||
for (auto indexShapeOneDim : indexesShape) {
|
||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: Only support multi indexes with same shape");
|
||||
}
|
||||
}
|
||||
|
||||
// concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2]
|
||||
auto indicesShapeConcat = indexesShape[0];
|
||||
uint64_t lastDim = indexesRank[0];
|
||||
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||
indicesTfConcatTensors, lastDim);
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf gathernp algorithm with tf style indices as input.
|
||||
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert GatherNdOp fail for index tensor.");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
}
|
||||
auto index = tensors[0];
|
||||
|
||||
// Support for multiple index
|
||||
auto index = indexTensors[0];
|
||||
auto indexTorch = tensorsTorchType[0];
|
||||
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
|
||||
if (!index.getImpl())
|
||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -3350,8 +3439,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
|||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||
}
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indicesShape;
|
||||
for (auto shape : indexShape) {
|
||||
|
|
|
@ -1826,6 +1826,27 @@ class IndexTensorStaticModule(torch.nn.Module):
|
|||
def IndexTensorStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4))
|
||||
|
||||
# ==============================================================================
|
||||
class IndexTensorMultiIndexStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5], torch.float32, True),
|
||||
([2, 3], torch.int64, True),
|
||||
([2, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (index1, index2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexTensorMultiIndexStaticModule())
|
||||
def IndexTensorMultiIndexStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4), tu.randint(2, 3, high=4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue