mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix sympy_int_to_int utility (#3657)
New sympy type is introduced to represent integer infinity in upstream PyTorch repo. Subsequently, sympy.oo is no longer used to represent infinity upper bound for dynamic dimensions where the upper bound is unknown. Instead `int_oo` is used to represent integer infinity. This commit updates the `_sympy_int_to_int` utility in light of this change.pull/3637/head
parent
f9766c89f6
commit
fa39d91357
|
@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) {
|
|||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
|
||||
m.def("get_int64_max", []() { return INT64_MAX; });
|
||||
|
||||
m.def("get_int64_min", []() { return INT64_MIN; });
|
||||
}
|
||||
|
|
|
@ -78,6 +78,16 @@ except ModuleNotFoundError:
|
|||
# conditional.
|
||||
ml_dtypes = None
|
||||
|
||||
try:
|
||||
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
|
||||
except ModuleNotFoundError:
|
||||
# This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity:
|
||||
# https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e
|
||||
# Required module may not be present in the stable version of PyTorch.
|
||||
int_oo = None
|
||||
IntInfinity = None
|
||||
NegativeIntInfinity = None
|
||||
|
||||
from torch.fx.node import (
|
||||
Argument as NodeArgument,
|
||||
)
|
||||
|
@ -125,6 +135,8 @@ from ..dialects import (
|
|||
func as func_dialect,
|
||||
)
|
||||
|
||||
from .._mlir_libs._torchMlir import get_int64_max, get_int64_min
|
||||
|
||||
__all__ = [
|
||||
"FxImporter",
|
||||
]
|
||||
|
@ -1165,22 +1177,32 @@ class ContextCache:
|
|||
self, prog: torch.export.ExportedProgram
|
||||
) -> Dict[str, RangeConstraint]:
|
||||
|
||||
# Recent PyTorch versions use `int_oo` to represent integer infinity.
|
||||
# Older PyTorch versions like PyTorch stable version may not have
|
||||
# `int_oo` defined just yet.
|
||||
infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,)
|
||||
|
||||
def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
|
||||
# Convert simple sympy Integers into concrete int
|
||||
if val == sympy.oo:
|
||||
return math.inf
|
||||
if val == -sympy.oo:
|
||||
return -math.inf
|
||||
if val in infs:
|
||||
return get_int64_max()
|
||||
if val in tuple(-inf for inf in infs):
|
||||
return get_int64_min()
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
# TODO: Remove this adjustment when fractional ranges are removed
|
||||
return adjust_func(val)
|
||||
|
||||
contains_symbolic_ints = False
|
||||
sym_int_types = (
|
||||
(sympy.Integer, IntInfinity, NegativeIntInfinity)
|
||||
if IntInfinity is not None
|
||||
else sympy.Integer
|
||||
)
|
||||
for val in prog.range_constraints.values():
|
||||
if (
|
||||
isinstance(val.lower, sympy.Integer)
|
||||
and isinstance(val.upper, sympy.Integer)
|
||||
isinstance(val.lower, sym_int_types)
|
||||
and isinstance(val.upper, sym_int_types)
|
||||
and not val.is_bool
|
||||
):
|
||||
contains_symbolic_ints = True
|
||||
|
|
|
@ -88,12 +88,13 @@ def test_import_frozen_exported_program_with_func_name():
|
|||
|
||||
@run
|
||||
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
|
||||
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32>
|
||||
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
|
||||
# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
|
||||
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
|
||||
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32>
|
||||
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
|
||||
# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32>
|
||||
def test_import_frozen_exported_program_with_dynamic_shapes():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -103,10 +104,11 @@ def test_import_frozen_exported_program_with_dynamic_shapes():
|
|||
return torch.tanh(x)
|
||||
|
||||
batch = Dim("batch", max=10)
|
||||
dynamic_shapes = {"x": {0: batch}}
|
||||
channel = Dim("channel", min=2)
|
||||
dynamic_shapes = {"x": {0: batch, 1: channel}}
|
||||
m = fx.export_and_import(
|
||||
Basic(),
|
||||
torch.randn(3, 4),
|
||||
torch.randn(3, 4, 5),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
func_name="test_net",
|
||||
import_symbolic_shape_expressions=True,
|
||||
|
|
Loading…
Reference in New Issue