mirror of https://github.com/llvm/torch-mlir
[fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902)
The investigation is largely recorded in https://github.com/llvm/torch-mlir/pull/2881, but this change allows us to capture non-persistent buffers that were lifted as tensor constants (after https://github.com/pytorch/pytorch/pull/118969 landed in upstream PyTorch), and propagate them to `Torch` dialect as "frozen" `torch.vtensor.literal`. I believe this patch should work with both nightly and stable PyTorch, but will let CI confirm the same. Thanks @stellaraccident for the valuable pointers and guidance. --------- Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/2906/head
parent
9b967f6b5a
commit
3e836d8dad
|
@ -373,6 +373,19 @@ class FxImporter:
|
||||||
sig = prog.graph_signature
|
sig = prog.graph_signature
|
||||||
state_dict = prog.state_dict
|
state_dict = prog.state_dict
|
||||||
arg_replacements: dict[str, Any] = {}
|
arg_replacements: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
||||||
|
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
||||||
|
if hasattr(prog, "constants"):
|
||||||
|
constants = prog.constants
|
||||||
|
# Lift tensor constants.
|
||||||
|
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
|
||||||
|
try:
|
||||||
|
state_value = constants[state_name]
|
||||||
|
except KeyError as e:
|
||||||
|
raise AssertionError("Could not find state mapping for tensor constants") from e
|
||||||
|
arg_replacements[input_name] = state_value
|
||||||
|
else:
|
||||||
# Lift buffers.
|
# Lift buffers.
|
||||||
for input_name, state_name in sig.inputs_to_buffers.items():
|
for input_name, state_name in sig.inputs_to_buffers.items():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
72fcb9ad662bb941a266e3d747835382634c2be6
|
3cbc8e89fd09b0ffb4914187b438f15c121e2302
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
--pre
|
--pre
|
||||||
torch==2.3.0.dev20240122
|
torch==2.3.0.dev20240207
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||||
--pre
|
--pre
|
||||||
torchvision==0.18.0.dev20240122
|
torchvision==0.18.0.dev20240207
|
||||||
|
|
Loading…
Reference in New Issue