[MLIR][TORCH] Fix dynamic cases for aten.index.Tensor

pull/1252/head snapshot-20220819.570
Vivek Khandelwal 2022-08-16 18:16:42 +05:30
parent 1e1759c2eb
commit 65d811e267
4 changed files with 198 additions and 12 deletions

View File

@ -273,6 +273,10 @@ LTC_XFAIL_SET = {
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic",

View File

@ -583,10 +583,11 @@ public:
int replacedIndexCount = indexTensorDims.size();
int64_t startIndex = contiguous ? firstIndexDim : 0;
// Currently we only support statically sized index tensors
// when there is more than one index tensor.
// TODO: Add support for dynamic size index tensors. This will probably
// require broadcasting the index tensors to a common shape.
// Currently we only support statically sized index tensors or dynamic size
// index tensors without overlapping dynamic dims when there is more than
// one index tensor.
// TODO: Add support for dynamic size index tensors with overlapping
// dynamic dims.
SmallVector<Value> broadcastedIndexShape;
if (indexTensors.size() > 1) {
int maxRank = -1;
@ -602,12 +603,39 @@ public:
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
auto resultDimSize = refinedResultShape[i];
if (ShapedType::isDynamic(resultDimSize)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensors must have static shape if "
"there is more than one index tensor");
SmallVector<Value> dynamicDims;
int64_t staticDimSize = -1;
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
int64_t indexTensorRank = indexTensorType.getRank();
if ((maxRank - indexTensorRank) > (i - startIndex))
continue;
int64_t dim = i - startIndex - maxRank + indexTensorRank;
if (ShapedType::isDynamic(indexTensorType.getShape()[dim]))
dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim));
else
staticDimSize =
std::max(staticDimSize, indexTensorType.getShape()[dim]);
}
if (dynamicDims.size() >= 2)
return rewriter.notifyMatchFailure(
op,
"unimplemented: index tensors with overlapping dynamic dims");
if (staticDimSize > 1) {
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
rewriter.getIndexType());
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
dynamicDims[0]);
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
broadcastedIndexShape.push_back(dynamicDims[0]);
} else {
broadcastedIndexShape.push_back(getConstant(
rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
broadcastedIndexShape.push_back(
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
} else {
// For a single indexing tensor we can simply use its (dynamic) sizes

View File

@ -1748,6 +1748,125 @@ def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
# ==============================================================================
class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
None,
index1,
index2,
))
@register_test_case(
module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic())
def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
torch.randint(3, (3, )))
# ==============================================================================
class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
index1,
None,
index2,
))
@register_test_case(
module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic())
def IndexTensorMultiInputNonContiguousOneDimDynamic_basic(
module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
torch.randint(3, (3, )))
# ==============================================================================
class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 2], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
index2,
None,
index1,
))
@register_test_case(
module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic())
def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)),
torch.randint(3, (2, )))
# ==============================================================================
class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([4, 1], torch.int64, True),
([1, 3], torch.int64, True),
([-1, 3], torch.int64, True),
])
def forward(self, x, index1, index2, index3):
return torch.ops.aten.index(x, (index1, index2, index3))
@register_test_case(module_factory=lambda:
IndexTensorMultiInputNonContiguousMultipleStaticDims())
def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic(
module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)),
torch.randint(1, (1, 3)), torch.randint(1, (4, 3)))
# ==============================================================================
class IndexTensorMultiInputNonContiguous(torch.nn.Module):
def __init__(self):
@ -2591,7 +2710,7 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
@export
@annotate_args([
None,
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
@ -2613,7 +2732,7 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
@export
@annotate_args([
None,
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
@ -2645,4 +2764,4 @@ class AtenToDeviceModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AtenToDeviceModule())
def AtenToDeviceModule_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
module.forward(torch.randn(2, 4))

View File

@ -233,3 +233,38 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.index.Tensor
// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64>
// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor<?xi64>
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor<?x1xi64>
// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor<?xi64>
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor<?x?x?xf32>
// CHECK: %[[OUT_T:.*]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] : tensor<?x?x?xf32>
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor<?x1xi64>, tensor<?xi64>) outs(%[[OUT_T]] : tensor<?x?x?xf32>) {
// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index
// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor<?x?x?xf32>
// CHECK: linalg.yield %[[RESULT]] : f32
// CHECK: } -> tensor<?x?x?xf32>
// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
%none = torch.constant.none
%1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
%2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor<[?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?],f32>
}