From fa39d91357e8ebbf375f211274dfb4bbbbbc5ccf Mon Sep 17 00:00:00 2001 From: Vimal <111337181+patel-vimal@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:01:17 +0530 Subject: [PATCH] [FxImporter] Fix sympy_int_to_int utility (#3657) New sympy type is introduced to represent integer infinity in upstream PyTorch repo. Subsequently, sympy.oo is no longer used to represent infinity upper bound for dynamic dimensions where the upper bound is unknown. Instead `int_oo` is used to represent integer infinity. This commit updates the `_sympy_int_to_int` utility in light of this change. --- python/TorchMLIRModule.cpp | 4 +++ python/torch_mlir/extras/fx_importer.py | 34 ++++++++++++++++++++----- test/python/fx_importer/basic_test.py | 16 +++++++----- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index 73abf5cd5..36e391867 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) { } }, py::arg("context"), py::arg("load") = true); + + m.def("get_int64_max", []() { return INT64_MAX; }); + + m.def("get_int64_min", []() { return INT64_MIN; }); } diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 6f936e50e..c498e0437 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -78,6 +78,16 @@ except ModuleNotFoundError: # conditional. ml_dtypes = None +try: + from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity +except ModuleNotFoundError: + # This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity: + # https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e + # Required module may not be present in the stable version of PyTorch. + int_oo = None + IntInfinity = None + NegativeIntInfinity = None + from torch.fx.node import ( Argument as NodeArgument, ) @@ -125,6 +135,8 @@ from ..dialects import ( func as func_dialect, ) +from .._mlir_libs._torchMlir import get_int64_max, get_int64_min + __all__ = [ "FxImporter", ] @@ -1165,22 +1177,32 @@ class ContextCache: self, prog: torch.export.ExportedProgram ) -> Dict[str, RangeConstraint]: + # Recent PyTorch versions use `int_oo` to represent integer infinity. + # Older PyTorch versions like PyTorch stable version may not have + # `int_oo` defined just yet. + infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,) + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): # Convert simple sympy Integers into concrete int - if val == sympy.oo: - return math.inf - if val == -sympy.oo: - return -math.inf + if val in infs: + return get_int64_max() + if val in tuple(-inf for inf in infs): + return get_int64_min() if isinstance(val, sympy.Integer): return int(val) # TODO: Remove this adjustment when fractional ranges are removed return adjust_func(val) contains_symbolic_ints = False + sym_int_types = ( + (sympy.Integer, IntInfinity, NegativeIntInfinity) + if IntInfinity is not None + else sympy.Integer + ) for val in prog.range_constraints.values(): if ( - isinstance(val.lower, sympy.Integer) - and isinstance(val.upper, sympy.Integer) + isinstance(val.lower, sym_int_types) + and isinstance(val.upper, sym_int_types) and not val.is_bool ): contains_symbolic_ints = True diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 5c2ee65a3..be2235ec8 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -88,12 +88,13 @@ def test_import_frozen_exported_program_with_func_name(): @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes -# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32> # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int -# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> -# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32> -# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> -# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32> +# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -103,10 +104,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): return torch.tanh(x) batch = Dim("batch", max=10) - dynamic_shapes = {"x": {0: batch}} + channel = Dim("channel", min=2) + dynamic_shapes = {"x": {0: batch, 1: channel}} m = fx.export_and_import( Basic(), - torch.randn(3, 4), + torch.randn(3, 4, 5), dynamic_shapes=dynamic_shapes, func_name="test_net", import_symbolic_shape_expressions=True,