mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.index_select op
This commit adds lowering of `aten.index_select` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/471/head snapshot-20211209.134
parent
0a0a1b4476
commit
8130354c09
|
@ -0,0 +1,145 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexSelectSingleIdxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectSingleIdxModule())
|
||||
def IndexSelectSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2]))
|
||||
|
||||
|
||||
class IndexSelectTwoIdxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectTwoIdxModule())
|
||||
def IndexSelectTwoIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4]))
|
||||
|
||||
|
||||
class IndexSelectWholeDimensionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([4], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule())
|
||||
def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3]))
|
||||
|
||||
|
||||
class IndexSelectWholeTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3], torch.float32, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 0, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectWholeTensorModule())
|
||||
def IndexSelectWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3), torch.tensor([0, 1, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicModule())
|
||||
def IndexSelectDynamicModulebasic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4]))
|
||||
|
||||
|
||||
class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 2, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule())
|
||||
def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2]))
|
||||
|
||||
|
||||
class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
return torch.index_select(input, 1, indices)
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule())
|
||||
def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2]))
|
|
@ -44,6 +44,7 @@ from . import scalar
|
|||
from . import squeeze
|
||||
from . import slice_like
|
||||
from . import nll_loss
|
||||
from . import index_select
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -3439,6 +3439,82 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Let's say we have an input tensor: initialized with some random values of
|
||||
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
|
||||
// integer argument dim = 1. The size of the output tensor will be [4, 2, 6].
|
||||
// The approach is as follows:
|
||||
//
|
||||
// for i in range(input.size[0])
|
||||
// for j in range(index.size[0])
|
||||
// for k in range(input.size[2])
|
||||
// indexValue = index[j]
|
||||
// output[i,j,k] = input[i,indexValue,k]
|
||||
|
||||
class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.self();
|
||||
Value indices = adaptor.index();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type elementType = resultType.getElementType();
|
||||
unsigned inputRank = inputType.getRank();
|
||||
|
||||
int64_t dimInt;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dimInt)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
|
||||
resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0];
|
||||
Value initTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementType);
|
||||
|
||||
SmallVector<AffineExpr> resultExpr;
|
||||
AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt);
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
|
||||
for (unsigned i = 0; i < inputRank; i++) {
|
||||
resultExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
}
|
||||
|
||||
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr});
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), ValueRange{indices}, initTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value index = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), args[0]);
|
||||
SmallVector<Value> indexTarget;
|
||||
for (unsigned i = 0; i < inputRank; i++)
|
||||
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
indexTarget[dimInt] = index;
|
||||
Value extractedElement =
|
||||
b.create<tensor::ExtractOp>(loc, input, indexTarget);
|
||||
b.create<linalg::YieldOp>(loc, extractedElement);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -3539,6 +3615,8 @@ public:
|
|||
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIndexSelectOp>();
|
||||
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
Loading…
Reference in New Issue