diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index ce837945f..f622a0b93 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -358,7 +358,8 @@ class InputInfo: input_spec: TypingInputSpec node: Node ir_type: IrType - mutable_producer_node_name: Optional[str] + mutable_producer_node_name: Optional[str] = None + store_producer_node: Optional[str] = None class FxImporterHooks: @@ -387,6 +388,22 @@ class FxImporterHooks: """ return None + def store_produced_value( + self, + gni: "GraphNodeImporter", + py_value: Any, + produced_ir_value: Any, + info: InputInfo, + ): + """Given a load/store semantic mutatation, issues the store. + + This style is used for buffer and parameter updates, which are assumed to be + non-SSA updates that are otherwise in the value-tensor domain. + """ + raise NotImplementedError( + f"Store of a mutation to {info} is not supported (from {produced_ir_value})" + ) + class FxImporter: """Main entry-point for importing an fx.GraphModule. @@ -596,7 +613,11 @@ class FxImporter: elif input_spec.kind == InputKind.BUFFER and isinstance( arg, TensorArgument ): - # Remember buffer binding. + # Remember buffer binding. Unlike user input mutations, buffers + # are assumed to be represented with load/store semantics based + # on a symbolic or other non-SSA association. As such, they + # are not modeled with mutable IR but will trigger an output + # store hook when the final value is produced. value = prog.state_dict.get(input_spec.target) assert ( not input_spec.persistent or value is not None @@ -605,9 +626,7 @@ class FxImporter: mutable_producer_node_name = mutable_buffer_target_producers.get( input_spec.target ) - node_ir_type = self._cc.node_val_to_type( - node, mutable=bool(mutable_producer_node_name) - ) + node_ir_type = self._cc.node_val_to_type(node, mutable=False) buffer_bindings[node] = ( value, InputInfo( @@ -615,7 +634,7 @@ class FxImporter: input_spec, node=node, ir_type=node_ir_type, - mutable_producer_node_name=mutable_producer_node_name, + store_producer_node=mutable_producer_node_name, ), ) else: @@ -1136,17 +1155,17 @@ class GraphNodeImporter: self.bind_node_value(node, _on_access) if info.mutable_producer_node_name is not None: + raise NotImplementedError("NYI: Mutable SSA buffer updates") + + if info.store_producer_node is not None: def on_produced(value: Value): - mutable_buffer_value = self.resolve_node_value(node) with loc, InsertionPoint(self._b): - Operation.create( - "torch.overwrite.tensor.contents", - results=[], - operands=[value, mutable_buffer_value], + self.fx_importer._hooks.store_produced_value( + self, buffer_value, value, info ) - self._on_node_produced[info.mutable_producer_node_name] = on_produced + self._on_node_produced[info.store_producer_node] = on_produced def return_node_values(self, loc, nodes: List[Node]): with loc, InsertionPoint(self._b): diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index ef293b8cb..c62b12706 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -15,6 +15,7 @@ from torch_mlir import fx from torch_mlir.ir import ( Operation, + StringAttr, ) @@ -110,15 +111,24 @@ class ExternalBufferHooks(fx.FxImporterHooks): def resolve_input(self, gni, value, info): return Operation.create( - "my_dialect.import_buffer", results=[info.ir_type] + "my_dialect.import_buffer", + results=[info.ir_type], + attributes={"name": StringAttr.get(info.input_spec.target)}, ).result + def store_produced_value(self, gni, py_value, produced_ir_value, info): + Operation.create( + "my_dialect.store_buffer", + operands=[produced_ir_value], + attributes={"name": StringAttr.get(info.input_spec.target)}, + ) + @run # CHECK-LABEL: test_mutable_buffer -# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32> +# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() {name = "buffer"} : () -> !torch.vtensor<[3,4],f32> # CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0 -# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]] +# CHECK: "my_dialect.store_buffer"(%[[mul]]) {name = "buffer"} : (!torch.vtensor<[3,4],f32>) -> () # CHECK: return %arg0 def test_mutable_buffer(): class Basic(nn.Module): @@ -141,9 +151,9 @@ def test_mutable_buffer(): @run -# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal -# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable -def test_mutable_buffer_not_supported_from_literal(): +# CHECK-LABEL: test_mutable_buffer_not_supported_without_hooks +# CHECK: EXPECTED ERROR: Store of a mutation to {{.*}} is not supported +def test_mutable_buffer_not_supported_without_hooks(): class Basic(nn.Module): def __init__(self): super().__init__() @@ -159,5 +169,5 @@ def test_mutable_buffer_not_supported_from_literal(): torch.randn(3, 4), experimental_support_mutation=True, ) - except ValueError as e: - print("ERROR:", e) + except NotImplementedError as e: + print("EXPECTED ERROR:", str(e))