[Fx Importer] fix mutation importer with non persistent buffer (#3798)

A non-persistent buffer will not be a part of this module’s
`state_dict`. Hence when setting `experimental_support_mutation=True`
and have non-persistent buffer, the current fx importer will fail to
retrieve a value from `state_dict` and produce `torch.constant.none` to
represent the buffer. This fix get value of non-persistent buffer from
the module's `constants`.

---------

Co-authored-by: Dixin Zhou <dzhou@vdi-ahddp-020.dhcp.mathworks.com>
pull/3820/head
Dixin Zhou 2024-10-31 14:20:32 -04:00 committed by GitHub
parent 9ce2a69703
commit 8f52f5a4ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 4 deletions

View File

@ -723,10 +723,17 @@ class FxImporter:
# on a symbolic or other non-SSA association. As such, they # on a symbolic or other non-SSA association. As such, they
# are not modeled with mutable IR but will trigger an output # are not modeled with mutable IR but will trigger an output
# store hook when the final value is produced. # store hook when the final value is produced.
value = prog.state_dict.get(input_spec.target) if input_spec.persistent:
assert ( value = prog.state_dict.get(input_spec.target)
not input_spec.persistent or value is not None assert (
), "Expected state_dict value for persistent value" value is not None
), "Expected state_dict value for persistent buffer"
else:
value = prog.constants.get(input_spec.target)
assert (
value is not None
), "Expected constants value for non-persistent buffer"
node = placeholder_nodes[arg.name] node = placeholder_nodes[arg.name]
mutable_producer_node_name = mutable_buffer_target_producers.get( mutable_producer_node_name = mutable_buffer_target_producers.get(
input_spec.target input_spec.target

View File

@ -107,6 +107,27 @@ def test_frozen_buffer():
m.operation.verify() m.operation.verify()
@run
# CHECK-LABEL: test_frozen_buffer_non_persistent
# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0
# CHECK: return %[[mul]]
def test_frozen_buffer_non_persistent():
class Basic(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.randn(3, 4), persistent=False)
def forward(self, x):
return x * self.buffer
m = fx.export_and_import(
Basic(), torch.randn(3, 4), experimental_support_mutation=True
)
print(m)
m.operation.verify()
class ExternalBufferHooks(fx.FxImporterHooks): class ExternalBufferHooks(fx.FxImporterHooks):
def prepare_module(self, module_op: Operation): def prepare_module(self, module_op: Operation):
module_op.context.allow_unregistered_dialects = True module_op.context.allow_unregistered_dialects = True