Add conversion operation for bool resolved_literal (#3410)

Resolving `bool` literals can result in a type change to uint8. This
needs to be converted back to the expected type before returning to the
wrapped `torch` operators.
pull/3436/head
Rob Suderman 2024-06-03 14:43:38 -07:00 committed by GitHub
parent 11c3281a8a
commit 0a6861b1e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 0 deletions

View File

@ -1600,6 +1600,10 @@ class GraphNodeImporter:
user_value = self.fx_importer._hooks.resolve_literal(self, py_value) user_value = self.fx_importer._hooks.resolve_literal(self, py_value)
if user_value is not None: if user_value is not None:
assert isinstance(user_value, Value) assert isinstance(user_value, Value)
if orig_value is not None:
user_value = self._convert_type(
user_value, torch.Tensor, orig_value.dtype, orig_value.size()
)
return user_value return user_value
# Default conversion path. # Default conversion path.