[MLIR][TORCH] Add E2E support for aten.index_select op

This commit adds lowering of `aten.index_select` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/471/head snapshot-20211209.134
Vivek Khandelwal 2021-12-03 17:26:21 +05:30 committed by Prashant Kumar
parent 0a0a1b4476
commit 8130354c09
3 changed files with 224 additions and 0 deletions

View File

@ -0,0 +1,145 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class IndexSelectSingleIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([1], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 1, indices)
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
class IndexSelectTwoIdxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([2], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
class IndexSelectWholeDimensionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([4], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 0, indices)
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
class IndexSelectWholeTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3], torch.float32, True),
([3], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 0, indices)
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
class IndexSelectDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([2], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 2, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
([-1], torch.int64, True),
])
def forward(self, input, indices):
return torch.index_select(input, 1, indices)
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))

View File

@ -44,6 +44,7 @@ from . import scalar
from . import squeeze
from . import slice_like
from . import nll_loss
from . import index_select
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -3439,6 +3439,82 @@ public:
};
} // namespace
namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
// integer argument dim = 1. The size of the output tensor will be [4, 2, 6].
// The approach is as follows:
//
// for i in range(input.size[0])
// for j in range(index.size[0])
// for k in range(input.size[2])
// indexValue = index[j]
// output[i,j,k] = input[i,indexValue,k]
class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value input = adaptor.self();
Value indices = adaptor.index();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = resultType.getElementType();
unsigned inputRank = inputType.getRank();
int64_t dimInt;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
SmallVector<AffineExpr> resultExpr;
AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt);
SmallVector<StringRef> iteratorTypes;
for (unsigned i = 0; i < inputRank; i++) {
resultExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
}
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), ValueRange{indices}, initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), args[0]);
SmallVector<Value> indexTarget;
for (unsigned i = 0; i < inputRank; i++)
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
indexTarget[dimInt] = index;
Value extractedElement =
b.create<tensor::ExtractOp>(loc, input, indexTarget);
b.create<linalg::YieldOp>(loc, extractedElement);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -3539,6 +3615,8 @@ public:
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenIndexSelectOp>();
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))