fximporter: support newer torch versions (#2999)

uses version checking since attributes exist in both versions, the only
thing that changes is what we're receiving as an fx graph
pull/3045/head
Daniel Garvey 2024-03-08 14:58:50 -06:00 committed by GitHub
parent 6b3a7d07c2
commit 80c7bc3f7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 51 additions and 19 deletions

View File

@ -220,19 +220,47 @@ PY_BUILTIN_TO_TORCH_OP = {
"gt": torch.ops.aten.gt,
}
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
# so split by + and 0 index will always give the base version
_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
# The following are maps from symbolic ops to their non symbolic equivalents.
# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size
# We identify it using the number of args in the node, 1 being default, 2 being int
# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore
# map to torch.aten.size.int.
# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.<type>.
# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
if _IS_TORCH_2_1_OR_EARLIER:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size,
torch.ops.aten.sym_stride,
torch.ops.aten.sym_numel,
}
SYMBOLIC_OP_TO_TORCH_OP = {
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
}
else:
SYMBOLIC_TORCH_OPS = {
torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel.default,
}
SYMBOLIC_OP_TO_TORCH_OP = {
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default,
torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int,
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
}
@dataclass(frozen=True)
@ -638,7 +666,9 @@ class FxImporter:
node_importer.return_node_values(loc, user_outputs)
self.symbol_table.insert(func_op)
def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"):
def import_frozen_program(
self, prog: torch.export.ExportedProgram, func_name: str = "main"
):
"""Imports a consolidated torch.export.ExportedProgram instance.
If using the new torch.export path (vs a lower level precursor), then this is
@ -1137,14 +1167,14 @@ class GraphNodeImporter:
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
)
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
elif target in SYMBOLIC_TORCH_OPS or (
is_symbolic(node.meta.get("val"))
and is_builtin_function_or_method(target)
):
self._import_symbolic_torch_op(loc, node, target)
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node, target)
else:
raise NotImplementedError(
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
@ -1227,7 +1257,10 @@ class GraphNodeImporter:
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
concrete_target = getattr(torch_op, op_overload)
else:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
if _IS_TORCH_2_1_OR_EARLIER:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
else:
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target)
assert (
concrete_target is not None
@ -1628,8 +1661,7 @@ class TypeSubclassMap:
# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...
Empty = EmptyType()