mirror of https://github.com/llvm/torch-mlir
[Fx Importer] fix mutation importer with non persistent buffer (#3798)
A non-persistent buffer will not be a part of this module’s `state_dict`. Hence when setting `experimental_support_mutation=True` and have non-persistent buffer, the current fx importer will fail to retrieve a value from `state_dict` and produce `torch.constant.none` to represent the buffer. This fix get value of non-persistent buffer from the module's `constants`. --------- Co-authored-by: Dixin Zhou <dzhou@vdi-ahddp-020.dhcp.mathworks.com>pull/3820/head
parent
9ce2a69703
commit
8f52f5a4ed
|
@ -723,10 +723,17 @@ class FxImporter:
|
||||||
# on a symbolic or other non-SSA association. As such, they
|
# on a symbolic or other non-SSA association. As such, they
|
||||||
# are not modeled with mutable IR but will trigger an output
|
# are not modeled with mutable IR but will trigger an output
|
||||||
# store hook when the final value is produced.
|
# store hook when the final value is produced.
|
||||||
|
if input_spec.persistent:
|
||||||
value = prog.state_dict.get(input_spec.target)
|
value = prog.state_dict.get(input_spec.target)
|
||||||
assert (
|
assert (
|
||||||
not input_spec.persistent or value is not None
|
value is not None
|
||||||
), "Expected state_dict value for persistent value"
|
), "Expected state_dict value for persistent buffer"
|
||||||
|
else:
|
||||||
|
value = prog.constants.get(input_spec.target)
|
||||||
|
assert (
|
||||||
|
value is not None
|
||||||
|
), "Expected constants value for non-persistent buffer"
|
||||||
|
|
||||||
node = placeholder_nodes[arg.name]
|
node = placeholder_nodes[arg.name]
|
||||||
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
||||||
input_spec.target
|
input_spec.target
|
||||||
|
|
|
@ -107,6 +107,27 @@ def test_frozen_buffer():
|
||||||
m.operation.verify()
|
m.operation.verify()
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
# CHECK-LABEL: test_frozen_buffer_non_persistent
|
||||||
|
# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal
|
||||||
|
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0
|
||||||
|
# CHECK: return %[[mul]]
|
||||||
|
def test_frozen_buffer_non_persistent():
|
||||||
|
class Basic(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("buffer", torch.randn(3, 4), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.buffer
|
||||||
|
|
||||||
|
m = fx.export_and_import(
|
||||||
|
Basic(), torch.randn(3, 4), experimental_support_mutation=True
|
||||||
|
)
|
||||||
|
print(m)
|
||||||
|
m.operation.verify()
|
||||||
|
|
||||||
|
|
||||||
class ExternalBufferHooks(fx.FxImporterHooks):
|
class ExternalBufferHooks(fx.FxImporterHooks):
|
||||||
def prepare_module(self, module_op: Operation):
|
def prepare_module(self, module_op: Operation):
|
||||||
module_op.context.allow_unregistered_dialects = True
|
module_op.context.allow_unregistered_dialects = True
|
||||||
|
|
Loading…
Reference in New Issue