mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix dynamic cases for aten.index.Tensor
parent
1e1759c2eb
commit
65d811e267
|
@ -273,6 +273,10 @@ LTC_XFAIL_SET = {
|
|||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_matvec",
|
||||
"MulIntModule_basic",
|
||||
|
|
|
@ -583,10 +583,11 @@ public:
|
|||
int replacedIndexCount = indexTensorDims.size();
|
||||
int64_t startIndex = contiguous ? firstIndexDim : 0;
|
||||
|
||||
// Currently we only support statically sized index tensors
|
||||
// when there is more than one index tensor.
|
||||
// TODO: Add support for dynamic size index tensors. This will probably
|
||||
// require broadcasting the index tensors to a common shape.
|
||||
// Currently we only support statically sized index tensors or dynamic size
|
||||
// index tensors without overlapping dynamic dims when there is more than
|
||||
// one index tensor.
|
||||
// TODO: Add support for dynamic size index tensors with overlapping
|
||||
// dynamic dims.
|
||||
SmallVector<Value> broadcastedIndexShape;
|
||||
if (indexTensors.size() > 1) {
|
||||
int maxRank = -1;
|
||||
|
@ -602,12 +603,39 @@ public:
|
|||
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
|
||||
auto resultDimSize = refinedResultShape[i];
|
||||
if (ShapedType::isDynamic(resultDimSize)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: index tensors must have static shape if "
|
||||
"there is more than one index tensor");
|
||||
SmallVector<Value> dynamicDims;
|
||||
int64_t staticDimSize = -1;
|
||||
for (auto indexTensor : indexTensors) {
|
||||
RankedTensorType indexTensorType =
|
||||
indexTensor.getType().cast<RankedTensorType>();
|
||||
int64_t indexTensorRank = indexTensorType.getRank();
|
||||
if ((maxRank - indexTensorRank) > (i - startIndex))
|
||||
continue;
|
||||
int64_t dim = i - startIndex - maxRank + indexTensorRank;
|
||||
if (ShapedType::isDynamic(indexTensorType.getShape()[dim]))
|
||||
dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim));
|
||||
else
|
||||
staticDimSize =
|
||||
std::max(staticDimSize, indexTensorType.getShape()[dim]);
|
||||
}
|
||||
if (dynamicDims.size() >= 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: index tensors with overlapping dynamic dims");
|
||||
if (staticDimSize > 1) {
|
||||
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
|
||||
rewriter.getIndexType());
|
||||
auto equalToRunning = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
|
||||
dynamicDims[0]);
|
||||
rewriter.create<cf::AssertOp>(loc, equalToRunning,
|
||||
"mismatched size for broadcast");
|
||||
}
|
||||
broadcastedIndexShape.push_back(dynamicDims[0]);
|
||||
} else {
|
||||
broadcastedIndexShape.push_back(getConstant(
|
||||
rewriter, loc, resultDimSize, rewriter.getIndexType()));
|
||||
}
|
||||
broadcastedIndexShape.push_back(
|
||||
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
|
||||
}
|
||||
} else {
|
||||
// For a single indexing tensor we can simply use its (dynamic) sizes
|
||||
|
|
|
@ -1748,6 +1748,125 @@ def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
None,
|
||||
index1,
|
||||
index2,
|
||||
))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic())
|
||||
def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
|
||||
torch.randint(3, (3, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
index1,
|
||||
None,
|
||||
index2,
|
||||
))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic())
|
||||
def IndexTensorMultiInputNonContiguousOneDimDynamic_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
|
||||
torch.randint(3, (3, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 2], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
index2,
|
||||
None,
|
||||
index1,
|
||||
))
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic())
|
||||
def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)),
|
||||
torch.randint(3, (2, )))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([4, 1], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
([-1, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2, index3):
|
||||
return torch.ops.aten.index(x, (index1, index2, index3))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda:
|
||||
IndexTensorMultiInputNonContiguousMultipleStaticDims())
|
||||
def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic(
|
||||
module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)),
|
||||
torch.randint(1, (1, 3)), torch.randint(1, (4, 3)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -2591,7 +2710,7 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
|
|||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
|
@ -2613,7 +2732,7 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
|
|||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
|
@ -2645,4 +2764,4 @@ class AtenToDeviceModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: AtenToDeviceModule())
|
||||
def AtenToDeviceModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(2, 4))
|
||||
module.forward(torch.randn(2, 4))
|
||||
|
|
|
@ -233,3 +233,38 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
|
|||
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
|
||||
return %0 : !torch.vtensor<[?,?],f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index.Tensor
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
|
||||
// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64>
|
||||
// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor<?xi64>
|
||||
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor<?x1xi64>
|
||||
// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor<?xi64>
|
||||
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[OUT_T:.*]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor<?x1xi64>, tensor<?xi64>) outs(%[[OUT_T]] : tensor<?x?x?xf32>) {
|
||||
// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32):
|
||||
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index
|
||||
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
|
||||
// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor<?x?x?xf32>
|
||||
// CHECK: linalg.yield %[[RESULT]] : f32
|
||||
// CHECK: } -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||
// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%none = torch.constant.none
|
||||
%1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
|
||||
%2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %2 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue