mirror of https://github.com/llvm/torch-mlir
[FxImporter] refactor canonicalize using table driven (#3402)
parent
f09cb766dc
commit
37e89828a1
|
@ -1444,7 +1444,7 @@ class GraphNodeImporter:
|
|||
self._import_symbolic_torch_op(loc, node, target)
|
||||
elif isinstance(target, TorchOpOverload):
|
||||
# Dispatch to an ATen op.
|
||||
self._import_torch_op_overload(loc, node, target)
|
||||
self._import_torch_op_overload(loc, node)
|
||||
elif isinstance(target, HigherOrderOperator):
|
||||
self._import_hop(loc, node, target)
|
||||
else:
|
||||
|
@ -1615,59 +1615,18 @@ class GraphNodeImporter:
|
|||
self.bind_node_value(node, value, i + bind_none)
|
||||
|
||||
def _import_torch_op_overload(
|
||||
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
|
||||
self,
|
||||
loc: Location,
|
||||
node: torch_fx.Node,
|
||||
concrete_target: Optional[TorchOpOverload] = None,
|
||||
):
|
||||
# TODO: Convert this cascade of ifs to a table-driven
|
||||
# replace lift_fresh_copy with clone op
|
||||
if target == torch.ops.aten.lift_fresh_copy.default:
|
||||
node.target = target = torch.ops.aten.clone.default
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None}
|
||||
elif target == torch.ops.aten.lift_fresh_copy.out:
|
||||
# TODO: It seems not possible to hit this case from user code.
|
||||
# Retaining in case if it is triggered internally somehow, but
|
||||
# it can most likely be removed once assuming full
|
||||
# functionalization in all cases.
|
||||
node.target = target = torch.ops.aten.clone.out
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None, "out": node.args[1]}
|
||||
# TODO: generalize empty.memory_format in the future
|
||||
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
||||
# empty.memory_format input with a constant, which creates NaN values
|
||||
# because empty.memory_format contains uninitialized data. Converting
|
||||
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
|
||||
elif target == torch.ops.aten.empty.memory_format:
|
||||
if len(node.users) == 1:
|
||||
for key_node in node.users:
|
||||
if key_node.target == torch.ops.aten.baddbmm.default:
|
||||
node.target = target = torch.ops.aten.zeros.default
|
||||
elif target == torch.ops.aten._local_scalar_dense.default:
|
||||
input_type = node.args[0].meta["tensor_meta"].dtype
|
||||
if input_type.is_floating_point:
|
||||
node.target = target = torch.ops.aten.Float.Tensor
|
||||
else:
|
||||
node.target = target = torch.ops.aten.Int.Tensor
|
||||
node.args = (node.args[0],)
|
||||
elif target == torch.ops.aten._assert_async.msg:
|
||||
# TODO: A more suitable op to replace it?
|
||||
return
|
||||
elif target == torch.ops.aten._unsafe_index_put.default:
|
||||
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
|
||||
elif target == torch.ops.aten._embedding_bag_forward_only.default:
|
||||
node.target = target = torch.ops.aten.embedding_bag.padding_idx
|
||||
embedding_bag_args = [
|
||||
("scale_grad_by_freq", False),
|
||||
("mode", 0),
|
||||
("sparse", False),
|
||||
("per_sample_weights", None),
|
||||
("include_last_offset", False),
|
||||
("padding_idx", None),
|
||||
]
|
||||
node_kwargs = dict(node.kwargs)
|
||||
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
||||
if k not in node_kwargs:
|
||||
node_kwargs[k] = v
|
||||
node.kwargs = node_kwargs
|
||||
if concrete_target is None:
|
||||
node = node_canonicalize(node)
|
||||
if not node:
|
||||
return
|
||||
target = node.target
|
||||
else:
|
||||
target = concrete_target
|
||||
|
||||
schema = target._schema
|
||||
assert isinstance(schema, FunctionSchema)
|
||||
|
@ -2401,3 +2360,97 @@ TENSOR_SCALAR_OP_CONVERTER = {
|
|||
"torch.aten.sub.Tensor": "torch.aten.sub.Scalar",
|
||||
"torch.aten.floor_divide": "torch.aten.floor_divide.Scalar",
|
||||
}
|
||||
|
||||
|
||||
NODE_CANONICALIZE: Dict[TorchOpOverload, Callable] = {}
|
||||
|
||||
|
||||
def register_canonicalize(op: TorchOpOverload):
|
||||
def wrapper(func):
|
||||
NODE_CANONICALIZE[op] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten.lift_fresh_copy.default)
|
||||
def lift_fresh_copy_default(node: torch_fx.Node):
|
||||
# replace lift_fresh_copy with clone op
|
||||
node.target = torch.ops.aten.clone.default
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None}
|
||||
return node
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten.lift_fresh_copy.out)
|
||||
def lift_fresh_copy_out(node: torch_fx.Node):
|
||||
# TODO: It seems not possible to hit this case from user code.
|
||||
# Retaining in case if it is triggered internally somehow, but
|
||||
# it can most likely be removed once assuming full
|
||||
# functionalization in all cases.
|
||||
node.target = target = torch.ops.aten.clone.out
|
||||
node.args = (node.args[0],)
|
||||
node.kwargs = {"memory_format": None, "out": node.args[1]}
|
||||
return node
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten.empty.memory_format)
|
||||
def empty_memory_format(node: torch_fx.Node):
|
||||
# TODO: generalize empty.memory_format in the future
|
||||
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
||||
# empty.memory_format input with a constant, which creates NaN values
|
||||
# because empty.memory_format contains uninitialized data. Converting
|
||||
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
|
||||
if len(node.users) == 1:
|
||||
for key_node in node.users:
|
||||
if key_node.target == torch.ops.aten.baddbmm.default:
|
||||
node.target = torch.ops.aten.zeros.default
|
||||
return node
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten._local_scalar_dense.default)
|
||||
def aten__local_scalar_dense_default(node: torch_fx.Node):
|
||||
input_type = node.args[0].meta["tensor_meta"].dtype
|
||||
if input_type.is_floating_point:
|
||||
node.target = torch.ops.aten.Float.Tensor
|
||||
else:
|
||||
node.target = torch.ops.aten.Int.Tensor
|
||||
node.args = (node.args[0],)
|
||||
return node
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten._assert_async.msg)
|
||||
def aten__assert_async_msg(node: torch_fx.Node):
|
||||
# TODO: A more suitable op to replace it?
|
||||
return None
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten._unsafe_index_put.default)
|
||||
def aten__unsafe_index_put_default(node: torch_fx.Node):
|
||||
node.target = torch.ops.aten._unsafe_index_put.hacked_twin
|
||||
return node
|
||||
|
||||
|
||||
@register_canonicalize(torch.ops.aten._embedding_bag_forward_only.default)
|
||||
def aten__embedding_bag_forward_only_default(node: torch_fx.Node):
|
||||
node.target = torch.ops.aten.embedding_bag.padding_idx
|
||||
embedding_bag_args = [
|
||||
("scale_grad_by_freq", False),
|
||||
("mode", 0),
|
||||
("sparse", False),
|
||||
("per_sample_weights", None),
|
||||
("include_last_offset", False),
|
||||
("padding_idx", None),
|
||||
]
|
||||
node_kwargs = dict(node.kwargs)
|
||||
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
||||
if k not in node_kwargs:
|
||||
node_kwargs[k] = v
|
||||
node.kwargs = node_kwargs
|
||||
return node
|
||||
|
||||
|
||||
def node_canonicalize(node: torch_fx.Node):
|
||||
if node.target in NODE_CANONICALIZE:
|
||||
return NODE_CANONICALIZE[node.target](node)
|
||||
return node
|
||||
|
|
Loading…
Reference in New Issue