mirror of https://github.com/llvm/torch-mlir
fximporter: support newer torch versions (#2999)
uses version checking since attributes exist in both versions, the only thing that changes is what we're receiving as an fx graphpull/3045/head
parent
6b3a7d07c2
commit
80c7bc3f7a
|
@ -220,6 +220,20 @@ PY_BUILTIN_TO_TORCH_OP = {
|
||||||
"gt": torch.ops.aten.gt,
|
"gt": torch.ops.aten.gt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# torch with cuda has a __version__ that looks like "2.1.0+cu113",
|
||||||
|
# so split by + and 0 index will always give the base version
|
||||||
|
_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0"
|
||||||
|
|
||||||
|
# The following are maps from symbolic ops to their non symbolic equivalents.
|
||||||
|
# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size
|
||||||
|
# We identify it using the number of args in the node, 1 being default, 2 being int
|
||||||
|
# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore
|
||||||
|
# map to torch.aten.size.int.
|
||||||
|
# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.<type>.
|
||||||
|
# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS
|
||||||
|
# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP
|
||||||
|
|
||||||
|
if _IS_TORCH_2_1_OR_EARLIER:
|
||||||
SYMBOLIC_TORCH_OPS = {
|
SYMBOLIC_TORCH_OPS = {
|
||||||
torch.ops.aten.sym_size,
|
torch.ops.aten.sym_size,
|
||||||
torch.ops.aten.sym_stride,
|
torch.ops.aten.sym_stride,
|
||||||
|
@ -233,6 +247,20 @@ SYMBOLIC_OP_TO_TORCH_OP = {
|
||||||
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
||||||
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
SYMBOLIC_TORCH_OPS = {
|
||||||
|
torch.ops.aten.sym_size.int,
|
||||||
|
torch.ops.aten.sym_stride.int,
|
||||||
|
torch.ops.aten.sym_numel.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
SYMBOLIC_OP_TO_TORCH_OP = {
|
||||||
|
torch.ops.aten.sym_size.default: torch.ops.aten.size.default,
|
||||||
|
torch.ops.aten.sym_size.int: torch.ops.aten.size.int,
|
||||||
|
torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default,
|
||||||
|
torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int,
|
||||||
|
torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
@ -638,7 +666,9 @@ class FxImporter:
|
||||||
node_importer.return_node_values(loc, user_outputs)
|
node_importer.return_node_values(loc, user_outputs)
|
||||||
self.symbol_table.insert(func_op)
|
self.symbol_table.insert(func_op)
|
||||||
|
|
||||||
def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"):
|
def import_frozen_program(
|
||||||
|
self, prog: torch.export.ExportedProgram, func_name: str = "main"
|
||||||
|
):
|
||||||
"""Imports a consolidated torch.export.ExportedProgram instance.
|
"""Imports a consolidated torch.export.ExportedProgram instance.
|
||||||
|
|
||||||
If using the new torch.export path (vs a lower level precursor), then this is
|
If using the new torch.export path (vs a lower level precursor), then this is
|
||||||
|
@ -1137,14 +1167,14 @@ class GraphNodeImporter:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"General getitem access to non-multi-result ops"
|
f"General getitem access to non-multi-result ops"
|
||||||
)
|
)
|
||||||
elif isinstance(target, TorchOpOverload):
|
|
||||||
# Dispatch to an ATen op.
|
|
||||||
self._import_torch_op_overload(loc, node, target)
|
|
||||||
elif target in SYMBOLIC_TORCH_OPS or (
|
elif target in SYMBOLIC_TORCH_OPS or (
|
||||||
is_symbolic(node.meta.get("val"))
|
is_symbolic(node.meta.get("val"))
|
||||||
and is_builtin_function_or_method(target)
|
and is_builtin_function_or_method(target)
|
||||||
):
|
):
|
||||||
self._import_symbolic_torch_op(loc, node, target)
|
self._import_symbolic_torch_op(loc, node, target)
|
||||||
|
elif isinstance(target, TorchOpOverload):
|
||||||
|
# Dispatch to an ATen op.
|
||||||
|
self._import_torch_op_overload(loc, node, target)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
|
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
|
||||||
|
@ -1227,7 +1257,10 @@ class GraphNodeImporter:
|
||||||
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
|
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
|
||||||
concrete_target = getattr(torch_op, op_overload)
|
concrete_target = getattr(torch_op, op_overload)
|
||||||
else:
|
else:
|
||||||
|
if _IS_TORCH_2_1_OR_EARLIER:
|
||||||
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
|
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
|
||||||
|
else:
|
||||||
|
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
concrete_target is not None
|
concrete_target is not None
|
||||||
|
@ -1628,8 +1661,7 @@ class TypeSubclassMap:
|
||||||
|
|
||||||
# Opaque value to indicate something is empty. Used in cases where 'None'
|
# Opaque value to indicate something is empty. Used in cases where 'None'
|
||||||
# may have a different meaning.
|
# may have a different meaning.
|
||||||
class EmptyType:
|
class EmptyType: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
Empty = EmptyType()
|
Empty = EmptyType()
|
||||||
|
|
Loading…
Reference in New Issue