[FxImporter] refactor canonicalize using table driven (#3402)

pull/3641/head
penguin_wwy 2024-08-16 22:57:18 +08:00 committed by GitHub
parent f09cb766dc
commit 37e89828a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 106 additions and 53 deletions

View File

@ -1444,7 +1444,7 @@ class GraphNodeImporter:
self._import_symbolic_torch_op(loc, node, target)
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
self._import_torch_op_overload(loc, node)
elif isinstance(target, HigherOrderOperator):
self._import_hop(loc, node, target)
else:
@ -1615,59 +1615,18 @@ class GraphNodeImporter:
self.bind_node_value(node, value, i + bind_none)
def _import_torch_op_overload(
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
self,
loc: Location,
node: torch_fx.Node,
concrete_target: Optional[TorchOpOverload] = None,
):
# TODO: Convert this cascade of ifs to a table-driven
# replace lift_fresh_copy with clone op
if target == torch.ops.aten.lift_fresh_copy.default:
node.target = target = torch.ops.aten.clone.default
node.args = (node.args[0],)
node.kwargs = {"memory_format": None}
elif target == torch.ops.aten.lift_fresh_copy.out:
# TODO: It seems not possible to hit this case from user code.
# Retaining in case if it is triggered internally somehow, but
# it can most likely be removed once assuming full
# functionalization in all cases.
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0],)
node.kwargs = {"memory_format": None, "out": node.args[1]}
# TODO: generalize empty.memory_format in the future
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
# empty.memory_format input with a constant, which creates NaN values
# because empty.memory_format contains uninitialized data. Converting
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
elif target == torch.ops.aten.empty.memory_format:
if len(node.users) == 1:
for key_node in node.users:
if key_node.target == torch.ops.aten.baddbmm.default:
node.target = target = torch.ops.aten.zeros.default
elif target == torch.ops.aten._local_scalar_dense.default:
input_type = node.args[0].meta["tensor_meta"].dtype
if input_type.is_floating_point:
node.target = target = torch.ops.aten.Float.Tensor
else:
node.target = target = torch.ops.aten.Int.Tensor
node.args = (node.args[0],)
elif target == torch.ops.aten._assert_async.msg:
# TODO: A more suitable op to replace it?
return
elif target == torch.ops.aten._unsafe_index_put.default:
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
elif target == torch.ops.aten._embedding_bag_forward_only.default:
node.target = target = torch.ops.aten.embedding_bag.padding_idx
embedding_bag_args = [
("scale_grad_by_freq", False),
("mode", 0),
("sparse", False),
("per_sample_weights", None),
("include_last_offset", False),
("padding_idx", None),
]
node_kwargs = dict(node.kwargs)
for k, v in embedding_bag_args[len(node.args) - 3 :]:
if k not in node_kwargs:
node_kwargs[k] = v
node.kwargs = node_kwargs
if concrete_target is None:
node = node_canonicalize(node)
if not node:
return
target = node.target
else:
target = concrete_target
schema = target._schema
assert isinstance(schema, FunctionSchema)
@ -2401,3 +2360,97 @@ TENSOR_SCALAR_OP_CONVERTER = {
"torch.aten.sub.Tensor": "torch.aten.sub.Scalar",
"torch.aten.floor_divide": "torch.aten.floor_divide.Scalar",
}
NODE_CANONICALIZE: Dict[TorchOpOverload, Callable] = {}
def register_canonicalize(op: TorchOpOverload):
def wrapper(func):
NODE_CANONICALIZE[op] = func
return func
return wrapper
@register_canonicalize(torch.ops.aten.lift_fresh_copy.default)
def lift_fresh_copy_default(node: torch_fx.Node):
# replace lift_fresh_copy with clone op
node.target = torch.ops.aten.clone.default
node.args = (node.args[0],)
node.kwargs = {"memory_format": None}
return node
@register_canonicalize(torch.ops.aten.lift_fresh_copy.out)
def lift_fresh_copy_out(node: torch_fx.Node):
# TODO: It seems not possible to hit this case from user code.
# Retaining in case if it is triggered internally somehow, but
# it can most likely be removed once assuming full
# functionalization in all cases.
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0],)
node.kwargs = {"memory_format": None, "out": node.args[1]}
return node
@register_canonicalize(torch.ops.aten.empty.memory_format)
def empty_memory_format(node: torch_fx.Node):
# TODO: generalize empty.memory_format in the future
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
# empty.memory_format input with a constant, which creates NaN values
# because empty.memory_format contains uninitialized data. Converting
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
if len(node.users) == 1:
for key_node in node.users:
if key_node.target == torch.ops.aten.baddbmm.default:
node.target = torch.ops.aten.zeros.default
return node
@register_canonicalize(torch.ops.aten._local_scalar_dense.default)
def aten__local_scalar_dense_default(node: torch_fx.Node):
input_type = node.args[0].meta["tensor_meta"].dtype
if input_type.is_floating_point:
node.target = torch.ops.aten.Float.Tensor
else:
node.target = torch.ops.aten.Int.Tensor
node.args = (node.args[0],)
return node
@register_canonicalize(torch.ops.aten._assert_async.msg)
def aten__assert_async_msg(node: torch_fx.Node):
# TODO: A more suitable op to replace it?
return None
@register_canonicalize(torch.ops.aten._unsafe_index_put.default)
def aten__unsafe_index_put_default(node: torch_fx.Node):
node.target = torch.ops.aten._unsafe_index_put.hacked_twin
return node
@register_canonicalize(torch.ops.aten._embedding_bag_forward_only.default)
def aten__embedding_bag_forward_only_default(node: torch_fx.Node):
node.target = torch.ops.aten.embedding_bag.padding_idx
embedding_bag_args = [
("scale_grad_by_freq", False),
("mode", 0),
("sparse", False),
("per_sample_weights", None),
("include_last_offset", False),
("padding_idx", None),
]
node_kwargs = dict(node.kwargs)
for k, v in embedding_bag_args[len(node.args) - 3 :]:
if k not in node_kwargs:
node_kwargs[k] = v
node.kwargs = node_kwargs
return node
def node_canonicalize(node: torch_fx.Node):
if node.target in NODE_CANONICALIZE:
return NODE_CANONICALIZE[node.target](node)
return node