2023-12-22 00:40:10 +08:00
|
|
|
# Copyright 2023 Advanced Micro Devices, Inc
|
|
|
|
#
|
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
2024-01-26 09:01:47 +08:00
|
|
|
try:
|
|
|
|
from types import NoneType
|
|
|
|
except ImportError:
|
|
|
|
# python less than 3.10 doesn't have NoneType
|
|
|
|
NoneType = type(None)
|
|
|
|
|
2023-12-22 00:40:10 +08:00
|
|
|
import logging
|
|
|
|
import operator
|
|
|
|
import re
|
2024-02-13 08:10:57 +08:00
|
|
|
from dataclasses import dataclass
|
2024-01-26 09:01:47 +08:00
|
|
|
from types import BuiltinMethodType, BuiltinFunctionType
|
2023-12-22 00:40:10 +08:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
|
|
import weakref
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.export
|
|
|
|
import torch.fx as torch_fx
|
|
|
|
from torch.fx.passes.shape_prop import TensorMetadata
|
|
|
|
|
|
|
|
from torch import (
|
|
|
|
dtype as TorchDtype,
|
|
|
|
FunctionSchema,
|
|
|
|
)
|
|
|
|
|
|
|
|
from torch._ops import (
|
|
|
|
OpOverload as TorchOpOverload,
|
|
|
|
)
|
|
|
|
|
|
|
|
from torch._subclasses import (
|
|
|
|
FakeTensor as TorchFakeTensor,
|
|
|
|
)
|
|
|
|
|
|
|
|
from torch.fx import (
|
|
|
|
Graph,
|
|
|
|
GraphModule,
|
|
|
|
)
|
|
|
|
|
|
|
|
from torch.fx.node import (
|
|
|
|
Argument as NodeArgument,
|
|
|
|
)
|
|
|
|
|
|
|
|
from ..ir import (
|
|
|
|
Attribute,
|
|
|
|
Block,
|
|
|
|
Context,
|
2024-02-06 14:19:31 +08:00
|
|
|
DenseElementsAttr,
|
2023-12-22 00:40:10 +08:00
|
|
|
DenseResourceElementsAttr,
|
|
|
|
FloatAttr,
|
|
|
|
BF16Type,
|
|
|
|
ComplexType,
|
|
|
|
F16Type,
|
|
|
|
F32Type,
|
|
|
|
F64Type,
|
|
|
|
FunctionType,
|
|
|
|
InsertionPoint,
|
|
|
|
IntegerAttr,
|
|
|
|
IntegerType,
|
|
|
|
RankedTensorType,
|
|
|
|
Location,
|
|
|
|
Module,
|
|
|
|
Operation,
|
|
|
|
StringAttr,
|
|
|
|
SymbolTable,
|
|
|
|
Type as IrType,
|
|
|
|
Value,
|
|
|
|
)
|
|
|
|
|
|
|
|
from ..dialects import (
|
|
|
|
func as func_dialect,
|
|
|
|
)
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"FxImporter",
|
|
|
|
]
|
|
|
|
|
|
|
|
# An external callback that, given a Python value and a GraphNodeImporter, may choose
|
|
|
|
# to materialize IR to load the value as a vtensor. If it returns None, then default
|
|
|
|
# literal resolution proceeds.
|
|
|
|
LiteralResolverCallback = Callable[[Any, "GraphNodeImporter"], Optional[Value]]
|
|
|
|
|
|
|
|
REQUIRED_DIALCTS = [
|
|
|
|
"builtin",
|
|
|
|
"func",
|
|
|
|
"torch",
|
|
|
|
]
|
|
|
|
|
|
|
|
TORCH_DTYPE_TO_MLIR_TYPE_ASM = {
|
|
|
|
torch.float16: "f16",
|
|
|
|
torch.bfloat16: "bf16",
|
|
|
|
torch.float32: "f32",
|
|
|
|
torch.float64: "f64",
|
|
|
|
torch.uint8: "ui8",
|
|
|
|
torch.int8: "si8",
|
|
|
|
torch.int16: "si16",
|
|
|
|
torch.int32: "si32",
|
|
|
|
torch.int64: "si64",
|
|
|
|
torch.bool: "i1",
|
|
|
|
torch.qint8: "!torch.qint8",
|
|
|
|
torch.quint8: "!torch.quint8",
|
|
|
|
torch.complex32: "complex<f16>",
|
|
|
|
torch.complex64: "complex<f32>",
|
|
|
|
torch.complex128: "complex<f64>",
|
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
|
|
|
|
torch.float16: lambda: F16Type.get(),
|
|
|
|
torch.bfloat16: lambda: BF16Type.get(),
|
|
|
|
torch.float32: lambda: F32Type.get(),
|
|
|
|
torch.float64: lambda: F64Type.get(),
|
|
|
|
torch.uint8: lambda: IntegerType.get_unsigned(8),
|
|
|
|
torch.int8: lambda: IntegerType.get_signed(8),
|
|
|
|
torch.int16: lambda: IntegerType.get_signed(16),
|
|
|
|
torch.int32: lambda: IntegerType.get_signed(32),
|
|
|
|
torch.int64: lambda: IntegerType.get_signed(64),
|
|
|
|
torch.bool: lambda: IntegerType.get_signless(1),
|
|
|
|
torch.qint8: lambda: IntegerType.get_signed(8),
|
|
|
|
torch.quint8: lambda: IntegerType.get_unsigned(8),
|
|
|
|
torch.complex32: lambda: ComplexType.get(F16Type.get()),
|
|
|
|
torch.complex64: lambda: ComplexType.get(F32Type.get()),
|
|
|
|
torch.complex128: lambda: ComplexType.get(F64Type.get()),
|
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_DTYPE_TO_NPY_TYPE = {
|
|
|
|
# torch.qint8: None, # no equivalent np datatype
|
|
|
|
# torch.quint8: None,
|
|
|
|
torch.uint8: np.uint8,
|
|
|
|
torch.int8: np.int8,
|
|
|
|
torch.int16: np.int16,
|
|
|
|
torch.int32: np.int32,
|
|
|
|
torch.int64: np.int64,
|
|
|
|
# torch.bf16: None, there's no equivalent np datatype so this isn't supported right now
|
|
|
|
torch.float16: np.float16,
|
|
|
|
torch.float32: np.float32,
|
|
|
|
torch.float64: np.float64,
|
|
|
|
torch.bool: np.bool_,
|
|
|
|
# torch.complex32: None, # no equivalent precision for numpy
|
|
|
|
torch.complex64: np.complex64,
|
|
|
|
torch.complex128: np.complex128,
|
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_DTYPE_TO_INT = {
|
|
|
|
torch.uint8: 0,
|
|
|
|
torch.int8: 1,
|
|
|
|
torch.int16: 2,
|
|
|
|
torch.int32: 3,
|
|
|
|
torch.int64: 4,
|
|
|
|
torch.float16: 5,
|
|
|
|
torch.float32: 6,
|
|
|
|
torch.float64: 7,
|
|
|
|
# torch.complex_half 8
|
|
|
|
torch.complex32: 9,
|
|
|
|
torch.complex64: 10,
|
|
|
|
torch.bool: 11,
|
|
|
|
# torch.qint8: 12, # quantized dtypes are not supported in all backends, currently we do not support them
|
|
|
|
# torch.quint8: 13,
|
|
|
|
# torch.qint32 14
|
|
|
|
torch.bfloat16: 15,
|
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_MEMORY_FORMAT_TO_INT = {
|
|
|
|
torch.contiguous_format: 0,
|
|
|
|
torch.preserve_format: 1,
|
|
|
|
torch.channels_last: 2,
|
|
|
|
torch.channels_last_3d: 3,
|
|
|
|
}
|
|
|
|
|
|
|
|
TORCH_LAYOUT_TO_INT = {
|
|
|
|
torch.strided: 0,
|
|
|
|
torch.sparse_coo: 1,
|
|
|
|
torch.sparse_csr: 2,
|
|
|
|
torch.sparse_csc: 3,
|
|
|
|
torch.sparse_bsr: 4,
|
|
|
|
torch.sparse_bsc: 5,
|
|
|
|
}
|
|
|
|
|
|
|
|
PY_BUILTIN_TO_TORCH_OP = {
|
|
|
|
"truediv": torch.ops.aten.div,
|
|
|
|
"mul": torch.ops.aten.mul,
|
|
|
|
"add": torch.ops.aten.add,
|
|
|
|
"sub": torch.ops.aten.sub,
|
|
|
|
"lt": torch.ops.aten.lt,
|
|
|
|
"le": torch.ops.aten.le,
|
|
|
|
"ge": torch.ops.aten.ge,
|
|
|
|
"ne": torch.ops.aten.ne,
|
|
|
|
"gt": torch.ops.aten.gt,
|
|
|
|
}
|
|
|
|
|
|
|
|
SYMBOLIC_TORCH_OPS = {
|
|
|
|
torch.ops.aten.sym_size,
|
|
|
|
torch.ops.aten.sym_stride,
|
|
|
|
torch.ops.aten.sym_numel,
|
|
|
|
}
|
|
|
|
|
|
|
|
SYMBOLIC_OP_TO_TORCH_OP = {
|
|
|
|
(torch.ops.aten.sym_size, 1): torch.ops.aten.size.default,
|
|
|
|
(torch.ops.aten.sym_size, 2): torch.ops.aten.size.int,
|
|
|
|
(torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default,
|
|
|
|
(torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int,
|
|
|
|
(torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2024-02-13 08:10:57 +08:00
|
|
|
@dataclass(frozen=True)
|
|
|
|
class SparsityMeta:
|
|
|
|
"""Class for keeping track of sparsity meta data."""
|
2024-01-31 13:22:12 +08:00
|
|
|
|
2024-02-13 08:10:57 +08:00
|
|
|
layout: torch.layout
|
|
|
|
batch_dim: int
|
|
|
|
sparse_dim: int
|
|
|
|
dense_dim: int
|
|
|
|
pos_width: int
|
|
|
|
crd_width: int
|
|
|
|
|
|
|
|
|
|
|
|
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
|
|
|
|
"""Returns sparse tensor encoding for the given sparse layout as string."""
|
2024-02-13 02:04:54 +08:00
|
|
|
assert sparsity is not None
|
2024-02-13 08:10:57 +08:00
|
|
|
|
|
|
|
# Sparse tensors have the form
|
|
|
|
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
|
|
|
|
# which map directly to MLIR types.
|
|
|
|
batch_dim, sparse_dim, dense_dim = (
|
|
|
|
sparsity.batch_dim,
|
|
|
|
sparsity.sparse_dim,
|
|
|
|
sparsity.dense_dim,
|
|
|
|
)
|
|
|
|
dim = batch_dim + sparse_dim + dense_dim
|
|
|
|
assert dim == len(shape)
|
|
|
|
|
|
|
|
dims = ",".join(f"d{d}" for d in range(0, dim))
|
|
|
|
|
|
|
|
if sparsity.layout is torch.sparse_coo:
|
|
|
|
assert sparse_dim == 2 # TODO: deeper sparse dims
|
|
|
|
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton"
|
|
|
|
elif sparsity.layout is torch.sparse_csr:
|
|
|
|
assert sparse_dim == 2
|
|
|
|
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
|
|
|
|
elif sparsity.layout is torch.sparse_csc:
|
|
|
|
assert sparse_dim == 2
|
|
|
|
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
|
2024-02-13 02:04:54 +08:00
|
|
|
else:
|
|
|
|
# TODO: block format (derive block size!)
|
|
|
|
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
|
2024-01-31 13:22:12 +08:00
|
|
|
|
2024-02-13 08:10:57 +08:00
|
|
|
if batch_dim > 0:
|
|
|
|
batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim))
|
|
|
|
lvls = f"{batch},{lvls}"
|
|
|
|
|
|
|
|
if dense_dim > 0:
|
|
|
|
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
|
|
|
|
lvls = f"{lvls},{dense}"
|
|
|
|
|
|
|
|
posw, crdw = sparsity.pos_width, sparsity.crd_width
|
|
|
|
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def is_symbolic(obj: Any) -> bool:
|
2024-01-31 13:22:12 +08:00
|
|
|
"""Check whether an object in our graph is symbolic"""
|
2023-12-22 00:40:10 +08:00
|
|
|
return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool))
|
|
|
|
|
|
|
|
|
|
|
|
def is_builtin_function_or_method(obj: Any) -> bool:
|
|
|
|
return isinstance(obj, (BuiltinMethodType, BuiltinFunctionType))
|
|
|
|
|
|
|
|
|
|
|
|
class FxImporter:
|
|
|
|
"""Main entry-point for importing an fx.GraphModule.
|
|
|
|
|
|
|
|
The FxImporter is a low-level class intended for framework integrators.
|
|
|
|
It provides several options for customization:
|
|
|
|
|
|
|
|
* config_check: Optionally allows some per-import configuration safety
|
|
|
|
checks to be skipped.
|
|
|
|
* literal_resolver_callback: Callback that will be invoked when a literal,
|
|
|
|
live torch.Tensor is encountered in the FX graph, allowing the default
|
|
|
|
action (which is to inline the data as a DenseResourceElementsAttr) to
|
|
|
|
be completely overriden.
|
|
|
|
* py_attr_tracker: Weak reference tracker for live PyTorch objects used
|
|
|
|
to unique them with respect to attributes. If not specified, there will
|
|
|
|
be one reference tracker per import, but this can be injected to share
|
|
|
|
the same uniqueing across imports (i.e. if building multiple functions
|
|
|
|
into the same context or module).
|
|
|
|
"""
|
|
|
|
|
|
|
|
__slots__ = [
|
|
|
|
"_c",
|
|
|
|
"_cc",
|
|
|
|
"_literal_resolver_callback",
|
|
|
|
"_m",
|
|
|
|
"_m_ip",
|
|
|
|
"_py_attr_tracker",
|
|
|
|
"symbol_table",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
module: Optional[Module] = None,
|
|
|
|
context: Optional[Context] = None,
|
|
|
|
config_check: bool = True,
|
|
|
|
literal_resolver_callback: Optional[LiteralResolverCallback] = None,
|
|
|
|
py_attr_tracker: Optional["RefTracker"] = None,
|
|
|
|
):
|
|
|
|
if module is not None:
|
|
|
|
assert context is None, "If configuring with a Module, context must be None"
|
|
|
|
self._m = module
|
|
|
|
self._c = self.module.context
|
|
|
|
else:
|
|
|
|
self._c = context if context else Context()
|
|
|
|
self._m = Module.create(Location.unknown(self._c))
|
|
|
|
if config_check:
|
|
|
|
# Production code can disable this for a bit of a boost.
|
|
|
|
self._config_check()
|
|
|
|
self._py_attr_tracker = py_attr_tracker or RefTracker()
|
|
|
|
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
|
|
|
|
self._m_ip = InsertionPoint(self._m.body)
|
|
|
|
self._literal_resolver_callback = literal_resolver_callback
|
|
|
|
self.symbol_table = SymbolTable(self._m.operation)
|
|
|
|
|
|
|
|
def _config_check(self):
|
|
|
|
for dname in REQUIRED_DIALCTS:
|
|
|
|
try:
|
|
|
|
self._c.dialects[dname]
|
|
|
|
logging.debug("Context has registered dialect '%s'", dname)
|
|
|
|
except IndexError:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"The MLIR context {self._c} is missing required dialect '{dname}'"
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def module(self) -> Module:
|
|
|
|
return self._m
|
|
|
|
|
|
|
|
@property
|
|
|
|
def module_op(self) -> Operation:
|
|
|
|
return self._m.operation
|
|
|
|
|
|
|
|
def import_frozen_exported_program(self, prog: torch.export.ExportedProgram):
|
|
|
|
"""Imports a consolidated torch.export.ExportedProgram instance.
|
|
|
|
|
|
|
|
If using the new torch.export path (vs a lower level precursor), then this is
|
|
|
|
the recommended way to canonically use this importer.
|
|
|
|
|
|
|
|
The ExportedProgram form differs from some of the earlier work primarily in
|
|
|
|
how it deals with references to external tensors from "outside". In this form,
|
|
|
|
all such references are checked to have originated from within the exported
|
|
|
|
scope or from an @assume_constant_result wrapped function. Then they are
|
|
|
|
transformed to graph inputs and stashed in one of two data structures on
|
|
|
|
the ExportedProgram:
|
|
|
|
inputs_to_buffers / buffers : For non-parameter buffers.
|
|
|
|
inputs_to_parameters / parameters : For parameter buffers.
|
|
|
|
The values of the mapping in inputs_to_{buffers|parameters} are in the
|
|
|
|
state_dict. This replaces get_attr nodes that would have classically been
|
|
|
|
present during lower level tracing.
|
|
|
|
Historically, torch-mlir has assumed that all such external accesses are
|
|
|
|
frozen, and this entry-point preserves this behavior, treating each distinct
|
|
|
|
torch.Tensor encountered in such a way as a `torch.vtensor.literal` (or
|
|
|
|
delegating to the literal_resolver_callback to make a policy decision).
|
|
|
|
|
|
|
|
As we anticipate more nuanced treatment options in the future, we name this
|
|
|
|
method to indicate that it is producing "frozen" modules. Additional top-level
|
|
|
|
approaches to handling state can be introduced later as an addition.
|
|
|
|
"""
|
|
|
|
sig = prog.graph_signature
|
|
|
|
state_dict = prog.state_dict
|
|
|
|
arg_replacements: dict[str, Any] = {}
|
2024-02-14 04:38:32 +08:00
|
|
|
|
|
|
|
# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
|
|
|
|
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
|
|
|
|
if hasattr(prog, "constants"):
|
|
|
|
constants = prog.constants
|
|
|
|
# Lift tensor constants.
|
|
|
|
for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items():
|
|
|
|
try:
|
|
|
|
state_value = constants[state_name]
|
|
|
|
except KeyError as e:
|
|
|
|
raise AssertionError("Could not find state mapping for tensor constants") from e
|
|
|
|
arg_replacements[input_name] = state_value
|
|
|
|
else:
|
|
|
|
# Lift buffers.
|
|
|
|
for input_name, state_name in sig.inputs_to_buffers.items():
|
|
|
|
try:
|
|
|
|
state_value = state_dict[state_name]
|
|
|
|
except KeyError as e:
|
|
|
|
raise AssertionError("Could not find state mapping for buffer") from e
|
|
|
|
arg_replacements[input_name] = state_value
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
# Lift parameters.
|
|
|
|
for input_name, state_name in sig.inputs_to_parameters.items():
|
|
|
|
try:
|
|
|
|
state_value = state_dict[state_name]
|
|
|
|
except KeyError as e:
|
|
|
|
raise AssertionError(
|
|
|
|
"Could not find state mapping for parameter"
|
|
|
|
) from e
|
|
|
|
arg_replacements[input_name] = state_value
|
|
|
|
|
2024-01-31 13:22:12 +08:00
|
|
|
# Remove any lifted placeholders, replacing their uses with the state
|
2023-12-22 00:40:10 +08:00
|
|
|
# replacement value.
|
|
|
|
g = prog.graph
|
|
|
|
for node in g.nodes:
|
|
|
|
if node.op == "placeholder":
|
|
|
|
replacement = arg_replacements.get(node.name)
|
|
|
|
if replacement is None:
|
|
|
|
continue
|
|
|
|
node.replace_all_uses_with(replacement)
|
|
|
|
g.erase_node(node)
|
|
|
|
|
|
|
|
self.import_stateless_graph(g)
|
|
|
|
|
|
|
|
def import_graph_module(self, gm: GraphModule):
|
|
|
|
"""Low-level import of a GraphModule assuming that it has been functionalized."""
|
|
|
|
self.import_stateless_graph(gm.graph)
|
|
|
|
|
|
|
|
def import_stateless_graph(self, g: Graph, func_name: str = "main"):
|
|
|
|
"""Low-level import of a functionalized, assumed stateless Graph as a func."""
|
|
|
|
ftype, loc = self._graph_to_function_meta(g)
|
|
|
|
# TODO: The FuncOp constructor requires a context-manager context.
|
|
|
|
# Fix upstream and then unnest.
|
|
|
|
# See: https://github.com/nod-ai/SHARK-Turbine/issues/138
|
|
|
|
with loc:
|
|
|
|
func = func_dialect.FuncOp(
|
|
|
|
func_name,
|
|
|
|
ftype,
|
|
|
|
ip=self._m_ip,
|
|
|
|
)
|
|
|
|
entry_block = Block.create_at_start(func.body, ftype.inputs)
|
|
|
|
node_importer = GraphNodeImporter(
|
|
|
|
self,
|
|
|
|
self._c,
|
|
|
|
self._cc,
|
|
|
|
entry_block,
|
|
|
|
literal_resolver_callback=self._literal_resolver_callback,
|
|
|
|
)
|
|
|
|
node_importer.import_nodes(g.nodes)
|
|
|
|
self.symbol_table.insert(func)
|
|
|
|
|
|
|
|
def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]:
|
|
|
|
"""Extracts function metadata from the Graph.
|
|
|
|
|
|
|
|
Principally, this includes the FunctionType, but in the future,
|
|
|
|
it should also return other annotations (input strides, etc) that
|
|
|
|
affect compilation and should be included as arg attrs.
|
|
|
|
"""
|
|
|
|
input_types = []
|
|
|
|
result_types = []
|
|
|
|
loc = None
|
|
|
|
for node in g.nodes:
|
|
|
|
# Assume that the first node we can get a location for is about as
|
|
|
|
# good as it gets as an overall function location.
|
|
|
|
if loc is None:
|
|
|
|
loc = self._cc.get_node_location(node)
|
|
|
|
if node.op == "placeholder":
|
|
|
|
input_types.append(self._cc.node_val_to_type(node))
|
|
|
|
elif node.op == "output":
|
|
|
|
# An output node's args[0] is the return value. This seems to
|
|
|
|
# always be "boxed" as a tuple, which we emit as multi-results.
|
|
|
|
for result_node in node.args[0]:
|
|
|
|
if result_node is None:
|
|
|
|
result_types.append(
|
|
|
|
IrType.parse("!torch.none", context=self._c)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
result_types.append(self._cc.node_val_to_type(result_node))
|
|
|
|
return (
|
|
|
|
FunctionType.get(input_types, result_types, context=self._c),
|
|
|
|
loc if loc else Location.unknown(self._c),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ContextCache:
|
|
|
|
"""Caches per-context lookups of various things that we ask for repeatedly."""
|
|
|
|
|
|
|
|
__slots__ = [
|
|
|
|
"_c",
|
|
|
|
"_dtype_to_type",
|
|
|
|
"_tensor_metadata_cache",
|
|
|
|
"_py_attr_tracker",
|
|
|
|
# Types.
|
|
|
|
"torch_bool_type",
|
|
|
|
"torch_float_type",
|
|
|
|
"torch_int_type",
|
|
|
|
"torch_none_type",
|
|
|
|
"torch_str_type",
|
|
|
|
"torch_device_type",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, context: Context, *, py_attr_tracker: Optional["RefTracker"] = None
|
|
|
|
):
|
|
|
|
self._c = context
|
|
|
|
self._dtype_to_type: Dict[TorchDtype, IrType] = {}
|
|
|
|
self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], IrType] = {}
|
|
|
|
self._py_attr_tracker = py_attr_tracker or RefTracker()
|
|
|
|
|
|
|
|
# Common types.
|
|
|
|
with context:
|
|
|
|
self.torch_bool_type = IrType.parse("!torch.bool")
|
|
|
|
self.torch_float_type = IrType.parse("!torch.float")
|
|
|
|
self.torch_int_type = IrType.parse("!torch.int")
|
|
|
|
self.torch_none_type = IrType.parse("!torch.none")
|
|
|
|
self.torch_str_type = IrType.parse("!torch.str")
|
|
|
|
self.torch_device_type = IrType.parse("!torch.Device")
|
|
|
|
|
|
|
|
def integer_attr(self, value: int, bits: int) -> Attribute:
|
|
|
|
c = self._c
|
|
|
|
return IntegerAttr.get(IntegerType.get_signless(bits, c), value)
|
|
|
|
|
|
|
|
"""Strips symbolic elements from a torch.Size object and returns shape asm"""
|
|
|
|
|
|
|
|
def format_asm_shape(self, shape: torch.Size) -> str:
|
|
|
|
return ",".join("?" if is_symbolic(d) else str(d) for d in list(shape))
|
|
|
|
|
|
|
|
"""Return IrType for !torch.vtensor with the given shape and dtype"""
|
|
|
|
|
2024-02-06 14:19:31 +08:00
|
|
|
def get_vtensor_type(
|
2024-02-13 02:04:54 +08:00
|
|
|
self,
|
|
|
|
shape: torch.Size,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
*,
|
2024-02-13 08:10:57 +08:00
|
|
|
sparsity: Optional[SparsityMeta] = None, # keyword-only
|
2024-02-06 14:19:31 +08:00
|
|
|
):
|
2023-12-22 00:40:10 +08:00
|
|
|
shape_asm = self.format_asm_shape(shape)
|
|
|
|
mlir_dtype = str(self.dtype_to_type(dtype))
|
2024-02-13 02:04:54 +08:00
|
|
|
if sparsity is not None:
|
|
|
|
encoding = sparsity_encoding(shape, sparsity)
|
|
|
|
assert encoding is not None
|
2024-02-06 14:19:31 +08:00
|
|
|
return IrType.parse(
|
2024-02-13 02:04:54 +08:00
|
|
|
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
|
2024-02-06 14:19:31 +08:00
|
|
|
context=self._c,
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
return IrType.parse(
|
2024-02-06 14:19:31 +08:00
|
|
|
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
def node_val_to_type(self, node: torch_fx.Node) -> IrType:
|
|
|
|
try:
|
|
|
|
tensor_meta = node.meta.get("tensor_meta")
|
|
|
|
val = node.meta.get("val")
|
2024-02-13 02:04:54 +08:00
|
|
|
sparsity = node.meta.get("sparsity", None)
|
2023-12-22 00:40:10 +08:00
|
|
|
if tensor_meta is not None:
|
|
|
|
assert isinstance(tensor_meta, TensorMetadata)
|
|
|
|
# Quantized tensor meta data is not preserved in our lowering,
|
|
|
|
# so throw error instead of silently doing wrong thing.
|
|
|
|
if tensor_meta.is_quantized:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"Quantized tensor meta data is not supported."
|
|
|
|
)
|
|
|
|
else:
|
2024-02-13 02:04:54 +08:00
|
|
|
return self.tensor_metadata_to_type(tensor_meta, sparsity=sparsity)
|
2023-12-22 00:40:10 +08:00
|
|
|
elif val is not None:
|
|
|
|
# some nodes with symbolic inputs pass a 'val' attribute rather than
|
|
|
|
# tensor_meta
|
|
|
|
if isinstance(val, TorchFakeTensor):
|
2024-02-13 02:04:54 +08:00
|
|
|
return self.get_vtensor_type(
|
|
|
|
val.size(), val.dtype, sparsity=sparsity
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
|
|
|
|
if t is not None:
|
|
|
|
return IrType.parse(t, self._c)
|
|
|
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"FIXME: Unsupported placeholder node (this often indicates that a necessary) "
|
|
|
|
f"fx preprocessing pass was not run): {node.meta}"
|
|
|
|
)
|
|
|
|
except KeyError as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
|
|
|
|
)
|
|
|
|
|
2024-02-06 14:19:31 +08:00
|
|
|
def tensor_metadata_to_type(
|
2024-02-13 02:04:54 +08:00
|
|
|
self,
|
|
|
|
tm: TensorMetadata,
|
|
|
|
*,
|
2024-02-13 08:10:57 +08:00
|
|
|
sparsity: Optional[SparsityMeta] = None, # keyword-only
|
2024-02-06 14:19:31 +08:00
|
|
|
) -> IrType:
|
2023-12-22 00:40:10 +08:00
|
|
|
tm_shape = tuple(
|
|
|
|
item.node if is_symbolic(item) else item for item in list(tm.shape)
|
|
|
|
)
|
|
|
|
|
2024-02-13 02:04:54 +08:00
|
|
|
key = (tm_shape, tm.dtype, sparsity)
|
2023-12-22 00:40:10 +08:00
|
|
|
t = self._tensor_metadata_cache.get(key)
|
|
|
|
if t is None:
|
2024-02-13 02:04:54 +08:00
|
|
|
t = self.get_vtensor_type(tm.shape, tm.dtype, sparsity=sparsity)
|
2023-12-22 00:40:10 +08:00
|
|
|
self._tensor_metadata_cache[key] = t
|
|
|
|
return t
|
|
|
|
|
|
|
|
def dtype_to_type(self, dtype: TorchDtype) -> IrType:
|
|
|
|
t = self._dtype_to_type.get(dtype)
|
|
|
|
if t is None:
|
|
|
|
try:
|
|
|
|
asm = TORCH_DTYPE_TO_MLIR_TYPE_ASM[dtype]
|
|
|
|
except IndexError:
|
|
|
|
raise ValueError(f"Unknown conversion from {dtype} to IREE type")
|
|
|
|
t = IrType.parse(asm, self._c)
|
|
|
|
self._dtype_to_type[dtype] = t
|
|
|
|
return t
|
|
|
|
|
|
|
|
def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType:
|
|
|
|
dtype_asm = str(self.dtype_to_type(tensor.dtype))
|
|
|
|
return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>")
|
|
|
|
|
|
|
|
def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
|
|
|
|
stack_trace = node.meta.get("stack_trace")
|
|
|
|
if stack_trace is None:
|
|
|
|
return None
|
|
|
|
# Ugh.
|
|
|
|
# TODO: Avoid needing to regex match this.
|
|
|
|
# https://github.com/pytorch/pytorch/issues/91000
|
|
|
|
stack_trace = node.stack_trace
|
|
|
|
if stack_trace:
|
|
|
|
m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace)
|
|
|
|
if m:
|
|
|
|
filename, line = m.group(1), int(m.group(2))
|
|
|
|
return Location.file(filename, line, col=0, context=self._c)
|
|
|
|
return Location.unknown(context=self._c)
|
|
|
|
|
|
|
|
|
|
|
|
class GraphNodeImporter:
|
|
|
|
"""Imports graph nodes into an MLIR function.
|
|
|
|
|
|
|
|
The caller must have already created the function.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__slots__ = [
|
|
|
|
"_b",
|
|
|
|
"_c",
|
|
|
|
"_cc",
|
|
|
|
"_literal_resolver_callback",
|
|
|
|
"_v",
|
|
|
|
"_multi_result_nodes",
|
|
|
|
"fx_importer",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
fx_importer: FxImporter,
|
|
|
|
context: Context,
|
|
|
|
context_cache: ContextCache,
|
|
|
|
block: Block,
|
|
|
|
*,
|
|
|
|
literal_resolver_callback: Optional[LiteralResolverCallback] = None,
|
|
|
|
):
|
|
|
|
self.fx_importer = fx_importer
|
|
|
|
self._c = context
|
|
|
|
self._cc = context_cache
|
|
|
|
self._b = block
|
|
|
|
# Map of (Node, result_index) to MLIR Value.
|
|
|
|
self._v: Dict[Tuple[torch_fx.Node, int], Value] = {}
|
|
|
|
# Statically multi-result nodes which we have de-tupled are noted here.
|
|
|
|
# They will have their getitem calls short-circuited.
|
|
|
|
self._multi_result_nodes: Set[torch_fx.Node] = set()
|
|
|
|
self._literal_resolver_callback = literal_resolver_callback
|
|
|
|
|
|
|
|
def import_nodes(self, nodes: Sequence[torch_fx.Node]):
|
|
|
|
with InsertionPoint(self._b):
|
|
|
|
loc = Location.unknown()
|
|
|
|
num_placeholders = 0
|
|
|
|
for node in nodes:
|
|
|
|
op = node.op
|
|
|
|
# Attempt to extract locations. Not everything has them,
|
|
|
|
# so we do our best.
|
|
|
|
new_loc = self._cc.get_node_location(node)
|
|
|
|
if new_loc is not None:
|
|
|
|
loc = new_loc
|
|
|
|
if op == "placeholder":
|
|
|
|
# Associate the placeholder node with corresponding block
|
|
|
|
# argument.
|
|
|
|
self._v[(node, 0)] = self._b.arguments[num_placeholders]
|
|
|
|
num_placeholders += 1
|
|
|
|
elif op == "call_function":
|
|
|
|
target = node.target
|
|
|
|
if target == operator.getitem:
|
|
|
|
# Special case handling of getitem for when it is resolving
|
|
|
|
# against a function call that we know has returned multiple
|
|
|
|
# results. We short-circuit this case because we have modeled
|
|
|
|
# function calls to natively return multiple results vs tupling.
|
|
|
|
getitem_ref, getitem_index = node.args
|
|
|
|
if getitem_ref in self._multi_result_nodes:
|
|
|
|
try:
|
|
|
|
self._v[(node, 0)] = self._v[
|
|
|
|
(getitem_ref, getitem_index)
|
|
|
|
]
|
|
|
|
except IndexError:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"getitem de-aliasing failed. This likely "
|
|
|
|
f"indicates a programmer error that usually "
|
|
|
|
f"would have happened at runtime. Please "
|
|
|
|
f"notify developers if this case happens "
|
|
|
|
f"(at {loc})."
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"General getitem access to non-multi-result ops"
|
|
|
|
)
|
|
|
|
elif isinstance(target, TorchOpOverload):
|
|
|
|
# Dispatch to an ATen op.
|
|
|
|
self._import_torch_op_overload(loc, node, target)
|
|
|
|
elif target in SYMBOLIC_TORCH_OPS or (
|
|
|
|
is_symbolic(node.meta.get("val"))
|
|
|
|
and is_builtin_function_or_method(target)
|
|
|
|
):
|
|
|
|
self._import_symbolic_torch_op(loc, node, target)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}"
|
|
|
|
)
|
|
|
|
elif op == "output":
|
|
|
|
# args[0] is a singleton tuple that we flatten into multiple
|
|
|
|
# results.
|
|
|
|
operands = [self._import_argument(loc, arg) for arg in node.args[0]]
|
|
|
|
func_dialect.ReturnOp(operands, loc=loc)
|
|
|
|
|
|
|
|
def _promote_symbolic_scalar_int_float(self, loc, graph, param):
|
|
|
|
temp_target = torch.ops.aten.Float.Scalar
|
|
|
|
temp_node = torch.fx.Node(
|
|
|
|
graph=graph,
|
|
|
|
name=f"{str(param)}_as_float",
|
|
|
|
op="call_function",
|
|
|
|
target=temp_target,
|
|
|
|
args=(param,),
|
|
|
|
kwargs={},
|
|
|
|
return_type=float,
|
|
|
|
)
|
|
|
|
temp_node.meta["val"] = torch.sym_float(param.meta["val"])
|
|
|
|
self._import_torch_op_overload(loc, temp_node, temp_target)
|
|
|
|
return temp_node
|
|
|
|
|
|
|
|
def _import_symbolic_torch_op(
|
|
|
|
self,
|
|
|
|
loc: Location,
|
|
|
|
node: torch_fx.Node,
|
|
|
|
target: Union[
|
|
|
|
torch._ops.OpOverloadPacket, BuiltinMethodType, BuiltinFunctionType
|
|
|
|
],
|
|
|
|
):
|
|
|
|
# parse builtin operations like add, sub, mul, etc. because dynamo captures these
|
|
|
|
# operations on symbolic arguments as regular python expressions rather than as torch ops
|
|
|
|
if is_builtin_function_or_method(target):
|
|
|
|
arg_types = [
|
2024-02-06 14:19:31 +08:00
|
|
|
(
|
|
|
|
arg.meta["val"].node.pytype
|
|
|
|
if isinstance(arg, torch.fx.Node)
|
|
|
|
else type(arg)
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
for arg in node.args
|
|
|
|
]
|
|
|
|
is_int = [item == int for item in arg_types]
|
|
|
|
if all(is_int):
|
|
|
|
op_overload = "int"
|
|
|
|
elif any(is_int):
|
|
|
|
if target.__name__ in ("add", "lt", "ge", "ne", "gt"):
|
|
|
|
op_overload = "float_int"
|
|
|
|
# put float arg first, as expected in signature
|
|
|
|
if arg_types[1] == float:
|
|
|
|
node.args = (node.args[1], node.args[0])
|
|
|
|
else:
|
|
|
|
# promote int argument to float - following torch-mlir convention
|
|
|
|
arg0, arg1 = node.args
|
|
|
|
if is_int[0]:
|
|
|
|
if isinstance(arg0, torch.fx.Node):
|
|
|
|
prom_arg = self._promote_symbolic_scalar_int_float(
|
|
|
|
loc, node.graph, arg0
|
|
|
|
)
|
|
|
|
new_args = (prom_arg, arg1)
|
|
|
|
else:
|
|
|
|
arg0 = float(arg0)
|
|
|
|
new_args = (arg0, arg1)
|
|
|
|
else:
|
|
|
|
if isinstance(arg1, torch.fx.Node):
|
|
|
|
prom_arg = self._promote_symbolic_scalar_int_float(
|
|
|
|
loc, node.graph, arg1
|
|
|
|
)
|
|
|
|
new_args = (arg0, prom_arg)
|
|
|
|
else:
|
|
|
|
arg1 = float(arg1)
|
|
|
|
new_args = (arg0, arg1)
|
|
|
|
|
|
|
|
node.args = new_args
|
|
|
|
op_overload = "float"
|
|
|
|
else:
|
|
|
|
op_overload = "float"
|
|
|
|
|
|
|
|
torch_op = PY_BUILTIN_TO_TORCH_OP.get(target.__name__)
|
|
|
|
assert (
|
|
|
|
torch_op is not None
|
|
|
|
), f"Unsupported builtin function for symbolic types: {target} with args {node.args}"
|
|
|
|
concrete_target = getattr(torch_op, op_overload)
|
|
|
|
else:
|
|
|
|
concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args)))
|
|
|
|
|
|
|
|
assert (
|
|
|
|
concrete_target is not None
|
|
|
|
), f"Unable to parse symbolic operation: {target} with args {node.args}"
|
|
|
|
self._import_torch_op_overload(loc, node, concrete_target)
|
|
|
|
|
|
|
|
def _import_torch_op_overload(
|
|
|
|
self, loc: Location, node: torch_fx.Node, target: TorchOpOverload
|
|
|
|
):
|
|
|
|
# replace lift_fresh_copy with clone op
|
|
|
|
if target == torch.ops.aten.lift_fresh_copy.default:
|
|
|
|
node.target = target = torch.ops.aten.clone.default
|
|
|
|
node.args = (node.args[0], None)
|
|
|
|
elif target == torch.ops.aten.lift_fresh_copy.out:
|
|
|
|
node.target = target = torch.ops.aten.clone.out
|
|
|
|
node.args = (node.args[0], None, node.args[1])
|
|
|
|
# TODO: generalize empty.memory_format in the future
|
|
|
|
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
|
|
|
|
# empty.memory_format input with a constant, which creates NaN values
|
|
|
|
# because empty.memory_format contains uninitialized data. Converting
|
|
|
|
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
|
|
|
|
elif target == torch.ops.aten.empty.memory_format:
|
|
|
|
if len(node.users) == 1:
|
|
|
|
for key_node in node.users:
|
|
|
|
if key_node.target == torch.ops.aten.baddbmm.default:
|
|
|
|
node.target = target = torch.ops.aten.zeros.default
|
|
|
|
|
|
|
|
schema = target._schema
|
|
|
|
assert isinstance(schema, FunctionSchema)
|
|
|
|
|
|
|
|
# Map to a `torch` dialect name.
|
|
|
|
namespace, sep, unqualified_name = schema.name.partition("::")
|
|
|
|
assert sep, f"Malformed Torch op name {schema.name}"
|
|
|
|
mlir_op_name = f"torch.{namespace}.{unqualified_name}"
|
|
|
|
if schema.overload_name != "":
|
|
|
|
mlir_op_name += f".{schema.overload_name}"
|
|
|
|
|
|
|
|
# Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments.
|
|
|
|
if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and (
|
|
|
|
isinstance(node.args[1], float) or isinstance(node.args[1], int)
|
|
|
|
):
|
|
|
|
mlir_op_name = TENSOR_SCALAR_OP_CONVERTER[mlir_op_name]
|
|
|
|
# we are dynamically changing which op is emitted here due to an issue in
|
|
|
|
# torch dynamo where it emits the Tensor variant of ops even when processing
|
|
|
|
# scalar arguments, therefore we retrieve the schema as well so that we
|
|
|
|
# consume the correct typing information when subsequently importing the
|
|
|
|
# function arguments and result types
|
|
|
|
# i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema`
|
|
|
|
op_attrs = mlir_op_name.split(".")
|
|
|
|
op_overload = getattr(torch, "ops")
|
|
|
|
for i in range(1, len(op_attrs)):
|
|
|
|
op_overload = getattr(op_overload, op_attrs[i])
|
|
|
|
schema = op_overload._schema
|
|
|
|
|
|
|
|
return_count = len(schema.returns)
|
|
|
|
if return_count == 1:
|
|
|
|
# Unary return directly maps a single meta["val"] and cannot be subscripted.
|
|
|
|
# if "tensor_meta" is None, this will throw unsupported placeholder node error
|
|
|
|
result_types = [self._cc.node_val_to_type(node)]
|
|
|
|
elif return_count == 0:
|
|
|
|
# Some torch ops do have 0 returns, and these are supported with ZeroResults
|
|
|
|
# op trait. Python bindings for IR creation allow us to pass empty result_types
|
|
|
|
# for such ops. Therefore, we pass an empty result types for these cases.
|
|
|
|
result_types = []
|
|
|
|
else:
|
|
|
|
# Multi-return will unpack the meta["val"] and trigger our getitem subscripting
|
|
|
|
# short-circuit above. Note that if we ever choose to also fully reify Python
|
|
|
|
# level result tuples, we will need to create a tuple-boxed version of this and
|
|
|
|
# redirect to it for generic object access.
|
|
|
|
|
|
|
|
result_types = []
|
|
|
|
for v in node.meta["val"]:
|
|
|
|
result_types.append(self._cc.tensor_metadata_to_type(v))
|
|
|
|
result_types = tuple(result_types)
|
|
|
|
|
|
|
|
self._multi_result_nodes.add(node)
|
|
|
|
# Unroll operands from formal parameters, args and kwargs.
|
|
|
|
operands = []
|
|
|
|
for i, parameter in enumerate(schema.arguments):
|
|
|
|
if parameter.kwarg_only and parameter.name in node.kwargs:
|
|
|
|
operands.append(
|
|
|
|
self._import_argument(
|
|
|
|
loc, node.kwargs[parameter.name], parameter.type
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif i < len(node.args):
|
|
|
|
operands.append(
|
|
|
|
self._import_argument(loc, node.args[i], parameter.type)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
operands.append(
|
|
|
|
self._import_default_value(
|
|
|
|
loc, parameter.default_value, parameter.type
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Support unregistered torch ops using torch.operator.
|
|
|
|
# torch.operator is used to represent ops from registry
|
|
|
|
# which haven't been generated by torch_ods_gen.py.
|
|
|
|
if not self._c.is_registered_operation(mlir_op_name):
|
|
|
|
operation = Operation.create(
|
|
|
|
"torch.operator",
|
|
|
|
attributes={"name": StringAttr.get(mlir_op_name)},
|
|
|
|
results=result_types,
|
|
|
|
operands=operands,
|
|
|
|
loc=loc,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
operation = Operation.create(
|
|
|
|
mlir_op_name,
|
|
|
|
results=result_types,
|
|
|
|
operands=operands,
|
|
|
|
loc=loc,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Record value mapping.
|
|
|
|
for i, value in enumerate(operation.results):
|
|
|
|
self._v[(node, i)] = value
|
|
|
|
|
|
|
|
def _import_argument(
|
|
|
|
self, loc: Location, arg: NodeArgument, expected_jit_type=None
|
|
|
|
) -> Value:
|
|
|
|
"""Import an FX `Argument`, which must result to an MLIR `Value`."""
|
|
|
|
if isinstance(arg, torch_fx.Node):
|
|
|
|
# If implementing boxed support for multi-result nodes, then
|
|
|
|
# this will need to do something more intelligent.
|
|
|
|
if arg in self._multi_result_nodes:
|
|
|
|
raise RuntimeError(f"Attempt to de-reference a multi-result node")
|
|
|
|
|
|
|
|
# catch references to dynamically created constant attributes and make sure they have an origin in our module
|
|
|
|
if arg.op == "get_attr" and (arg.target, 0) not in self._v:
|
|
|
|
gm = arg.graph.owning_module
|
|
|
|
assert hasattr(
|
|
|
|
gm, arg.target
|
|
|
|
), f"Attempting to retrieve attribute '{arg.target}' from module, but no such attribute exists"
|
|
|
|
obj = getattr(gm, arg.target)
|
|
|
|
with loc:
|
|
|
|
self._v[(arg, 0)] = self._import_literal(obj)
|
|
|
|
|
|
|
|
return self._v[(arg, 0)]
|
|
|
|
elif isinstance(arg, torch_fx.immutable_collections.immutable_list):
|
|
|
|
return self._import_list_argument(loc, arg, expected_jit_type)
|
|
|
|
elif isinstance(expected_jit_type, torch.TensorType) and not isinstance(
|
|
|
|
arg, torch.Tensor
|
|
|
|
):
|
|
|
|
# promote scalars to tensor types as appropriate
|
|
|
|
return self._import_scalar_as_tensor(loc, arg)
|
|
|
|
else:
|
|
|
|
with loc:
|
|
|
|
return self._import_literal(arg)
|
|
|
|
|
|
|
|
def _import_literal(self, py_value: Any) -> Value:
|
|
|
|
# Apply the conversion callback.
|
|
|
|
user_callback = self._literal_resolver_callback
|
|
|
|
if user_callback:
|
|
|
|
user_value = user_callback(py_value, self)
|
|
|
|
if user_value is not None:
|
|
|
|
assert isinstance(user_value, Value)
|
|
|
|
return user_value
|
|
|
|
|
|
|
|
# Default conversion path.
|
|
|
|
converter = LITERAL_CONVERTER_MAP.lookup(type(py_value))
|
|
|
|
if converter is None:
|
|
|
|
raise TypeError(
|
|
|
|
f"Unsupported argument -> literal conversion for {py_value.__class__}"
|
|
|
|
)
|
|
|
|
return converter(py_value, self, self._cc)
|
|
|
|
|
|
|
|
def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value:
|
|
|
|
tensor_arg = torch.tensor(arg)
|
|
|
|
result_type = self._cc.get_vtensor_type(tensor_arg.size(), tensor_arg.dtype)
|
|
|
|
with loc:
|
|
|
|
constant_arg = LITERAL_CONVERTER_MAP.lookup(type(arg))(arg, self, self._cc)
|
|
|
|
|
|
|
|
return Operation.create(
|
|
|
|
name="torch.prim.NumToTensor.Scalar",
|
|
|
|
results=[result_type],
|
|
|
|
operands=[constant_arg],
|
|
|
|
loc=loc,
|
|
|
|
).result
|
|
|
|
|
|
|
|
def _import_list_argument(
|
|
|
|
self, loc: Location, arg: NodeArgument, expected_jit_type
|
|
|
|
) -> Value:
|
|
|
|
assert (
|
|
|
|
isinstance(expected_jit_type, torch.ListType)
|
|
|
|
or (
|
|
|
|
isinstance(expected_jit_type, torch.OptionalType)
|
|
|
|
and isinstance(expected_jit_type.getElementType(), torch.ListType)
|
|
|
|
)
|
|
|
|
or isinstance(expected_jit_type, NoneType)
|
|
|
|
), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}"
|
|
|
|
|
|
|
|
# parse list type
|
|
|
|
if expected_jit_type is None:
|
|
|
|
element_type = type(arg[0])
|
|
|
|
else:
|
|
|
|
element_jit_type = expected_jit_type.getElementType()
|
|
|
|
|
|
|
|
# this branch is needed to handle Optional[List[]] types
|
|
|
|
if isinstance(element_jit_type, torch.ListType):
|
|
|
|
element_jit_type = element_jit_type.getElementType()
|
|
|
|
|
|
|
|
# this handles getting the inner types for List[Optional[]] types
|
|
|
|
is_optional_type = isinstance(element_jit_type, torch.OptionalType)
|
|
|
|
if is_optional_type:
|
|
|
|
element_jit_type = element_jit_type.getElementType()
|
|
|
|
element_type = TORCH_TYPE_TO_PY_TYPE[type(element_jit_type)]
|
|
|
|
|
|
|
|
# create list operands
|
|
|
|
list_operands = []
|
|
|
|
|
|
|
|
for operand in arg:
|
|
|
|
operand_type = type(operand)
|
|
|
|
if isinstance(operand, torch.fx.Node):
|
|
|
|
if operand in self._multi_result_nodes:
|
|
|
|
raise RuntimeError(f"Attempt to de-reference a multi-result node")
|
|
|
|
val = self._v[(operand, 0)]
|
|
|
|
val_type = str(val.type)
|
|
|
|
assert (
|
|
|
|
isinstance(element_type, str) and element_type in val_type
|
|
|
|
) or SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(
|
|
|
|
element_type
|
|
|
|
) == val_type, f"Heterogeneous lists are not supported: expected {element_type}, got {val_type}"
|
|
|
|
else:
|
|
|
|
assert (is_optional_type and operand_type is NoneType) or (
|
|
|
|
element_type == operand_type
|
|
|
|
), f"Heterogeneous lists are not supported: expected {element_type}, got {operand_type}"
|
|
|
|
|
|
|
|
operand_jit_type = (
|
|
|
|
torch.NoneType if operand_type is NoneType else element_jit_type
|
|
|
|
)
|
|
|
|
val = self._import_default_value(loc, operand, operand_jit_type)
|
|
|
|
|
|
|
|
list_operands.append(val)
|
|
|
|
|
|
|
|
# construct list op
|
|
|
|
if is_optional_type:
|
|
|
|
list_type = PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE[element_type]
|
|
|
|
else:
|
|
|
|
list_type = PY_TYPE_TO_TORCH_LIST_TYPE[element_type]
|
|
|
|
|
|
|
|
result_type = IrType.parse(list_type, context=self._c)
|
|
|
|
operation = Operation.create(
|
|
|
|
"torch.prim.ListConstruct",
|
|
|
|
results=[result_type],
|
|
|
|
operands=list_operands,
|
|
|
|
loc=loc,
|
|
|
|
)
|
|
|
|
|
|
|
|
return operation.result
|
|
|
|
|
|
|
|
def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
|
|
|
|
"""Imports a defaulted value for a known function schema."""
|
|
|
|
if isinstance(arg, list):
|
|
|
|
return self._import_list_argument(loc, arg, expected_jit_type)
|
|
|
|
|
|
|
|
# The LITERAL_CONVERTER_MAP maps each arg to its respective constant
|
|
|
|
# of the expected jit IR type (types like torch.dtype will form a chain of
|
|
|
|
# maps to get to constant of expected_jit_type).
|
|
|
|
cvt = LITERAL_CONVERTER_MAP.lookup(type(arg))
|
|
|
|
if cvt is None:
|
|
|
|
raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})")
|
|
|
|
with loc:
|
|
|
|
return cvt(arg, self, self._cc)
|
|
|
|
|
|
|
|
|
|
|
|
def _make_constant_op(
|
|
|
|
op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None
|
|
|
|
) -> Operation:
|
|
|
|
return Operation.create(
|
|
|
|
op_name,
|
|
|
|
results=[result_type if result_type else value_attr.type],
|
|
|
|
attributes={"value": value_attr},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
|
|
|
|
try:
|
|
|
|
dtype = tensor.dtype
|
|
|
|
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
|
|
|
|
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
|
|
|
|
return tensor_type
|
|
|
|
except KeyError:
|
2024-02-06 14:19:31 +08:00
|
|
|
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def _make_vtensor_literal_op(
|
|
|
|
tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker"
|
|
|
|
) -> Operation:
|
|
|
|
mapping = py_attr_tracker.track(tensor)
|
|
|
|
if mapping.is_empty:
|
|
|
|
# Resolve the attribute.
|
|
|
|
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
|
|
|
|
assert (
|
|
|
|
npy_dtype is not None
|
|
|
|
), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}"
|
|
|
|
# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,
|
|
|
|
# but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get
|
|
|
|
# a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to
|
|
|
|
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
|
|
|
|
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
|
|
|
|
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
|
|
|
|
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
|
2024-02-06 14:19:31 +08:00
|
|
|
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
|
|
|
|
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
|
|
|
|
# 0d tensors.
|
|
|
|
if np_tensor.size == 1:
|
|
|
|
try:
|
|
|
|
dtype = tensor.dtype
|
|
|
|
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
|
|
|
|
except KeyError:
|
|
|
|
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
|
|
|
|
elements_attr = DenseElementsAttr.get(
|
|
|
|
type=element_type, array=np_tensor, shape=np_tensor.shape
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
bytes_view = memoryview(np_tensor)
|
|
|
|
tensor_type = create_mlir_tensor_type(tensor)
|
|
|
|
shape_desc = "_".join([str(d) for d in tensor.shape])
|
|
|
|
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
|
|
|
elements_attr = DenseResourceElementsAttr.get_from_buffer(
|
|
|
|
bytes_view,
|
|
|
|
blob_name,
|
|
|
|
tensor_type,
|
|
|
|
)
|
2023-12-22 00:40:10 +08:00
|
|
|
mapping.value = elements_attr
|
|
|
|
else:
|
|
|
|
elements_attr = mapping.value
|
|
|
|
return Operation.create(
|
|
|
|
name="torch.vtensor.literal",
|
|
|
|
results=[vtensor_type],
|
|
|
|
attributes={"value": elements_attr},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
################################################################################
|
|
|
|
# TypeSubclassMapping
|
|
|
|
################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
class TypeSubclassMap:
|
|
|
|
"""Mapping of super-types to values.
|
|
|
|
|
|
|
|
Maintains a cache of actual types seen and uses that instead of a linear
|
|
|
|
scan.
|
|
|
|
"""
|
|
|
|
|
|
|
|
__slots__ = [
|
|
|
|
"_cache",
|
|
|
|
"_mapping",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
# The linear list of converters.
|
|
|
|
self._mapping: List[Tuple[type, Any]] = []
|
|
|
|
# When there is a hit on the linear mapping, memoize it here.
|
|
|
|
self._cache: Dict[type, Any] = {}
|
|
|
|
|
|
|
|
def map(self, t: type, value: Any):
|
|
|
|
self._mapping.append((t, value))
|
|
|
|
self._cache[t] = value
|
|
|
|
|
|
|
|
def lookup(self, t: type) -> Any:
|
|
|
|
try:
|
|
|
|
return self._cache[t]
|
|
|
|
except KeyError:
|
|
|
|
pass
|
|
|
|
for t_super, value in self._mapping:
|
|
|
|
if issubclass(t, t_super):
|
|
|
|
self._cache[t] = value
|
|
|
|
return value
|
|
|
|
else:
|
|
|
|
self._cache[t] = None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
###############################################################################
|
|
|
|
# Reference mapping
|
|
|
|
###############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
# Opaque value to indicate something is empty. Used in cases where 'None'
|
|
|
|
# may have a different meaning.
|
2024-02-13 02:04:54 +08:00
|
|
|
class EmptyType:
|
|
|
|
...
|
2023-12-22 00:40:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
Empty = EmptyType()
|
|
|
|
|
|
|
|
|
|
|
|
class RefMapping:
|
|
|
|
__slots__ = [
|
|
|
|
"_referrent",
|
|
|
|
"value",
|
|
|
|
]
|
|
|
|
|
|
|
|
def __init__(self, referrent: Any):
|
|
|
|
if referrent is not Empty:
|
|
|
|
self._referrent = weakref.ref(referrent)
|
|
|
|
self.value = Empty
|
|
|
|
|
|
|
|
@property
|
|
|
|
def is_empty(self):
|
|
|
|
return self.value is Empty
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return (
|
|
|
|
f"<RefMapping {id(self._referrent) if self._referrent is not Empty else 'empty'} -> "
|
|
|
|
f"{self.value if self.value is not Empty else 'empty'}>"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class RefTracker:
|
|
|
|
"""Tracks live references from Python values to symbolic associations."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self._refs: Dict[int, RefMapping] = {}
|
|
|
|
|
|
|
|
def track(self, referrent: Any) -> RefMapping:
|
|
|
|
ref_id = id(referrent)
|
|
|
|
existing = self._refs.get(ref_id)
|
|
|
|
if existing:
|
|
|
|
return existing
|
|
|
|
info = RefMapping(referrent)
|
|
|
|
if referrent is not Empty:
|
|
|
|
weakref.finalize(referrent, self._ref_finalizer, ref_id)
|
|
|
|
self._refs[ref_id] = info
|
|
|
|
return info
|
|
|
|
|
|
|
|
def _ref_finalizer(self, ref_id: int):
|
|
|
|
del self._refs[ref_id]
|
|
|
|
|
|
|
|
|
|
|
|
################################################################################
|
|
|
|
# Mappings
|
|
|
|
################################################################################
|
|
|
|
|
|
|
|
LITERAL_CONVERTER_MAP = TypeSubclassMap()
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
NoneType,
|
|
|
|
lambda arg, gni, cc: Operation.create(
|
|
|
|
"torch.constant.none", results=[cc.torch_none_type]
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
bool,
|
|
|
|
lambda arg, gni, cc: _make_constant_op(
|
|
|
|
"torch.constant.bool", cc.integer_attr(arg, 1), cc.torch_bool_type
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
int,
|
|
|
|
lambda arg, gni, cc: _make_constant_op(
|
|
|
|
"torch.constant.int", cc.integer_attr(arg, 64), cc.torch_int_type
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
float,
|
|
|
|
lambda arg, gni, cc: _make_constant_op(
|
|
|
|
"torch.constant.float", FloatAttr.get_f64(arg), cc.torch_float_type
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
str,
|
|
|
|
lambda arg, gni, cc: _make_constant_op(
|
|
|
|
"torch.constant.str", StringAttr.get(arg), cc.torch_str_type
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
torch.Tensor,
|
|
|
|
lambda arg, gni, cc: _make_vtensor_literal_op(
|
|
|
|
arg, cc.tensor_to_vtensor_type(arg), cc._py_attr_tracker
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
torch.device,
|
|
|
|
lambda arg, gni, cc: _make_constant_op(
|
|
|
|
"torch.constant.device", StringAttr.get(str(arg)), cc.torch_device_type
|
|
|
|
).result,
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
torch.dtype,
|
|
|
|
lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)(
|
|
|
|
TORCH_DTYPE_TO_INT[arg], gni, cc
|
|
|
|
),
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
torch.layout,
|
|
|
|
lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)(
|
|
|
|
TORCH_LAYOUT_TO_INT[arg], gni, cc
|
|
|
|
),
|
|
|
|
)
|
|
|
|
LITERAL_CONVERTER_MAP.map(
|
|
|
|
torch.memory_format,
|
|
|
|
lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)(
|
|
|
|
TORCH_MEMORY_FORMAT_TO_INT[arg], gni, cc
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
TORCH_TYPE_TO_PY_TYPE = {
|
|
|
|
torch.IntType: int,
|
|
|
|
torch.FloatType: float,
|
|
|
|
torch.StringType: str,
|
|
|
|
torch.BoolType: bool,
|
|
|
|
torch.TensorType: "vtensor",
|
|
|
|
}
|
|
|
|
|
|
|
|
PY_TYPE_TO_TORCH_LIST_TYPE = {
|
|
|
|
int: "!torch.list<int>",
|
|
|
|
float: "!torch.list<float>",
|
|
|
|
str: "!torch.list<str>",
|
|
|
|
bool: "!torch.list<bool>",
|
|
|
|
"tensor": "!torch.list<tensor>",
|
|
|
|
"vtensor": "!torch.list<vtensor>",
|
|
|
|
}
|
|
|
|
|
|
|
|
PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = {
|
|
|
|
int: "!torch.list<optional<int>>",
|
|
|
|
float: "!torch.list<optional<float>>",
|
|
|
|
str: "!torch.list<optional<str>>",
|
|
|
|
bool: "!torch.list<optional<bool>>",
|
|
|
|
"tensor": "!torch.list<optional<tensor>>",
|
|
|
|
"vtensor": "!torch.list<optional<vtensor>>",
|
|
|
|
}
|
|
|
|
|
|
|
|
SCALAR_TYPE_TO_TORCH_MLIR_TYPE = {
|
|
|
|
torch.SymInt: "!torch.int",
|
|
|
|
torch.SymFloat: "!torch.float",
|
|
|
|
torch.SymBool: "!torch.bool",
|
|
|
|
int: "!torch.int",
|
|
|
|
float: "!torch.float",
|
|
|
|
str: "!torch.str",
|
|
|
|
bool: "!torch.bool",
|
|
|
|
NoneType: "!torch.none",
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# AOT-autograd sometimes falsely emit tensor version op with scalar arguments.
|
|
|
|
# We may remove this dictionary, if we fix such behavior in the backend.
|
|
|
|
TENSOR_SCALAR_OP_CONVERTER = {
|
|
|
|
"torch.aten.mul.Tensor": "torch.aten.mul.Scalar",
|
|
|
|
"torch.aten.div.Tensor": "torch.aten.div.Scalar",
|
|
|
|
"torch.aten.add.Tensor": "torch.aten.add.Scalar",
|
|
|
|
"torch.aten.sub.Tensor": "torch.aten.sub.Scalar",
|
|
|
|
"torch.aten.floor_divide": "torch.aten.floor_divide.Scalar",
|
|
|
|
}
|