mirror of https://github.com/llvm/torch-mlir
[Fx Importer] fix mutation importer with non persistent buffer (#3798)
A non-persistent buffer will not be a part of this module’s `state_dict`. Hence when setting `experimental_support_mutation=True` and have non-persistent buffer, the current fx importer will fail to retrieve a value from `state_dict` and produce `torch.constant.none` to represent the buffer. This fix get value of non-persistent buffer from the module's `constants`. --------- Co-authored-by: Dixin Zhou <dzhou@vdi-ahddp-020.dhcp.mathworks.com>pull/3820/head
parent
9ce2a69703
commit
8f52f5a4ed
|
@ -723,10 +723,17 @@ class FxImporter:
|
|||
# on a symbolic or other non-SSA association. As such, they
|
||||
# are not modeled with mutable IR but will trigger an output
|
||||
# store hook when the final value is produced.
|
||||
if input_spec.persistent:
|
||||
value = prog.state_dict.get(input_spec.target)
|
||||
assert (
|
||||
not input_spec.persistent or value is not None
|
||||
), "Expected state_dict value for persistent value"
|
||||
value is not None
|
||||
), "Expected state_dict value for persistent buffer"
|
||||
else:
|
||||
value = prog.constants.get(input_spec.target)
|
||||
assert (
|
||||
value is not None
|
||||
), "Expected constants value for non-persistent buffer"
|
||||
|
||||
node = placeholder_nodes[arg.name]
|
||||
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
||||
input_spec.target
|
||||
|
|
|
@ -107,6 +107,27 @@ def test_frozen_buffer():
|
|||
m.operation.verify()
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_frozen_buffer_non_persistent
|
||||
# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal
|
||||
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0
|
||||
# CHECK: return %[[mul]]
|
||||
def test_frozen_buffer_non_persistent():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer", torch.randn(3, 4), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.buffer
|
||||
|
||||
m = fx.export_and_import(
|
||||
Basic(), torch.randn(3, 4), experimental_support_mutation=True
|
||||
)
|
||||
print(m)
|
||||
m.operation.verify()
|
||||
|
||||
|
||||
class ExternalBufferHooks(fx.FxImporterHooks):
|
||||
def prepare_module(self, module_op: Operation):
|
||||
module_op.context.allow_unregistered_dialects = True
|
||||
|
|
Loading…
Reference in New Issue