mirror of https://github.com/llvm/torch-mlir
[fx] Support ExportedProgram buffer mutation. (#3080)
In the prior state when I supported mutation of user inputs by treating them as mutable-tensor SSA values, I had left the case of buffer mutation only vaguely implemented until a concrete use emerged. This patch reworks this buffer mutation support by assuming that buffers must be resolved via the hooks symbolically and treated with load/store semantics. This is implied in the structure since we have no SSA value that represents a buffer and we already assume that reading parameters happens via such a mechanism.pull/3093/head
parent
fe2fb9d9f5
commit
826786bdd0
|
@ -358,7 +358,8 @@ class InputInfo:
|
|||
input_spec: TypingInputSpec
|
||||
node: Node
|
||||
ir_type: IrType
|
||||
mutable_producer_node_name: Optional[str]
|
||||
mutable_producer_node_name: Optional[str] = None
|
||||
store_producer_node: Optional[str] = None
|
||||
|
||||
|
||||
class FxImporterHooks:
|
||||
|
@ -387,6 +388,22 @@ class FxImporterHooks:
|
|||
"""
|
||||
return None
|
||||
|
||||
def store_produced_value(
|
||||
self,
|
||||
gni: "GraphNodeImporter",
|
||||
py_value: Any,
|
||||
produced_ir_value: Any,
|
||||
info: InputInfo,
|
||||
):
|
||||
"""Given a load/store semantic mutatation, issues the store.
|
||||
|
||||
This style is used for buffer and parameter updates, which are assumed to be
|
||||
non-SSA updates that are otherwise in the value-tensor domain.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"Store of a mutation to {info} is not supported (from {produced_ir_value})"
|
||||
)
|
||||
|
||||
|
||||
class FxImporter:
|
||||
"""Main entry-point for importing an fx.GraphModule.
|
||||
|
@ -596,7 +613,11 @@ class FxImporter:
|
|||
elif input_spec.kind == InputKind.BUFFER and isinstance(
|
||||
arg, TensorArgument
|
||||
):
|
||||
# Remember buffer binding.
|
||||
# Remember buffer binding. Unlike user input mutations, buffers
|
||||
# are assumed to be represented with load/store semantics based
|
||||
# on a symbolic or other non-SSA association. As such, they
|
||||
# are not modeled with mutable IR but will trigger an output
|
||||
# store hook when the final value is produced.
|
||||
value = prog.state_dict.get(input_spec.target)
|
||||
assert (
|
||||
not input_spec.persistent or value is not None
|
||||
|
@ -605,9 +626,7 @@ class FxImporter:
|
|||
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
||||
input_spec.target
|
||||
)
|
||||
node_ir_type = self._cc.node_val_to_type(
|
||||
node, mutable=bool(mutable_producer_node_name)
|
||||
)
|
||||
node_ir_type = self._cc.node_val_to_type(node, mutable=False)
|
||||
buffer_bindings[node] = (
|
||||
value,
|
||||
InputInfo(
|
||||
|
@ -615,7 +634,7 @@ class FxImporter:
|
|||
input_spec,
|
||||
node=node,
|
||||
ir_type=node_ir_type,
|
||||
mutable_producer_node_name=mutable_producer_node_name,
|
||||
store_producer_node=mutable_producer_node_name,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
@ -1136,17 +1155,17 @@ class GraphNodeImporter:
|
|||
self.bind_node_value(node, _on_access)
|
||||
|
||||
if info.mutable_producer_node_name is not None:
|
||||
raise NotImplementedError("NYI: Mutable SSA buffer updates")
|
||||
|
||||
if info.store_producer_node is not None:
|
||||
|
||||
def on_produced(value: Value):
|
||||
mutable_buffer_value = self.resolve_node_value(node)
|
||||
with loc, InsertionPoint(self._b):
|
||||
Operation.create(
|
||||
"torch.overwrite.tensor.contents",
|
||||
results=[],
|
||||
operands=[value, mutable_buffer_value],
|
||||
self.fx_importer._hooks.store_produced_value(
|
||||
self, buffer_value, value, info
|
||||
)
|
||||
|
||||
self._on_node_produced[info.mutable_producer_node_name] = on_produced
|
||||
self._on_node_produced[info.store_producer_node] = on_produced
|
||||
|
||||
def return_node_values(self, loc, nodes: List[Node]):
|
||||
with loc, InsertionPoint(self._b):
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch_mlir import fx
|
|||
|
||||
from torch_mlir.ir import (
|
||||
Operation,
|
||||
StringAttr,
|
||||
)
|
||||
|
||||
|
||||
|
@ -110,15 +111,24 @@ class ExternalBufferHooks(fx.FxImporterHooks):
|
|||
|
||||
def resolve_input(self, gni, value, info):
|
||||
return Operation.create(
|
||||
"my_dialect.import_buffer", results=[info.ir_type]
|
||||
"my_dialect.import_buffer",
|
||||
results=[info.ir_type],
|
||||
attributes={"name": StringAttr.get(info.input_spec.target)},
|
||||
).result
|
||||
|
||||
def store_produced_value(self, gni, py_value, produced_ir_value, info):
|
||||
Operation.create(
|
||||
"my_dialect.store_buffer",
|
||||
operands=[produced_ir_value],
|
||||
attributes={"name": StringAttr.get(info.input_spec.target)},
|
||||
)
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_mutable_buffer
|
||||
# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32>
|
||||
# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() {name = "buffer"} : () -> !torch.vtensor<[3,4],f32>
|
||||
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0
|
||||
# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]]
|
||||
# CHECK: "my_dialect.store_buffer"(%[[mul]]) {name = "buffer"} : (!torch.vtensor<[3,4],f32>) -> ()
|
||||
# CHECK: return %arg0
|
||||
def test_mutable_buffer():
|
||||
class Basic(nn.Module):
|
||||
|
@ -141,9 +151,9 @@ def test_mutable_buffer():
|
|||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal
|
||||
# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable
|
||||
def test_mutable_buffer_not_supported_from_literal():
|
||||
# CHECK-LABEL: test_mutable_buffer_not_supported_without_hooks
|
||||
# CHECK: EXPECTED ERROR: Store of a mutation to {{.*}} is not supported
|
||||
def test_mutable_buffer_not_supported_without_hooks():
|
||||
class Basic(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -159,5 +169,5 @@ def test_mutable_buffer_not_supported_from_literal():
|
|||
torch.randn(3, 4),
|
||||
experimental_support_mutation=True,
|
||||
)
|
||||
except ValueError as e:
|
||||
print("ERROR:", e)
|
||||
except NotImplementedError as e:
|
||||
print("EXPECTED ERROR:", str(e))
|
||||
|
|
Loading…
Reference in New Issue