diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index 73abf5cd5..36e391867 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) { } }, py::arg("context"), py::arg("load") = true); + + m.def("get_int64_max", []() { return INT64_MAX; }); + + m.def("get_int64_min", []() { return INT64_MIN; }); } diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 6f936e50e..c498e0437 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -78,6 +78,16 @@ except ModuleNotFoundError: # conditional. ml_dtypes = None +try: + from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity +except ModuleNotFoundError: + # This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity: + # https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e + # Required module may not be present in the stable version of PyTorch. + int_oo = None + IntInfinity = None + NegativeIntInfinity = None + from torch.fx.node import ( Argument as NodeArgument, ) @@ -125,6 +135,8 @@ from ..dialects import ( func as func_dialect, ) +from .._mlir_libs._torchMlir import get_int64_max, get_int64_min + __all__ = [ "FxImporter", ] @@ -1165,22 +1177,32 @@ class ContextCache: self, prog: torch.export.ExportedProgram ) -> Dict[str, RangeConstraint]: + # Recent PyTorch versions use `int_oo` to represent integer infinity. + # Older PyTorch versions like PyTorch stable version may not have + # `int_oo` defined just yet. + infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,) + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): # Convert simple sympy Integers into concrete int - if val == sympy.oo: - return math.inf - if val == -sympy.oo: - return -math.inf + if val in infs: + return get_int64_max() + if val in tuple(-inf for inf in infs): + return get_int64_min() if isinstance(val, sympy.Integer): return int(val) # TODO: Remove this adjustment when fractional ranges are removed return adjust_func(val) contains_symbolic_ints = False + sym_int_types = ( + (sympy.Integer, IntInfinity, NegativeIntInfinity) + if IntInfinity is not None + else sympy.Integer + ) for val in prog.range_constraints.values(): if ( - isinstance(val.lower, sympy.Integer) - and isinstance(val.upper, sympy.Integer) + isinstance(val.lower, sym_int_types) + and isinstance(val.upper, sym_int_types) and not val.is_bool ): contains_symbolic_ints = True diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 5c2ee65a3..be2235ec8 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -88,12 +88,13 @@ def test_import_frozen_exported_program_with_func_name(): @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes -# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32> # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int -# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> -# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> -# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> -# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> +# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -103,10 +104,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): return torch.tanh(x) batch = Dim("batch", max=10) - dynamic_shapes = {"x": {0: batch}} + channel = Dim("channel", min=2) + dynamic_shapes = {"x": {0: batch, 1: channel}} m = fx.export_and_import( Basic(), - torch.randn(3, 4), + torch.randn(3, 4, 5), dynamic_shapes=dynamic_shapes, func_name="test_net", import_symbolic_shape_expressions=True,