torch-mlir/python/torch_mlir/extras/onnx_importer.py

786 lines
31 KiB
Python

# Based on code Copyright (c) Advanced Micro Devices, Inc.
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""Imports ONNX graphs to `torch` dialect ops.
See documentation:
https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md
This file is distributed/forked verbatim into various downstream projects, and
it must abide by several rules above and beyond the rest of the codebase:
- It must be standalone, only depending on:
- `onnx`
- `..ir` relative imports to the main IR directory
- `..dialects.func` relative import to the `func` dialect (TODO:
we are looking to eliminate this dep).
- Python standard library
- It does not directly use the ODS generated `torch` dialect Python
wrappers. This allows it to be used in contexts that only build a C++
compiler with minimal IR Python bindings.
- It is intended as an enabler for full onnx compilation, only handling
the import from ONNX -> the `torch` dialect. Testing, full pipelines,
and utilities belong elsewhere.
"""
try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"The onnx package (`pip install onnx`) is required to use the onnx importer"
) from e
from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass
import numpy as np
import re
from ..ir import (
ArrayAttr,
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
DictAttr,
FloatAttr,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float8E4M3FNType,
Float8E5M2FNUZType,
Float8E5M2Type,
FunctionType,
InsertionPoint,
IntegerAttr,
IntegerType,
MLIRError,
RankedTensorType,
Location,
Module,
Operation,
StringAttr,
Type as IrType,
Value,
)
from ..dialects import (
func as func_dialect,
)
@dataclass
class Config:
"""Various configuration settings for the importer."""
# Ancient ONNX exporters would often add a model input for anything that
# might be mutable, providing an initializer for it as well. More modern
# tools tools realized this is a really bad idea for a lot of reasons.
# We choose to assume more recent norms, even if encountering older
# models. Setting this to False probably won't do what you want but
# should produce interesting errors to waste your time deciphering.
# We mainly use it as a way to document in the code that we are
# making an assumption.
elide_initialized_inputs: bool = True
class ModelInfo:
"""Top-level accounting and accessors for an ONNX model."""
def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()):
self.config = config
self.model_proto = model_proto
assert model_proto.graph, "Model must contain a main Graph"
self.main_graph = GraphInfo(self, model_proto.graph)
def create_module(self, context: Optional[Context] = None) -> Module:
if not context:
context = Context()
module = Module.create(Location.unknown(context))
# TODO: Populate module level metadata from the ModelProto
return module
class GraphInfo:
"""Information about a Graph within a model."""
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
self.model_info = model_info
self.graph_proto = graph_proto
self.initializer_map: Dict[str, onnx.TensorProto] = {
n.name: n for n in graph_proto.initializer
}
self.value_info_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.value_info
}
self.declared_input_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.input
}
self.output_map: Dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.output
}
# Generate the effective input map, which for old models can be a
# subset of the input map.
if model_info and model_info.config.elide_initialized_inputs:
self.input_map = {
k: v
for k, v in self.declared_input_map.items()
if k not in self.initializer_map
}
else:
self.input_map = self.declared_input_map
illegal_input_keys = self.input_map.keys() - (
self.input_map.keys() - self.initializer_map.keys()
)
assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), (
f"When not in elide_initialized_inputs=True, we expect inputs to not "
f"have an initial value (got {illegal_input_keys})."
)
def find_type_proto_for_name(self, name: str) -> onnx.TypeProto:
# Node outputs don't typically have type information, but shape inference
# will associate them in the value_info. If not there, it may be a
# graph output, which must have type information.
value_info = self.value_info_map.get(name) or self.output_map.get(name)
if value_info is not None:
return value_info.type
# No type information is associated, this can occur when the value is unused:
return ""
class OnnxImportError(Exception): ...
class NodeImporter:
"""Imports graph nodes into MLIR.
Typically, the top level graph will be imported into a func whereas dependent
graphs may just be imported with references to pre-existing values.
Note that ONNX requires that graphs be sorted topologically and free of cycles,
so we don't take any special steps to order them for dominance.
"""
__slots__ = [
"_c",
"_cc",
"_gi",
"_p",
"_b",
"_nv_map",
]
def __init__(
self,
graph_info: GraphInfo,
*,
parent_op: Operation,
block: Block,
context_cache: "ContextCache",
):
self._c = parent_op.context
self._cc = context_cache
self._gi = graph_info
self._p = parent_op
self._b = block
self._nv_map: Dict[str, Value] = {}
@classmethod
def define_function(
cls, graph_info: GraphInfo, module_op: Operation
) -> "NodeImporter":
cc = ContextCache(module_op.context)
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
body = module_op.regions[0].blocks[0]
func_name = graph_info.graph_proto.name
input_types = [
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
]
output_types = [
cc.type_proto_to_type(out.type)
for out in graph_info.output_map.values()
]
ftype = FunctionType.get(input_types, output_types)
func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body))
block = func_op.add_entry_block(
[Location.name(k) for k in graph_info.input_map.keys()]
)
imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc)
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
imp._nv_map[node_name] = input_value
imp._populate_graph_attrs(func_op)
return imp
def _populate_graph_attrs(self, container_op: Operation):
"""Populates graph level meta attributes on the given container op."""
m = self._gi.model_info.model_proto
with container_op.context:
i64_type = IntegerType.get_signed(64)
default_opset_version = 0
opset_versions: Dict[str, IntegerAttr] = {}
for opset_import in m.opset_import:
if opset_import.domain:
opset_versions[opset_import.domain] = IntegerAttr.get(
i64_type, opset_import.version
)
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes["torch.onnx_meta.opset_version"] = (
IntegerAttr.get(i64_type, default_opset_version)
)
if opset_versions:
container_op.attributes["torch.onnx_meta.opset_versions"] = (
DictAttr.get(opset_versions)
)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes["torch.onnx_meta.producer_version"] = (
StringAttr.get(m.producer_version)
)
def import_all(self, func=True):
"""Imports all nodes topologically."""
# TODO: Consider pulling in initializers on demand since there can be so
# much unused crap.
for init in self._gi.initializer_map.values():
self.import_initializer(init)
self.get_none()
for node in self._gi.graph_proto.node:
self.import_node(node)
outputs = []
for output_name in self._gi.output_map.keys():
try:
outputs.append(self._nv_map[output_name])
except KeyError:
raise OnnxImportError(
f"Non topologically produced ONNX graph output '{output_name}'"
)
with InsertionPoint(self._b), Location.unknown():
if func:
func_dialect.ReturnOp(outputs)
else:
Operation.create(name="torch.operator_terminator", operands=outputs)
def get_none(self):
if "" in self._nv_map:
return self._nv_map[""]
with InsertionPoint(self._b), Location.name("onnx_importer.none"):
nne = Operation.create(
name="torch.constant.none",
results=[self._cc.get_none_type()],
operands=[],
attributes={},
).results[0]
self._nv_map[""] = nne
return nne
def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
# Handle special op types that materialize to non-op IR constructs.
# Handlers return True if the op was handled, else this function
# should process it as a general node.
special_key = f"_handle_node_{op_type}"
if hasattr(self, special_key):
was_handled = getattr(self, special_key)(node)
if was_handled:
return
# General node import.
input_values = []
for input_name in node.input:
try:
input_values.append(self._nv_map[input_name])
except KeyError:
raise OnnxImportError(
f"Non topologically produced ONNX node input '{input_name}': {node}"
)
output_names = list(node.output)
output_types = [
self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n))
for n in output_names
]
attrs = self.import_attributes(node.attribute)
attrs["name"] = StringAttr.get(f"onnx.{op_type}")
regions = self.count_regions(node.attribute)
custom_op = Operation.create(
name="torch.operator",
results=output_types,
operands=input_values,
attributes=attrs,
regions=regions,
)
self.import_regions(node.attribute, custom_op)
for output_name, output_value in zip(output_names, custom_op.results):
self._nv_map[output_name] = output_value
def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
attrs = {}
for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type
if attr_type not in ATTRIBUTE_TYPE_HANDLERS:
raise OnnxImportError(
f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}"
)
handler = ATTRIBUTE_TYPE_HANDLERS[attr_type]
if handler is None:
# Active skip.
continue
elif handler is False:
# Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}"
)
result = handler(onnx_attr, self._cc)
attrs[f"torch.onnx.{onnx_attr.name}"] = result
return attrs
def count_regions(self, onnx_attrs: List[onnx.AttributeProto]):
count = 0
for onnx_attr in onnx_attrs:
if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH:
count += 1
return count
def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op):
attr_map = {}
for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type
if attr_type != onnx.AttributeProto.AttributeType.GRAPH:
continue
attr_map[onnx_attr.name] = onnx_attr
for name, region in zip(sorted(attr_map.keys()), op.regions):
attr = attr_map[name]
block_types = [
self._cc.type_proto_to_type(input.type) for input in attr.g.input
]
block_names = [input.name for input in attr.g.input]
region.blocks.append(
*block_types, arg_locs=[op.location] * len(block_types)
)
block = region.blocks[0]
graph_info = GraphInfo(None, attr.g)
imp = NodeImporter(
graph_info, parent_op=op, block=block, context_cache=self._cc
)
for node_name, input_value in zip(block_names, block.arguments):
imp._nv_map[node_name] = input_value
for k in self._nv_map:
imp._nv_map[k] = self._nv_map[k]
imp.import_all(False)
def import_initializer(
self, initializer: onnx.TensorProto, extern_name: str = None
) -> Value:
# If an explicitly specified name is given, use that; otherwise, pick
# up the name from the tensor proto itself
iname = extern_name if extern_name else initializer.name
with InsertionPoint(self._b), Location.name(iname):
value_attr = self._cc.tensor_proto_to_attr(initializer)
vtensor_type = self._cc.tensor_proto_to_type(initializer)
attrs = {
"name": StringAttr.get(f"onnx.Constant"),
"torch.onnx.value": value_attr,
}
literal_op = Operation.create(
name="torch.operator",
results=[vtensor_type],
attributes=attrs,
)
self._nv_map[iname] = literal_op.result
return literal_op.result
def _get_immediate_tensor(self, name: str) -> np.array:
try:
initializer = self._gi.initializer_map[name]
except KeyError:
raise OnnxImportError(
f"An immediate value for '{name}' was required but it is dynamically produced."
)
try:
dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type]
except KeyError:
raise OnnxImportError(
f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}"
)
raw_data = initializer.raw_data
if raw_data:
return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims))
else:
raise OnnxImportError(
f"Unhandled ONNX TensorProto immediate data: {initializer}"
)
def _handle_node_Constant(self, node: onnx.NodeProto) -> bool:
# Special case only for constants specified by value attribute (for now)
value_proto = _get_attr(node, "value", False)
if not value_proto:
return False
# Produce an initializer for the constant, so that it can be used in
# combination with other ops, such as ConstantOfShape, requiring
# a constant input
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
assert len(node.output) == 1
const_name = node.output[0]
self.import_initializer(value_proto.t, const_name)
self._gi.initializer_map[const_name] = value_proto.t
return True
class ContextCache:
"""Caches per-context lookups of various things."""
__slots__ = [
"_c",
"_elem_type_map",
"_list_type_map",
"_optional_type_map",
"_vtensor_type_map",
]
def __init__(self, context: Context):
self._c = context
self._elem_type_map: Dict[int, IrType] = {}
self._list_type_map: Dict[str, IrType] = {}
self._optional_type_map: Dict[str, IrType] = {}
self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {}
def tensor_element_type(self, elem_type: int) -> IrType:
t = self._elem_type_map.get(elem_type)
if t is None:
try:
with self._c:
t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]()
except KeyError:
raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}")
self._elem_type_map[elem_type] = t
return t
def get_none_type(self):
return IrType.parse("!torch.none", context=self._c)
def get_list_type(self, element_type: IrType) -> IrType:
key = str(element_type)
t = self._list_type_map.get(key)
if t is None:
asm = f"!torch.list<{str(element_type)}>"
try:
t = IrType.parse(asm, context=self._c)
except MLIRError as e:
raise OnnxImportError(
f"Unparseable torch type (MLIR asm format bug?): {asm}"
) from e
self._list_type_map[key] = t
return t
def get_optional_type(self, element_type: IrType) -> IrType:
key = str(element_type)
t = self._optional_type_map.get(key)
if t is None:
asm = f"!torch.optional<{str(element_type)}>"
try:
t = IrType.parse(asm, context=self._c)
except MLIRError as e:
raise OnnxImportError(
f"Unparseable torch type (MLIR asm format bug?): {asm}"
) from e
self._optional_type_map[key] = t
return t
def get_list_element_type(self, tp: onnx.TypeProto) -> IrType:
tt = tp.tensor_type
if tt.elem_type:
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
return f"vtensor<[{shape_asm}],{element_type}>"
raise OnnxImportError(f"Unsupport list element type")
def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType:
st = tp.sequence_type
tt = tp.tensor_type
if tt.elem_type:
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
return f"vtensor<[{shape_asm}],{element_type}>"
if st.elem_type:
element_type = self.get_list_element_type(st.elem_type)
return f"list<{element_type}>"
raise OnnxImportError(f"Unsupport optional element type")
def get_vtensor_type(
self, dims: Tuple[Optional[int]], element_type: IrType
) -> IrType:
key = (dims, element_type)
t = self._vtensor_type_map.get(key)
if t is None:
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>"
try:
t = IrType.parse(asm, context=self._c)
except MLIRError as e:
raise OnnxImportError(
f"Unparseable torch type (MLIR asm format bug?): {asm}"
) from e
self._vtensor_type_map[key] = t
return t
def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType:
element_type = self.tensor_element_type(tp.data_type)
return self.get_vtensor_type(tuple(tp.dims), element_type)
def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType:
element_type = self.tensor_element_type(tp.data_type)
# TODO: Fixme upstream: RankedTensorType.get should not require a location.
with Location.unknown():
try:
return RankedTensorType.get(tuple(tp.dims), element_type)
except TypeError as e:
raise OnnxImportError(f"Unsupported builtin tensor type") from e
def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
if tp == "":
return self.get_none_type()
tt = tp.tensor_type
if tt.elem_type:
if not tt.shape:
raise OnnxImportError(
f"Unsupported Tensor type without shape (run shape inference?): {tp}"
)
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
return self.get_vtensor_type(dims, element_type)
st = tp.sequence_type
if len(str(st.elem_type)) > 0:
element_type = self.get_list_element_type(st.elem_type)
return self.get_list_type(element_type)
ot = tp.optional_type
if len(str(ot.elem_type)) > 0:
element_type = self.get_optional_element_type(ot.elem_type)
return self.get_optional_type(element_type)
# TODO: Others if ever needed. Or we consider ourselves DNN-only.
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
def _sanitize_name(self, name):
if not name.isidentifier():
name = "_" + name
return re.sub("[:/]", "_", name)
def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
tensor_type = self.tensor_proto_to_builtin_type(tp)
if tp.HasField("raw_data"):
# Conveniently, DenseResourceElementsAttr shares the raw data
# format. We just give it maximum numeric alignment.
resource = DenseResourceElementsAttr.get_from_buffer(
tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8
)
return resource
else:
# We have to do a data type specific instantiation from proto fields.
# Since this is typically used for small tensor constants, we instantiate
# as a DenseElementsAttr.
handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type)
if handler is None:
raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}")
return handler(tp)
ELEM_TYPE_TO_IR_TYPE_CB = {
onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(),
onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),
onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8),
onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16),
onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16),
onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32),
onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64),
onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1),
onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(),
onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(),
onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32),
onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64),
onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()),
onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()),
onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
# Ommitted: STRING,
}
ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
),
onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, IntegerType.get_signed(64)),
IntegerAttr.get(
IntegerType.get_signed(64),
(
int.from_bytes(tp.raw_data, "little", signed=True)
if tp.HasField("raw_data")
else tp.int64_data[0]
),
),
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
}
# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr
# of the builtin tensor type for cases where the tensor data is inlined as typed
# values instead of raw_data.
ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get(
np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get(
np.packbits(
np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims),
axis=None,
bitorder="little",
),
signless=False,
),
onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT16: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.int16).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get(
np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims)
),
onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get(
# Special case. See proto
np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims),
signless=False,
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
),
# Intentionally unsupported: STRING
}
ELEM_TYPE_TO_NUMPY_DTYPE = {
onnx.TensorProto.DataType.FLOAT: np.float32,
onnx.TensorProto.DataType.UINT8: np.uint8,
onnx.TensorProto.DataType.INT8: np.int8,
onnx.TensorProto.DataType.UINT16: np.uint16,
onnx.TensorProto.DataType.INT16: np.int16,
onnx.TensorProto.DataType.INT32: np.int32,
onnx.TensorProto.DataType.INT64: np.int64,
onnx.TensorProto.DataType.BOOL: np.bool_,
onnx.TensorProto.DataType.FLOAT16: np.float16,
onnx.TensorProto.DataType.DOUBLE: np.float64,
onnx.TensorProto.DataType.UINT32: np.uint32,
onnx.TensorProto.DataType.UINT64: np.uint64,
onnx.TensorProto.DataType.COMPLEX64: np.complex64,
onnx.TensorProto.DataType.COMPLEX128: np.complex128,
# onnx.TensorProto.DataType.BFLOAT16:
# onnx.TensorProto.DataType.FLOAT8E4M3FN:
# onnx.TensorProto.DataType.FLOAT8E4M3FNUZ:
# onnx.TensorProto.DataType.FLOAT8E5M2:
# onnx.TensorProto.DataType.FLOAT8E5M2FNUZ:
# Ommitted: STRING,
}
# Mapping of AttributeType code to one of:
# None: Ignore attribute and do not output to MLIR
# False: Error if an attribute of this type is present
# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute
ATTRIBUTE_TYPE_HANDLERS = {
onnx.AttributeProto.AttributeType.UNDEFINED: False,
onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get(
F32Type.get(), a.f
),
onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get(
IntegerType.get_signed(64), a.i
),
onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s),
onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr(
a.t
),
onnx.AttributeProto.AttributeType.GRAPH: None,
onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False,
onnx.AttributeProto.AttributeType.TYPE_PROTO: False,
onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get(
[FloatAttr.get(F32Type.get(), f) for f in a.floats]
),
onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints]
),
onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get(
[StringAttr.get(s) for s in a.strings]
),
onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get(
[cc.tensor_proto_to_attr(t) for t in a.tensors]
),
onnx.AttributeProto.AttributeType.GRAPHS: False,
onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False,
onnx.AttributeProto.AttributeType.TYPE_PROTOS: False,
}
def _get_attr(
node: onnx.NodeProto, attr_name: str, is_required: bool = True
) -> onnx.AttributeProto:
for attr in node.attribute:
if attr.name == attr_name:
return attr
if is_required:
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")
return None