diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a8556c54d..cfaa666fd 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -723,10 +723,17 @@ class FxImporter: # on a symbolic or other non-SSA association. As such, they # are not modeled with mutable IR but will trigger an output # store hook when the final value is produced. - value = prog.state_dict.get(input_spec.target) - assert ( - not input_spec.persistent or value is not None - ), "Expected state_dict value for persistent value" + if input_spec.persistent: + value = prog.state_dict.get(input_spec.target) + assert ( + value is not None + ), "Expected state_dict value for persistent buffer" + else: + value = prog.constants.get(input_spec.target) + assert ( + value is not None + ), "Expected constants value for non-persistent buffer" + node = placeholder_nodes[arg.name] mutable_producer_node_name = mutable_buffer_target_producers.get( input_spec.target diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index ee829e455..c2e5d9f14 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -107,6 +107,27 @@ def test_frozen_buffer(): m.operation.verify() +@run +# CHECK-LABEL: test_frozen_buffer_non_persistent +# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0 +# CHECK: return %[[mul]] +def test_frozen_buffer_non_persistent(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4), persistent=False) + + def forward(self, x): + return x * self.buffer + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + class ExternalBufferHooks(fx.FxImporterHooks): def prepare_module(self, module_op: Operation): module_op.context.allow_unregistered_dialects = True