diff --git a/e2e_testing/torchscript/index_put.py b/e2e_testing/torchscript/index_put.py index f92a13f14..0f2fc3698 100644 --- a/e2e_testing/torchscript/index_put.py +++ b/e2e_testing/torchscript/index_put.py @@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== - class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module): def __init__(self): @@ -110,3 +109,92 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module): def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)), torch.randint(1000, (10,))) + +# ============================================================================== + +class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) + + +@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule()) +def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(100), torch.randint(100, (250,)), + tu.rand(250)) + + +class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.index_put(input, (index,), value, accumulate=False) + + +@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule()) +def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)), + torch.randint(10000, (300,))) + + +class IndexPutOneDimFloatAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) + + +@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule()) +def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1000), torch.randint(10, (500,)), + tu.rand(500)) + + +class IndexPutOneDimIntAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.index_put(input, (index,), value, accumulate=True) + + +@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule()) +def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils): + module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)), + torch.randint(1000, (10,))) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 89fdaedb1..3572f2332 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1412,6 +1412,22 @@ public: }; } // namespace +namespace { +// Decompose `aten.index_put` op into `valsem.aten.index_put_impl` op. +class DecomposeAtenIndexPutOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIndexPutOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.self(), op.indices(), op.values(), op.accumulate(), + /*unsafe=*/cstFalse); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -1515,6 +1531,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 2f0771b49..5ba4463c3 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -516,7 +516,8 @@ ChangeResult TypeAnalyzer::visitOperation( AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, - AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) { + AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>( + op)) { ValueKnowledge knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.dtype = operands[0]->getValue().dtype; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 40dd85f96..4e84f37fe 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2409,6 +2409,10 @@ module { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list return %0 : !torch.list } + func @"__torch_mlir_shape_fn.aten.index_put"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 8634f18cc..be2bc0e81 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -742,6 +742,9 @@ def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]: def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_helpers.index_select(self, dim, index) +def aten〇index_put(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]: + return upstream_shape_helpers.unary(self) + def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_helpers.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 04703c958..033549399 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -647,3 +647,18 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ %0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- +// CHECK-LABEL: func @torch.aten.index_put( +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>, +// CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> { +// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.vtensor<[?],si64>) -> !torch.list +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32> +// CHECK: return %[[RES]] : !torch.vtensor<[?],f32> +func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> { + %indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list + %0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +}