mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.index.Tensor_hacked_twin op
This commit adds lowering of `index.Tensor_hacked_twin` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1363/head
parent
a12b9c4492
commit
71b1f0dd7a
|
@ -447,6 +447,9 @@ LTC_XFAIL_SET = {
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
|
"IndexTensorHackedTwinModule_basic",
|
||||||
|
"IndexTensorHackedTwinModule3dInput_basic",
|
||||||
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
"Matmul_dot",
|
"Matmul_dot",
|
||||||
"Matmul_matvec",
|
"Matmul_matvec",
|
||||||
"MulIntModule_basic",
|
"MulIntModule_basic",
|
||||||
|
|
|
@ -5661,6 +5661,30 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenIndexTensorHackedTwinOp : Torch_Op<"aten.index.Tensor_hacked_twin", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTensorType:$indices
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenIndexTensorHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenIndexTensorHackedTwinOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -2762,6 +2762,21 @@ class DecomposeAtenLiftFreshCopyOp
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `aten.index.Tensor_hacked_twin` op into `aten.index.Tensor` op.
|
||||||
|
class DecomposeAtenIndexTensorHackedTwinOp
|
||||||
|
: public OpRewritePattern<AtenIndexTensorHackedTwinOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(), op.self(),
|
||||||
|
op.indices());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -2949,6 +2964,8 @@ public:
|
||||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||||
patterns.add<DecomposeAtenLiftFreshCopyOp>(context);
|
patterns.add<DecomposeAtenLiftFreshCopyOp>(context);
|
||||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||||
|
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
|
||||||
|
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
||||||
|
|
||||||
for (std::string opName : legalOps) {
|
for (std::string opName : legalOps) {
|
||||||
target.addLegalOp(OperationName(opName, context));
|
target.addLegalOp(OperationName(opName, context));
|
||||||
|
|
|
@ -665,7 +665,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
||||||
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp>(op)) {
|
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
|
||||||
|
AtenIndexTensorHackedTwinOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6822,6 +6822,10 @@ module {
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
|
func.func @__torch__.index_tensor_like(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
%int-1 = torch.constant.int -1
|
%int-1 = torch.constant.int -1
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
@ -6932,6 +6936,19 @@ module {
|
||||||
}
|
}
|
||||||
return %9 : !torch.list<int>
|
return %9 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.index.Tensor_hacked_twin"(%arg0: !torch.list<int>, %arg1: !torch.list<list<int>>) -> !torch.list<int> {
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%0 = torch.prim.ListConstruct : () -> !torch.list<optional<list<int>>>
|
||||||
|
%1 = torch.aten.len.t %arg1 : !torch.list<list<int>> -> !torch.int
|
||||||
|
torch.prim.Loop %1, %true, init() {
|
||||||
|
^bb0(%arg2: !torch.int):
|
||||||
|
%3 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
|
||||||
|
%4 = torch.aten.append.t %0, %3 : !torch.list<optional<list<int>>>, !torch.list<int> -> !torch.list<optional<list<int>>>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
%2 = call @__torch__.index_tensor_like(%arg0, %0) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
|
||||||
|
return %2 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.cat"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.cat"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -1081,20 +1081,7 @@ def aten〇constant_pad_nd(self: List[int], pad: List[int], value: float = 0) ->
|
||||||
def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
|
def aten〇pad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
|
||||||
return pad_shape_fn(self, pad)
|
return pad_shape_fn(self, pad)
|
||||||
|
|
||||||
# See https://numpy.org/doc/stable/user/basics.indexing.html
|
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
||||||
@check_shape_function([
|
|
||||||
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.
|
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions.
|
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension.
|
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
|
|
||||||
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
|
|
||||||
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
|
|
||||||
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
|
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
|
|
||||||
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
|
|
||||||
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
|
|
||||||
])
|
|
||||||
def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
|
||||||
assert len(indices) <= len(self), "More indices than dimensions to index"
|
assert len(indices) <= len(self), "More indices than dimensions to index"
|
||||||
broadcasted_shape: List[int] = []
|
broadcasted_shape: List[int] = []
|
||||||
unused_dim_sizes: List[int] = []
|
unused_dim_sizes: List[int] = []
|
||||||
|
@ -1134,6 +1121,26 @@ def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -
|
||||||
result_shape.append(unused_dim_sizes[i])
|
result_shape.append(unused_dim_sizes[i])
|
||||||
return result_shape
|
return result_shape
|
||||||
|
|
||||||
|
# See https://numpy.org/doc/stable/user/basics.indexing.html
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.
|
||||||
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions.
|
||||||
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension.
|
||||||
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
|
||||||
|
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
|
||||||
|
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
|
||||||
|
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
|
||||||
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
|
||||||
|
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
|
||||||
|
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
|
||||||
|
])
|
||||||
|
def aten〇index〇Tensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
|
||||||
|
return index_tensor_like(self, indices)
|
||||||
|
|
||||||
|
def aten〇index〇Tensor_hacked_twin(self: List[int], indices: List[List[int]]) -> List[int]:
|
||||||
|
optional_indices: List[Optional[List[int]]] = [x for x in indices]
|
||||||
|
return index_tensor_like(self, optional_indices)
|
||||||
|
|
||||||
def aten〇cat(tensors: List[List[int]], dim: int = 0) -> List[int]:
|
def aten〇cat(tensors: List[List[int]], dim: int = 0) -> List[int]:
|
||||||
return upstream_shape_functions.cat(tensors, dim)
|
return upstream_shape_functions.cat(tensors, dim)
|
||||||
|
|
||||||
|
|
|
@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||||
|
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||||
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||||
emit("aten::item : (Tensor) -> (Scalar)")
|
emit("aten::item : (Tensor) -> (Scalar)")
|
||||||
|
|
|
@ -1963,6 +1963,83 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorHackedTwinModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index):
|
||||||
|
return torch.ops.aten.index(x, [index])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexTensorHackedTwinModule())
|
||||||
|
def IndexTensorHackedTwinModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5), tu.randint(2, 3, high=4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorHackedTwinModule3dInput(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index):
|
||||||
|
return torch.ops.aten.index(x, [index])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: IndexTensorHackedTwinModule3dInput())
|
||||||
|
def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(
|
||||||
|
torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([4, 1], torch.int64, True),
|
||||||
|
([1, 3], torch.int64, True),
|
||||||
|
([-1, 3], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, index1, index2, index3):
|
||||||
|
return torch.ops.aten.index(x, [index1, index2, index3])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda:
|
||||||
|
IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims())
|
||||||
|
def IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic(
|
||||||
|
module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3),
|
||||||
|
tu.randint(1, 3, high=1), tu.randint(4, 3, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SquareModule(torch.nn.Module):
|
class SquareModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
Loading…
Reference in New Issue