mirror of https://github.com/llvm/torch-mlir
[FxImporter] refactor canonicalize using table driven (#3402)
parent
f09cb766dc
commit
37e89828a1
|
@ -1444,7 +1444,7 @@ class GraphNodeImporter:
|
||||||
self._import_symbolic_torch_op(loc, node, target)
|
self._import_symbolic_torch_op(loc, node, target)
|
||||||
elif isinstance(target, TorchOpOverload):
|
elif isinstance(target, TorchOpOverload):
|
||||||
# Dispatch to an ATen op.
|
# Dispatch to an ATen op.
|
||||||
self._import_torch_op_overload(loc, node, target)
|
self._import_torch_op_overload(loc, node)
|
||||||
elif isinstance(target, HigherOrderOperator):
|
elif isinstance(target, HigherOrderOperator):
|
||||||
self._import_hop(loc, node, target)
|
self._import_hop(loc, node, target)
|
||||||
else:
|
else:
|
||||||
|
@ -1615,59 +1615,18 @@ class GraphNodeImporter:
|
||||||
self.bind_node_value(node, value, i + bind_none)
|
self.bind_node_value(node, value, i + bind_none)
|
||||||
|
|
||||||
def _import_torch_op_overload(
|
def _import_torch_op_overload(
|
||||||
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
|
self,
|
||||||
|
loc: Location,
|
||||||
|
node: torch_fx.Node,
|
||||||
|
concrete_target: Optional[TorchOpOverload] = None,
|
||||||
):
|
):
|
||||||
# TODO: Convert this cascade of ifs to a table-driven
|
if concrete_target is None:
|
||||||
# replace lift_fresh_copy with clone op
|
node = node_canonicalize(node)
|
||||||
if target == torch.ops.aten.lift_fresh_copy.default:
|
if not node:
|
||||||
node.target = target = torch.ops.aten.clone.default
|
return
|
||||||
node.args = (node.args[0],)
|
target = node.target
|
||||||
node.kwargs = {"memory_format": None}
|
else:
|
||||||
elif target == torch.ops.aten.lift_fresh_copy.out:
|
target = concrete_target
|
||||||
# TODO: It seems not possible to hit this case from user code.
|
|
||||||
# Retaining in case if it is triggered internally somehow, but
|
|
||||||
# it can most likely be removed once assuming full
|
|
||||||
# functionalization in all cases.
|
|
||||||
node.target = target = torch.ops.aten.clone.out
|
|
||||||
node.args = (node.args[0],)
|
|
||||||
node.kwargs = {"memory_format": None, "out": node.args[1]}
|
|
||||||
# TODO: generalize empty.memory_format in the future
|
|
||||||
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
|
||||||
# empty.memory_format input with a constant, which creates NaN values
|
|
||||||
# because empty.memory_format contains uninitialized data. Converting
|
|
||||||
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
|
|
||||||
elif target == torch.ops.aten.empty.memory_format:
|
|
||||||
if len(node.users) == 1:
|
|
||||||
for key_node in node.users:
|
|
||||||
if key_node.target == torch.ops.aten.baddbmm.default:
|
|
||||||
node.target = target = torch.ops.aten.zeros.default
|
|
||||||
elif target == torch.ops.aten._local_scalar_dense.default:
|
|
||||||
input_type = node.args[0].meta["tensor_meta"].dtype
|
|
||||||
if input_type.is_floating_point:
|
|
||||||
node.target = target = torch.ops.aten.Float.Tensor
|
|
||||||
else:
|
|
||||||
node.target = target = torch.ops.aten.Int.Tensor
|
|
||||||
node.args = (node.args[0],)
|
|
||||||
elif target == torch.ops.aten._assert_async.msg:
|
|
||||||
# TODO: A more suitable op to replace it?
|
|
||||||
return
|
|
||||||
elif target == torch.ops.aten._unsafe_index_put.default:
|
|
||||||
node.target = target = torch.ops.aten._unsafe_index_put.hacked_twin
|
|
||||||
elif target == torch.ops.aten._embedding_bag_forward_only.default:
|
|
||||||
node.target = target = torch.ops.aten.embedding_bag.padding_idx
|
|
||||||
embedding_bag_args = [
|
|
||||||
("scale_grad_by_freq", False),
|
|
||||||
("mode", 0),
|
|
||||||
("sparse", False),
|
|
||||||
("per_sample_weights", None),
|
|
||||||
("include_last_offset", False),
|
|
||||||
("padding_idx", None),
|
|
||||||
]
|
|
||||||
node_kwargs = dict(node.kwargs)
|
|
||||||
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
|
||||||
if k not in node_kwargs:
|
|
||||||
node_kwargs[k] = v
|
|
||||||
node.kwargs = node_kwargs
|
|
||||||
|
|
||||||
schema = target._schema
|
schema = target._schema
|
||||||
assert isinstance(schema, FunctionSchema)
|
assert isinstance(schema, FunctionSchema)
|
||||||
|
@ -2401,3 +2360,97 @@ TENSOR_SCALAR_OP_CONVERTER = {
|
||||||
"torch.aten.sub.Tensor": "torch.aten.sub.Scalar",
|
"torch.aten.sub.Tensor": "torch.aten.sub.Scalar",
|
||||||
"torch.aten.floor_divide": "torch.aten.floor_divide.Scalar",
|
"torch.aten.floor_divide": "torch.aten.floor_divide.Scalar",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CANONICALIZE: Dict[TorchOpOverload, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_canonicalize(op: TorchOpOverload):
|
||||||
|
def wrapper(func):
|
||||||
|
NODE_CANONICALIZE[op] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten.lift_fresh_copy.default)
|
||||||
|
def lift_fresh_copy_default(node: torch_fx.Node):
|
||||||
|
# replace lift_fresh_copy with clone op
|
||||||
|
node.target = torch.ops.aten.clone.default
|
||||||
|
node.args = (node.args[0],)
|
||||||
|
node.kwargs = {"memory_format": None}
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten.lift_fresh_copy.out)
|
||||||
|
def lift_fresh_copy_out(node: torch_fx.Node):
|
||||||
|
# TODO: It seems not possible to hit this case from user code.
|
||||||
|
# Retaining in case if it is triggered internally somehow, but
|
||||||
|
# it can most likely be removed once assuming full
|
||||||
|
# functionalization in all cases.
|
||||||
|
node.target = target = torch.ops.aten.clone.out
|
||||||
|
node.args = (node.args[0],)
|
||||||
|
node.kwargs = {"memory_format": None, "out": node.args[1]}
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten.empty.memory_format)
|
||||||
|
def empty_memory_format(node: torch_fx.Node):
|
||||||
|
# TODO: generalize empty.memory_format in the future
|
||||||
|
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
||||||
|
# empty.memory_format input with a constant, which creates NaN values
|
||||||
|
# because empty.memory_format contains uninitialized data. Converting
|
||||||
|
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
|
||||||
|
if len(node.users) == 1:
|
||||||
|
for key_node in node.users:
|
||||||
|
if key_node.target == torch.ops.aten.baddbmm.default:
|
||||||
|
node.target = torch.ops.aten.zeros.default
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten._local_scalar_dense.default)
|
||||||
|
def aten__local_scalar_dense_default(node: torch_fx.Node):
|
||||||
|
input_type = node.args[0].meta["tensor_meta"].dtype
|
||||||
|
if input_type.is_floating_point:
|
||||||
|
node.target = torch.ops.aten.Float.Tensor
|
||||||
|
else:
|
||||||
|
node.target = torch.ops.aten.Int.Tensor
|
||||||
|
node.args = (node.args[0],)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten._assert_async.msg)
|
||||||
|
def aten__assert_async_msg(node: torch_fx.Node):
|
||||||
|
# TODO: A more suitable op to replace it?
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten._unsafe_index_put.default)
|
||||||
|
def aten__unsafe_index_put_default(node: torch_fx.Node):
|
||||||
|
node.target = torch.ops.aten._unsafe_index_put.hacked_twin
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@register_canonicalize(torch.ops.aten._embedding_bag_forward_only.default)
|
||||||
|
def aten__embedding_bag_forward_only_default(node: torch_fx.Node):
|
||||||
|
node.target = torch.ops.aten.embedding_bag.padding_idx
|
||||||
|
embedding_bag_args = [
|
||||||
|
("scale_grad_by_freq", False),
|
||||||
|
("mode", 0),
|
||||||
|
("sparse", False),
|
||||||
|
("per_sample_weights", None),
|
||||||
|
("include_last_offset", False),
|
||||||
|
("padding_idx", None),
|
||||||
|
]
|
||||||
|
node_kwargs = dict(node.kwargs)
|
||||||
|
for k, v in embedding_bag_args[len(node.args) - 3 :]:
|
||||||
|
if k not in node_kwargs:
|
||||||
|
node_kwargs[k] = v
|
||||||
|
node.kwargs = node_kwargs
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def node_canonicalize(node: torch_fx.Node):
|
||||||
|
if node.target in NODE_CANONICALIZE:
|
||||||
|
return NODE_CANONICALIZE[node.target](node)
|
||||||
|
return node
|
||||||
|
|
Loading…
Reference in New Issue