[Symbolic Shapes] Test coverage for unbacked symint from data dependent ops (#3542)

We do have support for translating unbacked symbolic_ints that arise
from data-dependent ops like `aten.nonzero`. This PR adds the python lit
test coverage for the same.
pull/3543/head
Sambhav Jain 2024-07-14 11:52:03 -07:00 committed by GitHub
parent cdbcf519f7
commit 7411ff2f69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 38 additions and 3 deletions

View File

@ -84,7 +84,7 @@ def test_tanh_sigmoid_cat():
# CHECK-LABEL: test_symbolic_dim_differ_by_one # CHECK-LABEL: test_symbolic_dim_differ_by_one
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int # CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> # CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32>
@ -262,7 +262,7 @@ def test_div_tensor_mixed_ranks():
@run @run
# CHECK-LABEL: test_shape_div # CHECK-LABEL: test_shape_div
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> {
# This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) # FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+)
# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int # CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int # CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32>
@ -433,7 +433,7 @@ def test_broadcast_unit_dim_to_dynamic_with_rank_increase():
@run @run
# CHECK-LABEL: test_gather_elements # CHECK-LABEL: test_gather_elements
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> {
# CHECK: %[[S0]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 9223372036854775806} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> # CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32>
# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> # CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32>
@ -461,3 +461,38 @@ def test_gather_elements():
import_symbolic_shape_expressions=True, import_symbolic_shape_expressions=True,
) )
print(m) print(m)
@run
# CHECK-LABEL: test_nonzero
# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,2],si64> {
# FIXME: There's a bug in the torch 2.3 stable release which creates redundant symbolic_int ops for the nonzero
# output which is fixed in the 2.4 nightlies. Once we move to a 2.4 stable release, this check may be re-enabled
# CHECK-DISABLED: %[[U0:.+]] = torch.symbolic_int "u0" {min_val = 0, max_val = 9223372036854775806} : !torch.int
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: %[[NZERO:.+]] = torch.aten.nonzero %[[ARG0]] : !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,2],si64>
# CHECK-DISABLED: torch.bind_symbolic_shape %[[NZERO]], [%[[U0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],si64>
# CHECK: return %[[NZERO]] : !torch.vtensor<[?,2],si64>
def test_nonzero():
class Nonzero(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nonzero(x)
# Sample inputs
x = torch.randn(4, 3)
# Dynamic dim constraints
batch = Dim("batch", min=3, max=10)
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
Nonzero(),
x,
dynamic_shapes=dynamic_shapes,
import_symbolic_shape_expressions=True,
)
print(m)