mirror of https://github.com/llvm/torch-mlir
[fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902)
The investigation is largely recorded in https://github.com/llvm/torch-mlir/pull/2881, but this change allows us to capture non-persistent buffers that were lifted as tensor constants (after https://github.com/pytorch/pytorch/pull/118969 landed in upstream PyTorch), and propagate them to `Torch` dialect as "frozen" `torch.vtensor.literal`. I believe this patch should work with both nightly and stable PyTorch, but will let CI confirm the same. Thanks @stellaraccident for the valuable pointers and guidance. --------- Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2906/head
parent
9b967f6b5a
commit
3e836d8dad
|
@ -373,6 +373,19 @@ class FxImporter:
|
|||
sig = prog.graph_signature
|
||||
state_dict = prog.state_dict
|
||||
arg_replacements: dict[str, Any] = {}
|
||||
|
||||
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
||||
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
||||
if hasattr(prog, "constants"):
|
||||
constants = prog.constants
|
||||
# Lift tensor constants.
|
||||
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
|
||||
try:
|
||||
state_value = constants[state_name]
|
||||
except KeyError as e:
|
||||
raise AssertionError("Could not find state mapping for tensor constants") from e
|
||||
arg_replacements[input_name] = state_value
|
||||
else:
|
||||
# Lift buffers.
|
||||
for input_name, state_name in sig.inputs_to_buffers.items():
|
||||
try:
|
||||
|
|
|
@ -1 +1 @@
|
|||
72fcb9ad662bb941a266e3d747835382634c2be6
|
||||
3cbc8e89fd09b0ffb4914187b438f15c121e2302
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torch==2.3.0.dev20240122
|
||||
torch==2.3.0.dev20240207
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torchvision==0.18.0.dev20240122
|
||||
torchvision==0.18.0.dev20240207
|
||||
|
|
Loading…
Reference in New Issue