[fx] Support ExportedProgram buffer mutation. (#3080)

In the prior state when I supported mutation of user inputs by treating
them as mutable-tensor SSA values, I had left the case of buffer
mutation only vaguely implemented until a concrete use emerged.
    
This patch reworks this buffer mutation support by assuming that buffers
must be resolved via the hooks symbolically and treated with load/store
semantics. This is implied in the structure since we have no SSA value
that represents a buffer and we already assume that reading parameters
happens via such a mechanism.
pull/3093/head
Stella Laurenzo 2024-04-01 14:18:12 -07:00 committed by GitHub
parent fe2fb9d9f5
commit 826786bdd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 20 deletions

View File

@ -358,7 +358,8 @@ class InputInfo:
input_spec: TypingInputSpec input_spec: TypingInputSpec
node: Node node: Node
ir_type: IrType ir_type: IrType
mutable_producer_node_name: Optional[str] mutable_producer_node_name: Optional[str] = None
store_producer_node: Optional[str] = None
class FxImporterHooks: class FxImporterHooks:
@ -387,6 +388,22 @@ class FxImporterHooks:
""" """
return None return None
def store_produced_value(
self,
gni: "GraphNodeImporter",
py_value: Any,
produced_ir_value: Any,
info: InputInfo,
):
"""Given a load/store semantic mutatation, issues the store.
This style is used for buffer and parameter updates, which are assumed to be
non-SSA updates that are otherwise in the value-tensor domain.
"""
raise NotImplementedError(
f"Store of a mutation to {info} is not supported (from {produced_ir_value})"
)
class FxImporter: class FxImporter:
"""Main entry-point for importing an fx.GraphModule. """Main entry-point for importing an fx.GraphModule.
@ -596,7 +613,11 @@ class FxImporter:
elif input_spec.kind == InputKind.BUFFER and isinstance( elif input_spec.kind == InputKind.BUFFER and isinstance(
arg, TensorArgument arg, TensorArgument
): ):
# Remember buffer binding. # Remember buffer binding. Unlike user input mutations, buffers
# are assumed to be represented with load/store semantics based
# on a symbolic or other non-SSA association. As such, they
# are not modeled with mutable IR but will trigger an output
# store hook when the final value is produced.
value = prog.state_dict.get(input_spec.target) value = prog.state_dict.get(input_spec.target)
assert ( assert (
not input_spec.persistent or value is not None not input_spec.persistent or value is not None
@ -605,9 +626,7 @@ class FxImporter:
mutable_producer_node_name = mutable_buffer_target_producers.get( mutable_producer_node_name = mutable_buffer_target_producers.get(
input_spec.target input_spec.target
) )
node_ir_type = self._cc.node_val_to_type( node_ir_type = self._cc.node_val_to_type(node, mutable=False)
node, mutable=bool(mutable_producer_node_name)
)
buffer_bindings[node] = ( buffer_bindings[node] = (
value, value,
InputInfo( InputInfo(
@ -615,7 +634,7 @@ class FxImporter:
input_spec, input_spec,
node=node, node=node,
ir_type=node_ir_type, ir_type=node_ir_type,
mutable_producer_node_name=mutable_producer_node_name, store_producer_node=mutable_producer_node_name,
), ),
) )
else: else:
@ -1136,17 +1155,17 @@ class GraphNodeImporter:
self.bind_node_value(node, _on_access) self.bind_node_value(node, _on_access)
if info.mutable_producer_node_name is not None: if info.mutable_producer_node_name is not None:
raise NotImplementedError("NYI: Mutable SSA buffer updates")
if info.store_producer_node is not None:
def on_produced(value: Value): def on_produced(value: Value):
mutable_buffer_value = self.resolve_node_value(node)
with loc, InsertionPoint(self._b): with loc, InsertionPoint(self._b):
Operation.create( self.fx_importer._hooks.store_produced_value(
"torch.overwrite.tensor.contents", self, buffer_value, value, info
results=[],
operands=[value, mutable_buffer_value],
) )
self._on_node_produced[info.mutable_producer_node_name] = on_produced self._on_node_produced[info.store_producer_node] = on_produced
def return_node_values(self, loc, nodes: List[Node]): def return_node_values(self, loc, nodes: List[Node]):
with loc, InsertionPoint(self._b): with loc, InsertionPoint(self._b):

View File

@ -15,6 +15,7 @@ from torch_mlir import fx
from torch_mlir.ir import ( from torch_mlir.ir import (
Operation, Operation,
StringAttr,
) )
@ -110,15 +111,24 @@ class ExternalBufferHooks(fx.FxImporterHooks):
def resolve_input(self, gni, value, info): def resolve_input(self, gni, value, info):
return Operation.create( return Operation.create(
"my_dialect.import_buffer", results=[info.ir_type] "my_dialect.import_buffer",
results=[info.ir_type],
attributes={"name": StringAttr.get(info.input_spec.target)},
).result ).result
def store_produced_value(self, gni, py_value, produced_ir_value, info):
Operation.create(
"my_dialect.store_buffer",
operands=[produced_ir_value],
attributes={"name": StringAttr.get(info.input_spec.target)},
)
@run @run
# CHECK-LABEL: test_mutable_buffer # CHECK-LABEL: test_mutable_buffer
# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32> # CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() {name = "buffer"} : () -> !torch.vtensor<[3,4],f32>
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0 # CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0
# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]] # CHECK: "my_dialect.store_buffer"(%[[mul]]) {name = "buffer"} : (!torch.vtensor<[3,4],f32>) -> ()
# CHECK: return %arg0 # CHECK: return %arg0
def test_mutable_buffer(): def test_mutable_buffer():
class Basic(nn.Module): class Basic(nn.Module):
@ -141,9 +151,9 @@ def test_mutable_buffer():
@run @run
# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal # CHECK-LABEL: test_mutable_buffer_not_supported_without_hooks
# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable # CHECK: EXPECTED ERROR: Store of a mutation to {{.*}} is not supported
def test_mutable_buffer_not_supported_from_literal(): def test_mutable_buffer_not_supported_without_hooks():
class Basic(nn.Module): class Basic(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -159,5 +169,5 @@ def test_mutable_buffer_not_supported_from_literal():
torch.randn(3, 4), torch.randn(3, 4),
experimental_support_mutation=True, experimental_support_mutation=True,
) )
except ValueError as e: except NotImplementedError as e:
print("ERROR:", e) print("EXPECTED ERROR:", str(e))