mirror of https://github.com/llvm/torch-mlir
[fx] Support ExportedProgram buffer mutation. (#3080)
In the prior state when I supported mutation of user inputs by treating them as mutable-tensor SSA values, I had left the case of buffer mutation only vaguely implemented until a concrete use emerged. This patch reworks this buffer mutation support by assuming that buffers must be resolved via the hooks symbolically and treated with load/store semantics. This is implied in the structure since we have no SSA value that represents a buffer and we already assume that reading parameters happens via such a mechanism.pull/3093/head
parent
fe2fb9d9f5
commit
826786bdd0
|
@ -358,7 +358,8 @@ class InputInfo:
|
||||||
input_spec: TypingInputSpec
|
input_spec: TypingInputSpec
|
||||||
node: Node
|
node: Node
|
||||||
ir_type: IrType
|
ir_type: IrType
|
||||||
mutable_producer_node_name: Optional[str]
|
mutable_producer_node_name: Optional[str] = None
|
||||||
|
store_producer_node: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class FxImporterHooks:
|
class FxImporterHooks:
|
||||||
|
@ -387,6 +388,22 @@ class FxImporterHooks:
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def store_produced_value(
|
||||||
|
self,
|
||||||
|
gni: "GraphNodeImporter",
|
||||||
|
py_value: Any,
|
||||||
|
produced_ir_value: Any,
|
||||||
|
info: InputInfo,
|
||||||
|
):
|
||||||
|
"""Given a load/store semantic mutatation, issues the store.
|
||||||
|
|
||||||
|
This style is used for buffer and parameter updates, which are assumed to be
|
||||||
|
non-SSA updates that are otherwise in the value-tensor domain.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Store of a mutation to {info} is not supported (from {produced_ir_value})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FxImporter:
|
class FxImporter:
|
||||||
"""Main entry-point for importing an fx.GraphModule.
|
"""Main entry-point for importing an fx.GraphModule.
|
||||||
|
@ -596,7 +613,11 @@ class FxImporter:
|
||||||
elif input_spec.kind == InputKind.BUFFER and isinstance(
|
elif input_spec.kind == InputKind.BUFFER and isinstance(
|
||||||
arg, TensorArgument
|
arg, TensorArgument
|
||||||
):
|
):
|
||||||
# Remember buffer binding.
|
# Remember buffer binding. Unlike user input mutations, buffers
|
||||||
|
# are assumed to be represented with load/store semantics based
|
||||||
|
# on a symbolic or other non-SSA association. As such, they
|
||||||
|
# are not modeled with mutable IR but will trigger an output
|
||||||
|
# store hook when the final value is produced.
|
||||||
value = prog.state_dict.get(input_spec.target)
|
value = prog.state_dict.get(input_spec.target)
|
||||||
assert (
|
assert (
|
||||||
not input_spec.persistent or value is not None
|
not input_spec.persistent or value is not None
|
||||||
|
@ -605,9 +626,7 @@ class FxImporter:
|
||||||
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
mutable_producer_node_name = mutable_buffer_target_producers.get(
|
||||||
input_spec.target
|
input_spec.target
|
||||||
)
|
)
|
||||||
node_ir_type = self._cc.node_val_to_type(
|
node_ir_type = self._cc.node_val_to_type(node, mutable=False)
|
||||||
node, mutable=bool(mutable_producer_node_name)
|
|
||||||
)
|
|
||||||
buffer_bindings[node] = (
|
buffer_bindings[node] = (
|
||||||
value,
|
value,
|
||||||
InputInfo(
|
InputInfo(
|
||||||
|
@ -615,7 +634,7 @@ class FxImporter:
|
||||||
input_spec,
|
input_spec,
|
||||||
node=node,
|
node=node,
|
||||||
ir_type=node_ir_type,
|
ir_type=node_ir_type,
|
||||||
mutable_producer_node_name=mutable_producer_node_name,
|
store_producer_node=mutable_producer_node_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1136,17 +1155,17 @@ class GraphNodeImporter:
|
||||||
self.bind_node_value(node, _on_access)
|
self.bind_node_value(node, _on_access)
|
||||||
|
|
||||||
if info.mutable_producer_node_name is not None:
|
if info.mutable_producer_node_name is not None:
|
||||||
|
raise NotImplementedError("NYI: Mutable SSA buffer updates")
|
||||||
|
|
||||||
|
if info.store_producer_node is not None:
|
||||||
|
|
||||||
def on_produced(value: Value):
|
def on_produced(value: Value):
|
||||||
mutable_buffer_value = self.resolve_node_value(node)
|
|
||||||
with loc, InsertionPoint(self._b):
|
with loc, InsertionPoint(self._b):
|
||||||
Operation.create(
|
self.fx_importer._hooks.store_produced_value(
|
||||||
"torch.overwrite.tensor.contents",
|
self, buffer_value, value, info
|
||||||
results=[],
|
|
||||||
operands=[value, mutable_buffer_value],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._on_node_produced[info.mutable_producer_node_name] = on_produced
|
self._on_node_produced[info.store_producer_node] = on_produced
|
||||||
|
|
||||||
def return_node_values(self, loc, nodes: List[Node]):
|
def return_node_values(self, loc, nodes: List[Node]):
|
||||||
with loc, InsertionPoint(self._b):
|
with loc, InsertionPoint(self._b):
|
||||||
|
|
|
@ -15,6 +15,7 @@ from torch_mlir import fx
|
||||||
|
|
||||||
from torch_mlir.ir import (
|
from torch_mlir.ir import (
|
||||||
Operation,
|
Operation,
|
||||||
|
StringAttr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,15 +111,24 @@ class ExternalBufferHooks(fx.FxImporterHooks):
|
||||||
|
|
||||||
def resolve_input(self, gni, value, info):
|
def resolve_input(self, gni, value, info):
|
||||||
return Operation.create(
|
return Operation.create(
|
||||||
"my_dialect.import_buffer", results=[info.ir_type]
|
"my_dialect.import_buffer",
|
||||||
|
results=[info.ir_type],
|
||||||
|
attributes={"name": StringAttr.get(info.input_spec.target)},
|
||||||
).result
|
).result
|
||||||
|
|
||||||
|
def store_produced_value(self, gni, py_value, produced_ir_value, info):
|
||||||
|
Operation.create(
|
||||||
|
"my_dialect.store_buffer",
|
||||||
|
operands=[produced_ir_value],
|
||||||
|
attributes={"name": StringAttr.get(info.input_spec.target)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@run
|
@run
|
||||||
# CHECK-LABEL: test_mutable_buffer
|
# CHECK-LABEL: test_mutable_buffer
|
||||||
# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32>
|
# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() {name = "buffer"} : () -> !torch.vtensor<[3,4],f32>
|
||||||
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0
|
# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0
|
||||||
# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]]
|
# CHECK: "my_dialect.store_buffer"(%[[mul]]) {name = "buffer"} : (!torch.vtensor<[3,4],f32>) -> ()
|
||||||
# CHECK: return %arg0
|
# CHECK: return %arg0
|
||||||
def test_mutable_buffer():
|
def test_mutable_buffer():
|
||||||
class Basic(nn.Module):
|
class Basic(nn.Module):
|
||||||
|
@ -141,9 +151,9 @@ def test_mutable_buffer():
|
||||||
|
|
||||||
|
|
||||||
@run
|
@run
|
||||||
# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal
|
# CHECK-LABEL: test_mutable_buffer_not_supported_without_hooks
|
||||||
# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable
|
# CHECK: EXPECTED ERROR: Store of a mutation to {{.*}} is not supported
|
||||||
def test_mutable_buffer_not_supported_from_literal():
|
def test_mutable_buffer_not_supported_without_hooks():
|
||||||
class Basic(nn.Module):
|
class Basic(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -159,5 +169,5 @@ def test_mutable_buffer_not_supported_from_literal():
|
||||||
torch.randn(3, 4),
|
torch.randn(3, 4),
|
||||||
experimental_support_mutation=True,
|
experimental_support_mutation=True,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except NotImplementedError as e:
|
||||||
print("ERROR:", e)
|
print("EXPECTED ERROR:", str(e))
|
||||||
|
|
Loading…
Reference in New Issue