[MLIR][TORCH] Add E2E support for aten.index.Tensor_hacked_twin op

This commit adds lowering of `index.Tensor_hacked_twin` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1363/head
Vivek Khandelwal 2022-09-06 18:59:24 +05:30
parent a12b9c4492
commit 71b1f0dd7a
8 changed files with 162 additions and 15 deletions

View File

@ -447,6 +447,9 @@ LTC_XFAIL_SET = {
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorHackedTwinModule_basic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic",

View File

@ -5661,6 +5661,30 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
}];
}
def Torch_AtenIndexTensorHackedTwinOp : Torch_Op<"aten.index.Tensor_hacked_twin", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTensorType:$indices
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexTensorHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenIndexTensorHackedTwinOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -2762,6 +2762,21 @@ class DecomposeAtenLiftFreshCopyOp
};
} // namespace
namespace {
// Decompose `aten.index.Tensor_hacked_twin` op into `aten.index.Tensor` op.
class DecomposeAtenIndexTensorHackedTwinOp
: public OpRewritePattern<AtenIndexTensorHackedTwinOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexTensorHackedTwinOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenIndexTensorOp>(op, op.getType(), op.self(),
op.indices());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -2949,6 +2964,8 @@ public:
target.addIllegalOp<Aten_EmbeddingBagOp>();
patterns.add<DecomposeAtenLiftFreshCopyOp>(context);
target.addIllegalOp<AtenLiftFreshCopyOp>();
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
for (std::string opName : legalOps) {
target.addLegalOp(OperationName(opName, context));

View File

@ -665,7 +665,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp>(op)) {
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
AtenIndexTensorHackedTwinOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

View File

@ -6822,6 +6822,10 @@ module {
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.index.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
%0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @__torch__.index_tensor_like(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>) -> !torch.list<int> {
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
@ -6932,6 +6936,19 @@ module {
}
return %9 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.index.Tensor_hacked_twin"(%arg0: !torch.list<int>, %arg1: !torch.list<list<int>>) -> !torch.list<int> {
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<optional<list<int>>>
%1 = torch.aten.len.t %arg1 : !torch.list<list<int>> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg2: !torch.int):
%3 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>
%4 = torch.aten.append.t %0, %3 : !torch.list<optional<list<int>>>, !torch.list<int> -> !torch.list<optional<list<int>>>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
%2 = call @__torch__.index_tensor_like(%arg0, %0) : (!torch.list<int>, !torch.list<optional<list<int>>>) -> !torch.list<int>
return %2 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.cat"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -1081,20 +1081,7 @@ def atenconstant_pad_nd(self: List[int], pad: List[int], value: float = 0) ->
def atenpad(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
return pad_shape_fn(self, pad)
# See https://numpy.org/doc/stable/user/basics.indexing.html
@check_shape_function([
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
])
def atenindexTensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
assert len(indices) <= len(self), "More indices than dimensions to index"
broadcasted_shape: List[int] = []
unused_dim_sizes: List[int] = []
@ -1134,6 +1121,26 @@ def atenindexTensor(self: List[int], indices: List[Optional[List[int]]]) -
result_shape.append(unused_dim_sizes[i])
return result_shape
# See https://numpy.org/doc/stable/user/basics.indexing.html
@check_shape_function([
Invocation(TensorOfShape(2), [LongTensorOfShape(4)]), # Basic case.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4)]), # More dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(6, 4)]), # Multidimensional index tensor along a dimension.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4), None]), # Explicit None value.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), LongTensorOfShape(4)]), # Indexing tensors on consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [None, LongTensorOfShape(4), None, LongTensorOfShape(4)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3, 4, 5), [LongTensorOfShape(4, 2), None, LongTensorOfShape(2)]), # Indexing tensors on non-consecutive dimensions.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4, 5, 6), LongTensorOfShape(1, 5, 1)]), # Broadcasting of index tensors.
Invocation(TensorOfShape(2, 3), [LongTensorOfShape(4)]), # Fewer index tensors than dimensions.
ErrorInvocation(TensorOfShape(2, 3), [LongTensorOfShape(4), LongTensorOfShape(4), LongTensorOfShape(4)]), # More index tensors than dimensions.
])
def atenindexTensor(self: List[int], indices: List[Optional[List[int]]]) -> List[int]:
return index_tensor_like(self, indices)
def atenindexTensor_hacked_twin(self: List[int], indices: List[List[int]]) -> List[int]:
optional_indices: List[Optional[List[int]]] = [x for x in indices]
return index_tensor_like(self, optional_indices)
def atencat(tensors: List[List[int]], dim: int = 0) -> List[int]:
return upstream_shape_functions.cat(tensors, dim)

View File

@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")

View File

@ -1963,6 +1963,83 @@ def IndexTensorMultiInputContiguousCenter_basic(module, tu: TestUtils):
# ==============================================================================
class IndexTensorHackedTwinModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, [index])
@register_test_case(module_factory=lambda: IndexTensorHackedTwinModule())
def IndexTensorHackedTwinModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.randint(2, 3, high=4))
# ==============================================================================
class IndexTensorHackedTwinModule3dInput(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, [index])
@register_test_case(
module_factory=lambda: IndexTensorHackedTwinModule3dInput())
def IndexTensorHackedTwinModule3dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
# ==============================================================================
class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(
torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([4, 1], torch.int64, True),
([1, 3], torch.int64, True),
([-1, 3], torch.int64, True),
])
def forward(self, x, index1, index2, index3):
return torch.ops.aten.index(x, [index1, index2, index3])
@register_test_case(
module_factory=lambda:
IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims())
def IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic(
module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), tu.randint(4, 1, high=3),
tu.randint(1, 3, high=1), tu.randint(4, 3, high=1))
# ==============================================================================
class SquareModule(torch.nn.Module):
def __init__(self):