[FxImporter] Fix sympy_int_to_int utility (#3657)

New sympy type is introduced to represent integer infinity in upstream
PyTorch repo. Subsequently, sympy.oo is no longer used to represent
infinity upper bound for dynamic dimensions where the upper bound is
unknown. Instead `int_oo` is used to represent integer infinity. This
commit updates the `_sympy_int_to_int` utility in light of this change.
pull/3637/head
Vimal 2024-08-26 22:01:17 +05:30 committed by GitHub
parent f9766c89f6
commit fa39d91357
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 13 deletions

View File

@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) {
} }
}, },
py::arg("context"), py::arg("load") = true); py::arg("context"), py::arg("load") = true);
m.def("get_int64_max", []() { return INT64_MAX; });
m.def("get_int64_min", []() { return INT64_MIN; });
} }

View File

@ -78,6 +78,16 @@ except ModuleNotFoundError:
# conditional. # conditional.
ml_dtypes = None ml_dtypes = None
try:
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
except ModuleNotFoundError:
# This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity:
# https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e
# Required module may not be present in the stable version of PyTorch.
int_oo = None
IntInfinity = None
NegativeIntInfinity = None
from torch.fx.node import ( from torch.fx.node import (
Argument as NodeArgument, Argument as NodeArgument,
) )
@ -125,6 +135,8 @@ from ..dialects import (
func as func_dialect, func as func_dialect,
) )
from .._mlir_libs._torchMlir import get_int64_max, get_int64_min
__all__ = [ __all__ = [
"FxImporter", "FxImporter",
] ]
@ -1165,22 +1177,32 @@ class ContextCache:
self, prog: torch.export.ExportedProgram self, prog: torch.export.ExportedProgram
) -> Dict[str, RangeConstraint]: ) -> Dict[str, RangeConstraint]:
# Recent PyTorch versions use `int_oo` to represent integer infinity.
# Older PyTorch versions like PyTorch stable version may not have
# `int_oo` defined just yet.
infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,)
def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
# Convert simple sympy Integers into concrete int # Convert simple sympy Integers into concrete int
if val == sympy.oo: if val in infs:
return math.inf return get_int64_max()
if val == -sympy.oo: if val in tuple(-inf for inf in infs):
return -math.inf return get_int64_min()
if isinstance(val, sympy.Integer): if isinstance(val, sympy.Integer):
return int(val) return int(val)
# TODO: Remove this adjustment when fractional ranges are removed # TODO: Remove this adjustment when fractional ranges are removed
return adjust_func(val) return adjust_func(val)
contains_symbolic_ints = False contains_symbolic_ints = False
sym_int_types = (
(sympy.Integer, IntInfinity, NegativeIntInfinity)
if IntInfinity is not None
else sympy.Integer
)
for val in prog.range_constraints.values(): for val in prog.range_constraints.values():
if ( if (
isinstance(val.lower, sympy.Integer) isinstance(val.lower, sym_int_types)
and isinstance(val.upper, sympy.Integer) and isinstance(val.upper, sym_int_types)
and not val.is_bool and not val.is_bool
): ):
contains_symbolic_ints = True contains_symbolic_ints = True

View File

@ -88,12 +88,13 @@ def test_import_frozen_exported_program_with_func_name():
@run @run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32>
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> # CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32>
def test_import_frozen_exported_program_with_dynamic_shapes(): def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module): class Basic(nn.Module):
def __init__(self): def __init__(self):
@ -103,10 +104,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
return torch.tanh(x) return torch.tanh(x)
batch = Dim("batch", max=10) batch = Dim("batch", max=10)
dynamic_shapes = {"x": {0: batch}} channel = Dim("channel", min=2)
dynamic_shapes = {"x": {0: batch, 1: channel}}
m = fx.export_and_import( m = fx.export_and_import(
Basic(), Basic(),
torch.randn(3, 4), torch.randn(3, 4, 5),
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
func_name="test_net", func_name="test_net",
import_symbolic_shape_expressions=True, import_symbolic_shape_expressions=True,