[fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902)

The investigation is largely recorded in
https://github.com/llvm/torch-mlir/pull/2881, but this change allows us
to capture non-persistent buffers that were lifted as tensor constants
(after https://github.com/pytorch/pytorch/pull/118969 landed in upstream
PyTorch), and propagate them to `Torch` dialect as "frozen"
`torch.vtensor.literal`. I believe this patch should work with both
nightly and stable PyTorch, but will let CI confirm the same. Thanks
@stellaraccident for the valuable pointers and guidance.

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/2906/head
Sambhav Jain 2024-02-13 12:38:32 -08:00 committed by GitHub
parent 9b967f6b5a
commit 3e836d8dad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 10 deletions

View File

@ -373,13 +373,26 @@ class FxImporter:
sig = prog.graph_signature
state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {}
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
if hasattr(prog, "constants"):
constants = prog.constants
# Lift tensor constants.
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
try:
state_value = constants[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for tensor constants") from e
arg_replacements[input_name] = state_value
else:
# Lift buffers.
for input_name, state_name in sig.inputs_to_buffers.items():
try:
state_value = state_dict[state_name]
except KeyError as e:
raise AssertionError("Could not find state mapping for buffer") from e
arg_replacements[input_name] = state_value
# Lift parameters.
for input_name, state_name in sig.inputs_to_parameters.items():

View File

@ -1 +1 @@
72fcb9ad662bb941a266e3d747835382634c2be6
3cbc8e89fd09b0ffb4914187b438f15c121e2302

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.3.0.dev20240122
torch==2.3.0.dev20240207

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.18.0.dev20240122
torchvision==0.18.0.dev20240207