mirror of https://github.com/llvm/torch-mlir
[Symbolic Shapes] Test coverage for unbacked symint from data dependent ops (#3542)
We do have support for translating unbacked symbolic_ints that arise from data-dependent ops like `aten.nonzero`. This PR adds the python lit test coverage for the same.pull/3543/head
parent
cdbcf519f7
commit
7411ff2f69
|
@ -84,7 +84,7 @@ def test_tanh_sigmoid_cat():
|
||||||
# CHECK-LABEL: test_symbolic_dim_differ_by_one
|
# CHECK-LABEL: test_symbolic_dim_differ_by_one
|
||||||
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
|
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
|
||||||
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
|
||||||
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
||||||
# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
|
# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
|
||||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
|
||||||
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
|
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
|
||||||
|
@ -262,7 +262,7 @@ def test_div_tensor_mixed_ranks():
|
||||||
@run
|
@run
|
||||||
# CHECK-LABEL: test_shape_div
|
# CHECK-LABEL: test_shape_div
|
||||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> {
|
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> {
|
||||||
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
|
||||||
# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int
|
# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int
|
||||||
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int
|
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int
|
||||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32>
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32>
|
||||||
|
@ -433,7 +433,7 @@ def test_broadcast_unit_dim_to_dynamic_with_rank_increase():
|
||||||
@run
|
@run
|
||||||
# CHECK-LABEL: test_gather_elements
|
# CHECK-LABEL: test_gather_elements
|
||||||
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> {
|
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> {
|
||||||
# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
|
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
|
||||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||||
# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32>
|
# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32>
|
||||||
# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32>
|
# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32>
|
||||||
|
@ -461,3 +461,38 @@ def test_gather_elements():
|
||||||
import_symbolic_shape_expressions=True,
|
import_symbolic_shape_expressions=True,
|
||||||
)
|
)
|
||||||
print(m)
|
print(m)
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_nonzero
|
||||||
|
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,2],si64> {
|
||||||
|
# FIXME: There's a bug in the torch 2.3 stable release which creates redundant symbolic_int ops for the nonzero
|
||||||
|
# output which is fixed in the 2.4 nightlies. Once we move to a 2.4 stable release, this check may be re-enabled
|
||||||
|
# CHECK-DISABLED: %[[U0:.+]] = torch.symbolic_int "u0" {min_val = 0, max_val = 9223372036854775806} : !torch.int
|
||||||
|
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int
|
||||||
|
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
|
||||||
|
# CHECK: %[[NZERO:.+]] = torch.aten.nonzero %[[ARG0]] : !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,2],si64>
|
||||||
|
# CHECK-DISABLED: torch.bind_symbolic_shape %[[NZERO]], [%[[U0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],si64>
|
||||||
|
# CHECK: return %[[NZERO]] : !torch.vtensor<[?,2],si64>
|
||||||
|
def test_nonzero():
|
||||||
|
class Nonzero(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nonzero(x)
|
||||||
|
|
||||||
|
# Sample inputs
|
||||||
|
x = torch.randn(4, 3)
|
||||||
|
|
||||||
|
# Dynamic dim constraints
|
||||||
|
batch = Dim("batch", min=3, max=10)
|
||||||
|
dynamic_shapes = {"x": {0: batch}}
|
||||||
|
|
||||||
|
m = fx.export_and_import(
|
||||||
|
Nonzero(),
|
||||||
|
x,
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
import_symbolic_shape_expressions=True,
|
||||||
|
)
|
||||||
|
print(m)
|
||||||
|
|
Loading…
Reference in New Issue