[MLIR][TORCH] Add E2E support for aten.index_put op

This commit decomposes `aten.index_put` op into
`valsem.aten.index_put_impl` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/679/head
Vivek Khandelwal 2022-03-10 20:48:08 +05:30
parent 3d95c3d6c9
commit 8da7d90611
6 changed files with 131 additions and 2 deletions

View File

@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
@ -110,3 +109,92 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
torch.randint(1000, (10,)))
# ==============================================================================
class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule())
def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(100), torch.randint(100, (250,)),
tu.rand(250))
class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule())
def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
torch.randint(10000, (300,)))
class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule())
def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1000), torch.randint(10, (500,)),
tu.rand(500))
class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule())
def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
torch.randint(1000, (10,)))

View File

@ -1412,6 +1412,22 @@ public:
};
} // namespace
namespace {
// Decompose `aten.index_put` op into `valsem.aten.index_put_impl` op.
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexPutOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
rewriter.replaceOpWithNewOp<ValsemVariantAtenIndexPutImplOp>(
op, op.getType(), op.self(), op.indices(), op.values(), op.accumulate(),
/*unsafe=*/cstFalse);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -1515,6 +1531,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenFullOp>();
patterns.add<DecomposeAtenFullLikeOp>(context);
target.addIllegalOp<AtenFullLikeOp>();
patterns.add<DecomposeAtenIndexPutOp>(context);
target.addIllegalOp<AtenIndexPutOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -516,7 +516,8 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) {
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(
op)) {
ValueKnowledge knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = operands[0]->getValue().dtype;

View File

@ -2409,6 +2409,10 @@ module {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.index_put"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2

View File

@ -742,6 +742,9 @@ def atenselectint(self: List[int], dim: int, index: int) -> List[int]:
def atenindex_select(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_helpers.index_select(self, dim, index)
def atenindex_put(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_helpers.unary(self)
def atenembedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_helpers.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)

View File

@ -647,3 +647,18 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
%0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.index_put(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> {
// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> {
%indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
%0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}