From 8f52f5a4ed6dda42005ccaaf404f031cc83df041 Mon Sep 17 00:00:00 2001 From: Dixin Zhou Date: Thu, 31 Oct 2024 14:20:32 -0400 Subject: [PATCH] [Fx Importer] fix mutation importer with non persistent buffer (#3798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A non-persistent buffer will not be a part of this module’s `state_dict`. Hence when setting `experimental_support_mutation=True` and have non-persistent buffer, the current fx importer will fail to retrieve a value from `state_dict` and produce `torch.constant.none` to represent the buffer. This fix get value of non-persistent buffer from the module's `constants`. --------- Co-authored-by: Dixin Zhou --- python/torch_mlir/extras/fx_importer.py | 15 +++++++++---- .../fx_importer/v2.3/mutation_import.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a8556c54d..cfaa666fd 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -723,10 +723,17 @@ class FxImporter: # on a symbolic or other non-SSA association. As such, they # are not modeled with mutable IR but will trigger an output # store hook when the final value is produced. - value = prog.state_dict.get(input_spec.target) - assert ( - not input_spec.persistent or value is not None - ), "Expected state_dict value for persistent value" + if input_spec.persistent: + value = prog.state_dict.get(input_spec.target) + assert ( + value is not None + ), "Expected state_dict value for persistent buffer" + else: + value = prog.constants.get(input_spec.target) + assert ( + value is not None + ), "Expected constants value for non-persistent buffer" + node = placeholder_nodes[arg.name] mutable_producer_node_name = mutable_buffer_target_producers.get( input_spec.target diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index ee829e455..c2e5d9f14 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -107,6 +107,27 @@ def test_frozen_buffer(): m.operation.verify() +@run +# CHECK-LABEL: test_frozen_buffer_non_persistent +# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0 +# CHECK: return %[[mul]] +def test_frozen_buffer_non_persistent(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4), persistent=False) + + def forward(self, x): + return x * self.buffer + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + class ExternalBufferHooks(fx.FxImporterHooks): def prepare_module(self, module_op: Operation): module_op.context.allow_unregistered_dialects = True