diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9981ed30e..f328bc5d0 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1600,6 +1600,10 @@ class GraphNodeImporter: user_value = self.fx_importer._hooks.resolve_literal(self, py_value) if user_value is not None: assert isinstance(user_value, Value) + if orig_value is not None: + user_value = self._convert_type( + user_value, torch.Tensor, orig_value.dtype, orig_value.size() + ) return user_value # Default conversion path.