From b790061b69d05dab503ec90ad2b0ed333dd9b62f Mon Sep 17 00:00:00 2001 From: Christopher McGirr <7071833+chrsmcgrr@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:53:11 +0200 Subject: [PATCH] [FxImporter] Add InputInfo to Resolve Literal Hook (#3688) --- python/torch_mlir/extras/fx_importer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a8d2790e9..a8556c54d 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -470,7 +470,7 @@ class FxImporterHooks: ... def resolve_literal( - self, gni: "GraphNodeImporter", literal: Any + self, gni: "GraphNodeImporter", literal: Any, info: Optional[InputInfo] ) -> Optional[Value]: """User overridable hook to resolve a literal value.""" return None @@ -1826,13 +1826,13 @@ class GraphNodeImporter: name=op_name, results=[result_type], operands=operands ).result - def _import_literal(self, py_value: Any) -> Value: + def _import_literal(self, py_value: Any, info: Optional[InputInfo] = None) -> Value: orig_value = None if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool: orig_value = py_value py_value = py_value.to(torch.uint8) # Apply the conversion callback. - user_value = self.fx_importer._hooks.resolve_literal(self, py_value) + user_value = self.fx_importer._hooks.resolve_literal(self, py_value, info) if user_value is not None: assert isinstance(user_value, Value) if orig_value is not None: @@ -1866,7 +1866,7 @@ class GraphNodeImporter: raise ValueError( f"Cannot import {info.input_spec} as a literal because it is mutable" ) - return self._import_literal(py_value) + return self._import_literal(py_value, info) def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: tensor_arg = torch.tensor(arg)