diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index dcba83438..e0ea4f950 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -447,6 +447,9 @@ LTC_XFAIL_SET = { "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousDynamic_basic", "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "Matmul_dot", "Matmul_matvec", "MulIntModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b53425300..9dc9a4244 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5661,6 +5661,30 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ }]; } +def Torch_AtenIndexTensorHackedTwinOp : Torch_Op<"aten.index.Tensor_hacked_twin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTensorType:$indices + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexTensorHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenIndexTensorHackedTwinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 118a43d3f..d37b3fd9a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2762,6 +2762,21 @@ class DecomposeAtenLiftFreshCopyOp }; } // namespace +namespace { +// Decompose `aten.index.Tensor_hacked_twin` op into `aten.index.Tensor` op. +class DecomposeAtenIndexTensorHackedTwinOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + op.indices()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -2949,6 +2964,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); for (std::string opName : legalOps) { target.addLegalOp(OperationName(opName, context)); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6c2828d5b..8c4b4a8fd 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -665,7 +665,8 @@ void TypeAnalysis::visitOperation(Operation *op, AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, - AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp>(op)) { + AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, + AtenIndexTensorHackedTwinOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ec0e81aeb..cba06c835 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6822,6 +6822,10 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list { + %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list + return %0 : !torch.list + } + func.func @__torch__.index_tensor_like(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list { %false = torch.constant.bool false %int-1 = torch.constant.int -1 %true = torch.constant.bool true @@ -6932,6 +6936,19 @@ module { } return %9 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.index.Tensor_hacked_twin"(%arg0: !torch.list, %arg1: !torch.list>) -> !torch.list { + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct : () -> !torch.list>> + %1 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int + torch.prim.Loop %1, %true, init() { + ^bb0(%arg2: !torch.int): + %3 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list>, !torch.int -> !torch.list + %4 = torch.aten.append.t %0, %3 : !torch.list>>, !torch.list -> !torch.list>> + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + %2 = call @__torch__.index_tensor_like(%arg0, %0) : (!torch.list, !torch.list>>) -> !torch.list + return %2 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.cat"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 94c0ea0cd..4f60a5e05 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -1081,20 +1081,7 @@ def aten〇constant_pad_nd(self: List[int], pad: List[int], value: float = 0) -> def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) -# See https://numpy.org/doc/stable/user/basics.indexing.html -@check_shape_function([ - Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. - Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. - Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. - Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. - Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. - ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. -]) -def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: +def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" broadcasted_shape: List[int] = [] unused_dim_sizes: List[int] = [] @@ -1134,6 +1121,26 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) - result_shape.append(unused_dim_sizes[i]) return result_shape +# See https://numpy.org/doc/stable/user/basics.indexing.html +@check_shape_function([ + Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value. + Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors. + Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions. + ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions. +]) +def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: + return index_tensor_like(self, indices) + +def aten〇index〇Tensor_hacked_twin(self: List[int], indices: List[List[int]]) -> List[int]: + optional_indices: List[Optional[List[int]]] = [x for x in indices] + return index_tensor_like(self, optional_indices) + def aten〇cat(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.cat(tensors, dim) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e958d7bee..671411808 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") + emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") emit("aten::item : (Tensor) -> (Scalar)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index d74aa0778..c5c223ec7 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1963,6 +1963,83 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils): # ============================================================================== +class IndexTensorHackedTwinModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, [index]) + + +@register_test_case(module_factory=lambda: IndexTensorHackedTwinModule()) +def IndexTensorHackedTwinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5), tu.randint(2, 3, high=4)) + + +# ============================================================================== + + +class IndexTensorHackedTwinModule3dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, index): + return torch.ops.aten.index(x, [index]) + + +@register_test_case( + module_factory=lambda: IndexTensorHackedTwinModule3dInput()) +def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3)) + + +# ============================================================================== + + +class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims( + torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([4, 1], torch.int64, True), + ([1, 3], torch.int64, True), + ([-1, 3], torch.int64, True), + ]) + def forward(self, x, index1, index2, index3): + return torch.ops.aten.index(x, [index1, index2, index3]) + + +@register_test_case( + module_factory=lambda: + IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims()) +def IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3), + tu.randint(1, 3, high=1), tu.randint(4, 3, high=1)) + + +# ============================================================================== + + class SquareModule(torch.nn.Module): def __init__(self):