mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.index_put op
This commit decomposes `aten.index_put` op into `valsem.aten.index_put_impl` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/679/head
parent
3d95c3d6c9
commit
8da7d90611
|
@ -11,7 +11,6 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -110,3 +109,92 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
|||
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||
torch.randint(1000, (10,)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule())
|
||||
def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||
tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule())
|
||||
def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||
torch.randint(10000, (300,)))
|
||||
|
||||
|
||||
class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule())
|
||||
def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||
tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule())
|
||||
def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||
torch.randint(1000, (10,)))
|
||||
|
|
|
@ -1412,6 +1412,22 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.index_put` op into `valsem.aten.index_put_impl` op.
|
||||
class DecomposeAtenIndexPutOp : public OpRewritePattern<AtenIndexPutOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenIndexPutOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenIndexPutImplOp>(
|
||||
op, op.getType(), op.self(), op.indices(), op.values(), op.accumulate(),
|
||||
/*unsafe=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -1515,6 +1531,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenFullOp>();
|
||||
patterns.add<DecomposeAtenFullLikeOp>(context);
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
patterns.add<DecomposeAtenIndexPutOp>(context);
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -516,7 +516,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenPermuteOp,
|
||||
AtenIndexSelectOp, AtenSelectIntOp, AtenSliceTensorOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp>(op)) {
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp>(
|
||||
op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
|
|
|
@ -2409,6 +2409,10 @@ module {
|
|||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.index_put"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
|
|
|
@ -742,6 +742,9 @@ def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]:
|
|||
def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.index_select(self, dim, index)
|
||||
|
||||
def aten〇index_put(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
|
||||
return upstream_shape_helpers.unary(self)
|
||||
|
||||
def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
|
||||
return upstream_shape_helpers.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
|
|
|
@ -647,3 +647,18 @@ func @torch.aten.full_like(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
%0 = torch.aten.full_like %arg0, %int5, %none, %none, %none, %none, %none : !torch.vtensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.index_put(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?],f32>, %[[INDEX:.*]]: !torch.vtensor<[?],si64>,
|
||||
// CHECK-SAME: %[[VALUES:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ACCUM:.*]]: !torch.bool) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[INDEX]] : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RES:.*]] = torch.valsem.aten.index_put_impl %[[INP]], %[[INDICES]], %[[VALUES]], %[[ACCUM]], %[[FALSE]] : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
|
||||
func @torch.aten.index_put(%input: !torch.vtensor<[?],f32>, %index: !torch.vtensor<[?],si64>, %values: !torch.vtensor<[?],f32>, %accumulate : !torch.bool) -> !torch.vtensor<[?],f32> {
|
||||
%indices = torch.prim.ListConstruct %index : (!torch.vtensor<[?],si64>) -> !torch.list<vtensor>
|
||||
%0 = torch.aten.index_put %input, %indices, %values, %accumulate : !torch.vtensor<[?],f32>, !torch.list<vtensor>, !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],f32>
|
||||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue