mirror of https://github.com/llvm/torch-mlir
fximporter: support newer torch versions (#2999)
uses version checking since attributes exist in both versions, the only thing that changes is what we're receiving as an fx graphpull/3045/head
parent
6b3a7d07c2
commit
80c7bc3f7a
|
@ -220,19 +220,47 @@ PY_BUILTIN_TO_TORCH_OP = {
|
|||
"gt": torch.ops.aten.gt,
|
||||
}
|
||||
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_numel,
|
||||
}
|
||||
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
|
||||
# so split by + and 0 index will always give the base version
|
||||
_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
|
||||
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
||||
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
||||
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
|
||||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||
}
|
||||
# The following are maps from symbolic ops to their non symbolic equivalents.
|
||||
# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size
|
||||
# We identify it using the number of args in the node, 1 being default, 2 being int
|
||||
# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore
|
||||
# map to torch.aten.size.int.
|
||||
# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.<type>.
|
||||
# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS
|
||||
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
||||
|
||||
if _IS_TORCH_2_1_OR_EARLIER:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size,
|
||||
torch.ops.aten.sym_stride,
|
||||
torch.ops.aten.sym_numel,
|
||||
}
|
||||
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
||||
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
||||
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
|
||||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||
}
|
||||
else:
|
||||
SYMBOLIC_TORCH_OPS = {
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
}
|
||||
|
||||
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
||||
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
||||
torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default,
|
||||
torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int,
|
||||
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -638,7 +666,9 @@ class FxImporter:
|
|||
node_importer.return_node_values(loc, user_outputs)
|
||||
self.symbol_table.insert(func_op)
|
||||
|
||||
def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"):
|
||||
def import_frozen_program(
|
||||
self, prog: torch.export.ExportedProgram, func_name: str = "main"
|
||||
):
|
||||
"""Imports a consolidated torch.export.ExportedProgram instance.
|
||||
|
||||
If using the new torch.export path (vs a lower level precursor), then this is
|
||||
|
@ -1137,14 +1167,14 @@ class GraphNodeImporter:
|
|||
raise NotImplementedError(
|
||||
f"General getitem access to non-multi-result ops"
|
||||
)
|
||||
elif isinstance(target, TorchOpOverload):
|
||||
# Dispatch to an ATen op.
|
||||
self._import_torch_op_overload(loc, node, target)
|
||||
elif target in SYMBOLIC_TORCH_OPS or (
|
||||
is_symbolic(node.meta.get("val"))
|
||||
and is_builtin_function_or_method(target)
|
||||
):
|
||||
self._import_symbolic_torch_op(loc, node, target)
|
||||
elif isinstance(target, TorchOpOverload):
|
||||
# Dispatch to an ATen op.
|
||||
self._import_torch_op_overload(loc, node, target)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
|
||||
|
@ -1227,7 +1257,10 @@ class GraphNodeImporter:
|
|||
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
|
||||
concrete_target = getattr(torch_op, op_overload)
|
||||
else:
|
||||
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
|
||||
if _IS_TORCH_2_1_OR_EARLIER:
|
||||
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
|
||||
else:
|
||||
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target)
|
||||
|
||||
assert (
|
||||
concrete_target is not None
|
||||
|
@ -1628,8 +1661,7 @@ class TypeSubclassMap:
|
|||
|
||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||
# may have a different meaning.
|
||||
class EmptyType:
|
||||
...
|
||||
class EmptyType: ...
|
||||
|
||||
|
||||
Empty = EmptyType()
|
||||
|
|
Loading…
Reference in New Issue