mirror of https://github.com/llvm/torch-mlir
Create MLIR functions for ONNX operators that are functions (#3409)
Resolves #3384. Many ONNX operators are defined by functions and therefore could be expanded into simpler ONNX operations during importing, avoiding the need for tools downstream to support these operators directly. This commit adds this capability to onnx_importer.py. When importing a node, the schema for the node's operator is retrieved. If the schema provides a function for the operator, a specialized version for the node's types and attributes will be created and imported as an MLIR function with private visibility. An MLIR function call will then be emitted, instead of a normal operator node. Caching is used to avoid generating redundant functions within the same module. In order to avoid a disruptive change to the importer output for a large number of operators that already have TorchOnnxToTorch support, an allowlist strategy is used by default. With this commit, only one operator is allowlisted for expansion, MeanVarianceNormalization. However, many other operators can be correctly expanded by the current code, so hopefully the allowlist can be gradually extended. It is possible to disable the allowlist in the configuration, in which case all functions are expanded (useful for testing). Tools downstream of the importer may now need to do inlining when consuming the output of the importer, e.g.: cat imported.mlir | torch-mlir-opt --inline --convert-onnx-to-torch Explanations for subtle code changes: - Looking up the correct schema and function for an operator requires knowing the opset version. NodeImporter retrieves this from the opset imports on the ModelProto retained by the GraphInfo. Previously, the model_proto field on GraphInfo was None when importing a subgraph in import_regions, but this conflicts with the new need for opset version info. Since the apparent purpose of setting it to None was to control how GraphInfo generates its input map, a new flag is added to GraphInfo (is_subgraph) to control this behavior, so that the actual ModelProto can now be provided without breaking this. This also turned out to be useful for getting the Config via ModelInfo via GraphInfo. - Some operators' functions are context-dependent, which means the function definition depends on the types of the inputs. Therefore node importing now needs to look up the types of a node's inputs, not just its outputs as was the case previously. Consequently the operand to find_type_proto_for_name() may now be a graph input or initializer in some cases, so it has to be updated.pull/3461/merge
parent
d2b663ece7
commit
51902ec2dc
|
@ -97,7 +97,10 @@ def _module_lowering(
|
||||||
# Lower from ONNX to Torch
|
# Lower from ONNX to Torch
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
torch_mod,
|
torch_mod,
|
||||||
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
# The importer may produce additional MLIR functions corresponding to
|
||||||
|
# ONNX operators that are functions. In some cases they need to be
|
||||||
|
# inlined to avoid the backend choking on them.
|
||||||
|
f"builtin.module(inline, func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||||
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
|
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ except ModuleNotFoundError as e:
|
||||||
from typing import Optional, List, Dict, Tuple
|
from typing import Optional, List, Dict, Tuple
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
@ -91,6 +91,45 @@ class Config:
|
||||||
# making an assumption.
|
# making an assumption.
|
||||||
elide_initialized_inputs: bool = True
|
elide_initialized_inputs: bool = True
|
||||||
|
|
||||||
|
# Some ONNX operators are defined by ONNX functions and will be
|
||||||
|
# automatically expanded (see get_operator_function() below) to MLIR
|
||||||
|
# functions by the importer. This option allows allowlisting functions that
|
||||||
|
# should be expanded. If this is None, then allowlisting is not used (all
|
||||||
|
# functions not explicitly denylisted will be expanded).
|
||||||
|
#
|
||||||
|
# Since function expansion has not always been supported, the default should
|
||||||
|
# be to use allowlisting, to avoid disruption.
|
||||||
|
function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Default domain (ONNX built-in ops)
|
||||||
|
"": {
|
||||||
|
"MeanVarianceNormalization",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some ONNX operators are defined by ONNX functions and will be
|
||||||
|
# automatically expanded (see get_operator_function() below) to MLIR
|
||||||
|
# functions by the importer. This option allows denylisting functions that
|
||||||
|
# should not be expanded.
|
||||||
|
function_expansion_denylists_by_domain: Dict[str, set[str]] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# Default domain (ONNX built-in ops)
|
||||||
|
"": {
|
||||||
|
# CastLike's second input `target_type` is used only for its
|
||||||
|
# type (T2), from which its output's type is inferred, but
|
||||||
|
# because its value is unused, ONNX's shape inference doesn't
|
||||||
|
# annotate the input value with a type, so looking up the
|
||||||
|
# function by the provided input types will fail.
|
||||||
|
"CastLike",
|
||||||
|
# ONNX errors when trying to infer the type of the Loop op
|
||||||
|
# within this function: "[ShapeInferenceError] Inferred shape
|
||||||
|
# and existing shape differ in rank: (1) vs (0)"
|
||||||
|
"Range",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
"""Top-level accounting and accessors for an ONNX model."""
|
"""Top-level accounting and accessors for an ONNX model."""
|
||||||
|
@ -112,7 +151,12 @@ class ModelInfo:
|
||||||
class GraphInfo:
|
class GraphInfo:
|
||||||
"""Information about a Graph within a model."""
|
"""Information about a Graph within a model."""
|
||||||
|
|
||||||
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_info: ModelInfo,
|
||||||
|
graph_proto: onnx.GraphProto,
|
||||||
|
is_subgraph: bool = False,
|
||||||
|
):
|
||||||
self.model_info = model_info
|
self.model_info = model_info
|
||||||
self.graph_proto = graph_proto
|
self.graph_proto = graph_proto
|
||||||
self.initializer_map: Dict[str, onnx.TensorProto] = {
|
self.initializer_map: Dict[str, onnx.TensorProto] = {
|
||||||
|
@ -130,7 +174,11 @@ class GraphInfo:
|
||||||
|
|
||||||
# Generate the effective input map, which for old models can be a
|
# Generate the effective input map, which for old models can be a
|
||||||
# subset of the input map.
|
# subset of the input map.
|
||||||
if model_info and model_info.config.elide_initialized_inputs:
|
if (
|
||||||
|
not is_subgraph
|
||||||
|
and model_info
|
||||||
|
and model_info.config.elide_initialized_inputs
|
||||||
|
):
|
||||||
self.input_map = {
|
self.input_map = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in self.declared_input_map.items()
|
for k, v in self.declared_input_map.items()
|
||||||
|
@ -150,9 +198,20 @@ class GraphInfo:
|
||||||
# Node outputs don't typically have type information, but shape inference
|
# Node outputs don't typically have type information, but shape inference
|
||||||
# will associate them in the value_info. If not there, it may be a
|
# will associate them in the value_info. If not there, it may be a
|
||||||
# graph output, which must have type information.
|
# graph output, which must have type information.
|
||||||
value_info = self.value_info_map.get(name) or self.output_map.get(name)
|
value_info = (
|
||||||
|
self.value_info_map.get(name)
|
||||||
|
or self.output_map.get(name)
|
||||||
|
or self.declared_input_map.get(name)
|
||||||
|
)
|
||||||
if value_info is not None:
|
if value_info is not None:
|
||||||
return value_info.type
|
return value_info.type
|
||||||
|
|
||||||
|
tensor_proto = self.initializer_map.get(name)
|
||||||
|
if tensor_proto is not None:
|
||||||
|
return onnx.helper.make_tensor_type_proto(
|
||||||
|
tensor_proto.data_type, tensor_proto.dims
|
||||||
|
)
|
||||||
|
|
||||||
# No type information is associated, this can occur when the value is unused:
|
# No type information is associated, this can occur when the value is unused:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -173,6 +232,8 @@ class NodeImporter:
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"_c",
|
"_c",
|
||||||
"_cc",
|
"_cc",
|
||||||
|
"_m",
|
||||||
|
"_mc",
|
||||||
"_gi",
|
"_gi",
|
||||||
"_p",
|
"_p",
|
||||||
"_b",
|
"_b",
|
||||||
|
@ -186,9 +247,13 @@ class NodeImporter:
|
||||||
parent_op: Operation,
|
parent_op: Operation,
|
||||||
block: Block,
|
block: Block,
|
||||||
context_cache: "ContextCache",
|
context_cache: "ContextCache",
|
||||||
|
module_op: Operation,
|
||||||
|
module_cache: "ModuleCache",
|
||||||
):
|
):
|
||||||
self._c = parent_op.context
|
self._c = parent_op.context
|
||||||
self._cc = context_cache
|
self._cc = context_cache
|
||||||
|
self._m = module_op
|
||||||
|
self._mc = module_cache
|
||||||
self._gi = graph_info
|
self._gi = graph_info
|
||||||
self._p = parent_op
|
self._p = parent_op
|
||||||
self._b = block
|
self._b = block
|
||||||
|
@ -196,9 +261,19 @@ class NodeImporter:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_function(
|
def define_function(
|
||||||
cls, graph_info: GraphInfo, module_op: Operation
|
cls,
|
||||||
|
graph_info: GraphInfo,
|
||||||
|
module_op: Operation,
|
||||||
|
context_cache: Optional["ContextCache"] = None,
|
||||||
|
module_cache: Optional["ModuleCache"] = None,
|
||||||
|
private: bool = False,
|
||||||
) -> "NodeImporter":
|
) -> "NodeImporter":
|
||||||
cc = ContextCache(module_op.context)
|
cc = (
|
||||||
|
context_cache
|
||||||
|
if context_cache is not None
|
||||||
|
else ContextCache(module_op.context)
|
||||||
|
)
|
||||||
|
mc = module_cache if module_cache is not None else ModuleCache(module_op, cc)
|
||||||
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
|
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
|
||||||
body = module_op.regions[0].blocks[0]
|
body = module_op.regions[0].blocks[0]
|
||||||
func_name = graph_info.graph_proto.name
|
func_name = graph_info.graph_proto.name
|
||||||
|
@ -210,11 +285,23 @@ class NodeImporter:
|
||||||
for out in graph_info.output_map.values()
|
for out in graph_info.output_map.values()
|
||||||
]
|
]
|
||||||
ftype = FunctionType.get(input_types, output_types)
|
ftype = FunctionType.get(input_types, output_types)
|
||||||
func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body))
|
func_op = func_dialect.FuncOp(
|
||||||
|
func_name,
|
||||||
|
ftype,
|
||||||
|
ip=InsertionPoint(body),
|
||||||
|
visibility="private" if private else None,
|
||||||
|
)
|
||||||
block = func_op.add_entry_block(
|
block = func_op.add_entry_block(
|
||||||
[Location.name(k) for k in graph_info.input_map.keys()]
|
[Location.name(k) for k in graph_info.input_map.keys()]
|
||||||
)
|
)
|
||||||
imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc)
|
imp = NodeImporter(
|
||||||
|
graph_info,
|
||||||
|
parent_op=func_op,
|
||||||
|
block=block,
|
||||||
|
context_cache=cc,
|
||||||
|
module_op=module_op,
|
||||||
|
module_cache=mc,
|
||||||
|
)
|
||||||
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
|
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
|
||||||
imp._nv_map[node_name] = input_value
|
imp._nv_map[node_name] = input_value
|
||||||
imp._populate_graph_attrs(func_op)
|
imp._populate_graph_attrs(func_op)
|
||||||
|
@ -294,6 +381,8 @@ class NodeImporter:
|
||||||
def import_node(self, node: onnx.NodeProto):
|
def import_node(self, node: onnx.NodeProto):
|
||||||
with InsertionPoint(self._b), Location.name(node.name):
|
with InsertionPoint(self._b), Location.name(node.name):
|
||||||
op_type = node.op_type
|
op_type = node.op_type
|
||||||
|
op_domain = node.domain
|
||||||
|
|
||||||
# Handle special op types that materialize to non-op IR constructs.
|
# Handle special op types that materialize to non-op IR constructs.
|
||||||
# Handlers return True if the op was handled, else this function
|
# Handlers return True if the op was handled, else this function
|
||||||
# should process it as a general node.
|
# should process it as a general node.
|
||||||
|
@ -304,33 +393,58 @@ class NodeImporter:
|
||||||
return
|
return
|
||||||
# General node import.
|
# General node import.
|
||||||
input_values = []
|
input_values = []
|
||||||
|
input_type_protos = []
|
||||||
for input_name in node.input:
|
for input_name in node.input:
|
||||||
try:
|
try:
|
||||||
input_values.append(self._nv_map[input_name])
|
input_values.append(self._nv_map[input_name])
|
||||||
|
# Missing optional arguments will have empty types
|
||||||
|
input_type_protos.append(
|
||||||
|
self._gi.find_type_proto_for_name(input_name)
|
||||||
|
or onnx.TypeProto()
|
||||||
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise OnnxImportError(
|
raise OnnxImportError(
|
||||||
f"Non topologically produced ONNX node input '{input_name}': {node}"
|
f"Non topologically produced ONNX node input '{input_name}': {node}"
|
||||||
)
|
)
|
||||||
|
|
||||||
output_names = list(node.output)
|
output_names = []
|
||||||
output_types = [
|
output_type_protos = []
|
||||||
self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n))
|
output_types = []
|
||||||
for n in output_names
|
for output_name in node.output:
|
||||||
]
|
output_names.append(output_name)
|
||||||
|
type_proto = self._gi.find_type_proto_for_name(output_name)
|
||||||
|
output_type_protos.append(type_proto)
|
||||||
|
output_types.append(self._cc.type_proto_to_type(type_proto))
|
||||||
|
|
||||||
attrs = self.import_attributes(node.attribute)
|
for opset_import in self._gi.model_info.model_proto.opset_import:
|
||||||
attrs["name"] = StringAttr.get(f"onnx.{op_type}")
|
if opset_import.domain == op_domain:
|
||||||
regions = self.count_regions(node.attribute)
|
opset_version = opset_import.version
|
||||||
|
break
|
||||||
custom_op = Operation.create(
|
operator_func_op = self._mc.get_operator_function(
|
||||||
name="torch.operator",
|
op_type,
|
||||||
results=output_types,
|
op_domain,
|
||||||
operands=input_values,
|
opset_version,
|
||||||
attributes=attrs,
|
input_type_protos,
|
||||||
regions=regions,
|
output_type_protos,
|
||||||
|
node,
|
||||||
|
self._gi.model_info.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.import_regions(node.attribute, custom_op)
|
if operator_func_op is not None:
|
||||||
|
custom_op = func_dialect.CallOp(operator_func_op, input_values)
|
||||||
|
else:
|
||||||
|
attrs = self.import_attributes(node.attribute)
|
||||||
|
attrs["name"] = StringAttr.get(f"onnx.{op_type}")
|
||||||
|
regions = self.count_regions(node.attribute)
|
||||||
|
custom_op = Operation.create(
|
||||||
|
name="torch.operator",
|
||||||
|
results=output_types,
|
||||||
|
operands=input_values,
|
||||||
|
attributes=attrs,
|
||||||
|
regions=regions,
|
||||||
|
)
|
||||||
|
self.import_regions(node.attribute, custom_op)
|
||||||
|
|
||||||
for output_name, output_value in zip(output_names, custom_op.results):
|
for output_name, output_value in zip(output_names, custom_op.results):
|
||||||
self._nv_map[output_name] = output_value
|
self._nv_map[output_name] = output_value
|
||||||
|
|
||||||
|
@ -388,9 +502,14 @@ class NodeImporter:
|
||||||
*block_types, arg_locs=[op.location] * len(block_types)
|
*block_types, arg_locs=[op.location] * len(block_types)
|
||||||
)
|
)
|
||||||
block = region.blocks[0]
|
block = region.blocks[0]
|
||||||
graph_info = GraphInfo(None, attr.g)
|
graph_info = GraphInfo(self._gi.model_info, attr.g, is_subgraph=True)
|
||||||
imp = NodeImporter(
|
imp = NodeImporter(
|
||||||
graph_info, parent_op=op, block=block, context_cache=self._cc
|
graph_info,
|
||||||
|
parent_op=op,
|
||||||
|
block=block,
|
||||||
|
context_cache=self._cc,
|
||||||
|
module_op=self._m,
|
||||||
|
module_cache=self._mc,
|
||||||
)
|
)
|
||||||
|
|
||||||
for node_name, input_value in zip(block_names, block.arguments):
|
for node_name, input_value in zip(block_names, block.arguments):
|
||||||
|
@ -608,6 +727,11 @@ class ContextCache:
|
||||||
element_type = self.get_optional_element_type(ot.elem_type)
|
element_type = self.get_optional_element_type(ot.elem_type)
|
||||||
return self.get_optional_type(element_type)
|
return self.get_optional_type(element_type)
|
||||||
|
|
||||||
|
# Check if TypeProto is empty (sometimes happens for unused function
|
||||||
|
# arguments)
|
||||||
|
if tp.WhichOneof("value") is None:
|
||||||
|
return self.get_none_type()
|
||||||
|
|
||||||
# TODO: Others if ever needed. Or we consider ourselves DNN-only.
|
# TODO: Others if ever needed. Or we consider ourselves DNN-only.
|
||||||
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
|
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
|
||||||
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
|
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
|
||||||
|
@ -636,6 +760,323 @@ class ContextCache:
|
||||||
return handler(tp)
|
return handler(tp)
|
||||||
|
|
||||||
|
|
||||||
|
def _shallow_copy_and_clear_protobuf_list(protobuf_list) -> list:
|
||||||
|
"""
|
||||||
|
Workaround for .clear() not being available on protobuf lists for some
|
||||||
|
reason.
|
||||||
|
"""
|
||||||
|
copy = list(protobuf_list)
|
||||||
|
while len(protobuf_list) > 0:
|
||||||
|
protobuf_list.pop()
|
||||||
|
return copy
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_attributes_on_node(
|
||||||
|
interior_node: onnx.NodeProto,
|
||||||
|
caller_node: onnx.NodeProto,
|
||||||
|
op_schema: onnx.defs.OpSchema,
|
||||||
|
) -> onnx.NodeProto:
|
||||||
|
"""
|
||||||
|
Helper for _specialize_function_and_create_model() that binds concrete
|
||||||
|
values to an attributes on a node in the interior of a function.
|
||||||
|
|
||||||
|
This should behave the same as ONNX's C++ attribute binder, please use it as
|
||||||
|
a reference: https://github.com/onnx/onnx/blob/88f8ef15cfaa3138d336f3502aed5018d802bf43/onnx/shape_inference/attribute_binder.h#L15-L64
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _bind_attributes_in_subgraph(
|
||||||
|
old_subgraph: onnx.GraphProto,
|
||||||
|
caller_node: onnx.NodeProto,
|
||||||
|
op_schema: onnx.defs.OpSchema,
|
||||||
|
) -> onnx.GraphProto:
|
||||||
|
"""
|
||||||
|
Recurse to bind attributes in a subgraph.
|
||||||
|
"""
|
||||||
|
new_subgraph.CopyFrom(old_subgraph)
|
||||||
|
old_nodes = _shallow_copy_and_clear_protobuf_list(new_subgraph.node)
|
||||||
|
for old_node in old_nodes:
|
||||||
|
new_subgraph.node.append(
|
||||||
|
_bind_attributes_on_node(old_node, caller_node, op_schema)
|
||||||
|
)
|
||||||
|
return new_subgraph
|
||||||
|
|
||||||
|
def _bind_attribute(
|
||||||
|
old_attribute: onnx.AttributeProto,
|
||||||
|
caller_node: onnx.NodeProto,
|
||||||
|
op_schema: onnx.defs.OpSchema,
|
||||||
|
) -> Optional[onnx.AttributeProto]:
|
||||||
|
"""
|
||||||
|
Bind a single attribute.
|
||||||
|
|
||||||
|
Bound values either come from attributes on the node calling the
|
||||||
|
function, or from default values. If the attribute is optional and has
|
||||||
|
no default value, and no value was provided by the caller, None is
|
||||||
|
returned and the attribute should be removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ref_name = old_attribute.ref_attr_name
|
||||||
|
if not ref_name:
|
||||||
|
if not old_attribute.g or len(old_attribute.graphs) == 0:
|
||||||
|
return old_attribute
|
||||||
|
|
||||||
|
# Recurse to bind attributes on subgraphs. ONNX's implementation of
|
||||||
|
# attribute binding only does this for subgraphs that didn't come
|
||||||
|
# from a referenced attribute value, so this code doesn't either.
|
||||||
|
new_attribute = onnx.AttributeProto()
|
||||||
|
new_attribute.CopyFrom(old_attribute)
|
||||||
|
if new_attribute.g:
|
||||||
|
new_attribute.g = _bind_attributes_in_subgraph(
|
||||||
|
new_attribute.g, caller_node, op_schema
|
||||||
|
)
|
||||||
|
if new_attribute.graphs:
|
||||||
|
old_subgraphs = _shallow_copy_and_clear_protobuf_list(
|
||||||
|
new_attribute.graphs
|
||||||
|
)
|
||||||
|
for old_subgraph in old_subgraphs:
|
||||||
|
new_attribute.graphs.append(
|
||||||
|
_bind_attributes_in_subgraph(
|
||||||
|
old_subgraph, caller_node, op_schema
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return new_attribute
|
||||||
|
|
||||||
|
for call_attribute in caller_node.attribute:
|
||||||
|
if call_attribute.name == ref_name:
|
||||||
|
new_attribute = onnx.AttributeProto()
|
||||||
|
new_attribute.CopyFrom(call_attribute)
|
||||||
|
new_attribute.name = old_attribute.name
|
||||||
|
return new_attribute
|
||||||
|
|
||||||
|
# The default value is sometimes empty for optional attributes
|
||||||
|
# that don't have a default, in which case it is dropped.
|
||||||
|
default_value = op_schema.attributes[ref_name].default_value
|
||||||
|
if default_value and default_value.type:
|
||||||
|
new_attribute = onnx.AttributeProto()
|
||||||
|
new_attribute.CopyFrom(default_value)
|
||||||
|
new_attribute.name = old_attribute.name
|
||||||
|
return new_attribute
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_node = onnx.NodeProto()
|
||||||
|
new_node.CopyFrom(interior_node)
|
||||||
|
old_attributes = _shallow_copy_and_clear_protobuf_list(new_node.attribute)
|
||||||
|
for node_attribute in old_attributes:
|
||||||
|
new_attribute = _bind_attribute(node_attribute, caller_node, op_schema)
|
||||||
|
if new_attribute is not None:
|
||||||
|
new_node.attribute.append(new_attribute)
|
||||||
|
continue
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
|
||||||
|
def _specialize_function_and_create_model(
|
||||||
|
function_proto: onnx.FunctionProto,
|
||||||
|
op_schema: onnx.defs.OpSchema,
|
||||||
|
name_to_give_model: str,
|
||||||
|
input_type_protos: list[onnx.TypeProto],
|
||||||
|
output_type_protos: list[onnx.TypeProto],
|
||||||
|
caller_node: onnx.NodeProto,
|
||||||
|
) -> onnx.ModelProto:
|
||||||
|
"""
|
||||||
|
Helper for ModuleCache::get_operator_function() that specializes a function
|
||||||
|
and coverts it to a model.
|
||||||
|
|
||||||
|
An ONNX function may be polymorphic, parameterized over the types of its
|
||||||
|
inputs and values of its attributes (~= compile-time constants). We need to
|
||||||
|
monomorphize it for importing into MLIR. It seems like the only practical
|
||||||
|
way to do this is by turning it into a model:
|
||||||
|
- models can have types on their inputs and outputs, unlike functions
|
||||||
|
- ONNX provides a function to do shape inference (providing concrete
|
||||||
|
types for everything in the body) for models, but not for functions
|
||||||
|
- the rest of the code in this importer can only handle models, not
|
||||||
|
functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_proto = onnx.GraphProto()
|
||||||
|
|
||||||
|
for input_name, input_type_proto in zip(function_proto.input, input_type_protos):
|
||||||
|
input_proto = onnx.ValueInfoProto()
|
||||||
|
input_proto.name = input_name
|
||||||
|
input_proto.type.CopyFrom(input_type_proto)
|
||||||
|
graph_proto.input.append(input_proto)
|
||||||
|
output_proto = onnx.ValueInfoProto()
|
||||||
|
|
||||||
|
for output_name, output_type_proto in zip(
|
||||||
|
function_proto.output, output_type_protos
|
||||||
|
):
|
||||||
|
output_proto.name = output_name
|
||||||
|
output_proto.type.CopyFrom(output_type_proto)
|
||||||
|
graph_proto.output.append(output_proto)
|
||||||
|
|
||||||
|
for node in function_proto.node:
|
||||||
|
# Import referenced attributes from call-site or default values
|
||||||
|
graph_proto.node.append(_bind_attributes_on_node(node, caller_node, op_schema))
|
||||||
|
|
||||||
|
graph_proto.name = name_to_give_model
|
||||||
|
|
||||||
|
model_proto = onnx.ModelProto()
|
||||||
|
model_proto.opset_import.extend(function_proto.opset_import)
|
||||||
|
# FIXME: is this the correct IR version, or should it be the latest, or the
|
||||||
|
# one used by the actual model, or something else?
|
||||||
|
model_proto.ir_version = onnx.helper.find_min_ir_version_for(
|
||||||
|
function_proto.opset_import
|
||||||
|
)
|
||||||
|
model_proto.graph.CopyFrom(graph_proto)
|
||||||
|
|
||||||
|
model_proto = onnx.shape_inference.infer_shapes(
|
||||||
|
model_proto, check_type=True, strict_mode=True, data_prop=True
|
||||||
|
)
|
||||||
|
graph_proto = model_proto.graph
|
||||||
|
|
||||||
|
# Useful for debugging.
|
||||||
|
# onnx.checker.check_model(model_proto, full_check=True)
|
||||||
|
|
||||||
|
return model_proto
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleCache:
|
||||||
|
"""Caches per-module lookups of various things."""
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"_m",
|
||||||
|
"_cc",
|
||||||
|
"_operator_function_map",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, module_op: Operation, context_cache: ContextCache):
|
||||||
|
self._m = module_op
|
||||||
|
self._cc = context_cache
|
||||||
|
self._operator_function_map: Dict[str, func_dialect.FuncOp] = {}
|
||||||
|
|
||||||
|
def get_operator_function(
|
||||||
|
self,
|
||||||
|
op_name: str,
|
||||||
|
op_domain: str,
|
||||||
|
opset_version: int,
|
||||||
|
input_type_protos: list[onnx.TypeProto],
|
||||||
|
output_type_protos: list[onnx.TypeProto],
|
||||||
|
caller_node: onnx.NodeProto,
|
||||||
|
config: Config,
|
||||||
|
) -> Optional[func_dialect.FuncOp]:
|
||||||
|
"""
|
||||||
|
Get or create MLIR function corresponding to an ONNX operator.
|
||||||
|
|
||||||
|
Returns None for ONNX operators that aren't functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
allowlists = config.function_expansion_allowlists_by_domain
|
||||||
|
denylists = config.function_expansion_denylists_by_domain
|
||||||
|
|
||||||
|
if allowlists is not None and not (
|
||||||
|
op_domain in allowlists and op_name in allowlists[op_domain]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if op_domain in denylists and op_name in denylists[op_domain]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
op_schema = onnx.defs.get_schema(
|
||||||
|
op_name, domain=op_domain, max_inclusive_version=opset_version
|
||||||
|
)
|
||||||
|
|
||||||
|
# The get_schema() lookup above should get the right version of the
|
||||||
|
# operator definition, but the function body can change slightly
|
||||||
|
# within a single operator version, as explained in
|
||||||
|
# https://github.com/onnx/onnx/blob/093a8d335a66ea136eb1f16b3a1ce6237ee353ab/onnx/defs/schema.h#L1070-L1086
|
||||||
|
# There also seem to be cases where a function goes from being not
|
||||||
|
# context-dependent to context-dependent.
|
||||||
|
f = lambda ver: ver <= opset_version
|
||||||
|
ncd_function_version = max(
|
||||||
|
filter(f, op_schema.function_opset_versions),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
cd_function_version = max(
|
||||||
|
filter(f, op_schema.context_dependent_function_opset_versions),
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
if ncd_function_version is None and cd_function_version is None:
|
||||||
|
# No relevant function definition
|
||||||
|
return None
|
||||||
|
if ncd_function_version is not None and (
|
||||||
|
cd_function_version is None or cd_function_version < ncd_function_version
|
||||||
|
):
|
||||||
|
specific_version = ncd_function_version
|
||||||
|
is_context_dependent = False
|
||||||
|
else:
|
||||||
|
specific_version = cd_function_version
|
||||||
|
is_context_dependent = True
|
||||||
|
|
||||||
|
# This is both a key for memoization of function importing and also a
|
||||||
|
# name mangling scheme, so it must include all information needed to
|
||||||
|
# uniquely identify a function and anything it might be parameterized
|
||||||
|
# over.
|
||||||
|
key = repr(
|
||||||
|
(
|
||||||
|
op_name,
|
||||||
|
op_domain,
|
||||||
|
opset_version,
|
||||||
|
input_type_protos,
|
||||||
|
# Though output types can be inferred from input types, it does
|
||||||
|
# not seem to be the case that there's only one legal set of
|
||||||
|
# outputs for a given set of inputs. When attemtping to always
|
||||||
|
# use onnx.shape_inference.infer_function_output_types instead
|
||||||
|
# of the caller-provided types, sometimes IR verification fails
|
||||||
|
output_type_protos,
|
||||||
|
# Avoid including the attributes twice (once on their own and
|
||||||
|
# once as part of the node) for context-dependent functions,
|
||||||
|
# avoid including unused parts of the node for other functions.
|
||||||
|
caller_node if is_context_dependent else caller_node.attribute,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = self._operator_function_map.get(key)
|
||||||
|
if existing is not None:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
if is_context_dependent:
|
||||||
|
function_proto_str = (
|
||||||
|
op_schema.get_context_dependent_function_with_opset_version(
|
||||||
|
specific_version,
|
||||||
|
caller_node.SerializeToString(),
|
||||||
|
[
|
||||||
|
t.SerializeToString() if not isinstance(t, bytes) else t
|
||||||
|
for t in input_type_protos
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
function_proto_str = op_schema.get_function_with_opset_version(
|
||||||
|
specific_version
|
||||||
|
)
|
||||||
|
if not function_proto_str:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Function lookup for {op_name}/{op_domain}/{specific_version}/{is_context_dependent} failed unexpectedly. This probably indicates a bug."
|
||||||
|
)
|
||||||
|
function_proto = onnx.onnx_pb.FunctionProto()
|
||||||
|
function_proto.ParseFromString(function_proto_str)
|
||||||
|
|
||||||
|
tmp_model_proto = _specialize_function_and_create_model(
|
||||||
|
function_proto,
|
||||||
|
op_schema,
|
||||||
|
key,
|
||||||
|
input_type_protos,
|
||||||
|
output_type_protos,
|
||||||
|
caller_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
tmp_model_info = ModelInfo(tmp_model_proto)
|
||||||
|
tmp_graph_info = GraphInfo(tmp_model_info, tmp_model_proto.graph)
|
||||||
|
# Mark function as private so it will be thrown away after inlining
|
||||||
|
imp = NodeImporter.define_function(
|
||||||
|
tmp_graph_info, self._m, self._cc, self, private=True
|
||||||
|
)
|
||||||
|
imp.import_all()
|
||||||
|
func_op = imp._p
|
||||||
|
|
||||||
|
self._operator_function_map[key] = func_op
|
||||||
|
return func_op
|
||||||
|
|
||||||
|
|
||||||
ELEM_TYPE_TO_IR_TYPE_CB = {
|
ELEM_TYPE_TO_IR_TYPE_CB = {
|
||||||
onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(),
|
onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(),
|
||||||
onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),
|
onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),
|
||||||
|
|
|
@ -31,10 +31,14 @@ from ...ir import (
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
|
config = onnx_importer.Config()
|
||||||
|
if args.disable_function_expansion_allowlist:
|
||||||
|
config.function_expansion_allowlists_by_domain = None
|
||||||
|
|
||||||
model_proto = load_onnx_model(args)
|
model_proto = load_onnx_model(args)
|
||||||
context = Context()
|
context = Context()
|
||||||
torch_d.register_dialect(context)
|
torch_d.register_dialect(context)
|
||||||
model_info = onnx_importer.ModelInfo(model_proto)
|
model_info = onnx_importer.ModelInfo(model_proto, config=config)
|
||||||
m = model_info.create_module(context=context).operation
|
m = model_info.create_module(context=context).operation
|
||||||
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
||||||
imp.import_all()
|
imp.import_all()
|
||||||
|
@ -195,6 +199,12 @@ def parse_arguments(argv=None) -> argparse.Namespace:
|
||||||
" to before importing to MLIR. This can sometime assist with shape inference.",
|
" to before importing to MLIR. This can sometime assist with shape inference.",
|
||||||
type=int,
|
type=int,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-function-expansion-allowlist",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable the allowlist for ONNX function expansion,"
|
||||||
|
" allowing non-allowlisted functions to be expanded.",
|
||||||
|
)
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
# Test that expansion of ONNX operators that are functions works for a simple
|
||||||
|
# example. The exact name mangling scheme used is not matched against, all that
|
||||||
|
# matters is that it has the name of the operator (GreaterOrEqual here) in it.
|
||||||
|
# Attributes are also not checked here. What we are interested in is the types
|
||||||
|
# and operations.
|
||||||
|
#
|
||||||
|
# The model comes from an upstream ONNX test: backend/test/data/node/test_greater_equal/model.onnx
|
||||||
|
|
||||||
|
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
|
||||||
|
|
||||||
|
# CHECK-LABEL: func.func @test_greater_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
# CHECK: %0 = call @"{{.*}}GreaterOrEqual{{.*}}"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
|
||||||
|
# CHECK-LABEL: func.func private @"{{.*}}GreaterOrEqual{{.*}}"(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
# CHECK: %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
# CHECK: %1 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
# CHECK: %2 = torch.operator "onnx.Or"(%0, %1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1>
|
||||||
|
# CHECK: return %2 : !torch.vtensor<[3,4,5],i1>
|
Binary file not shown.
|
@ -0,0 +1,22 @@
|
||||||
|
# Test the expansion of ONNX operators that are functions, specifically the
|
||||||
|
# propagation of attribute values from the call-site to nodes within the
|
||||||
|
# expanded function.
|
||||||
|
#
|
||||||
|
# In this case, the model has a ReduceSumSquare node with the attribute
|
||||||
|
# 'keepdims' set to 0, and the definition of this version of ReduceSumSquare
|
||||||
|
# contains a ReduceSum node that references the value of 'keepdims', so we
|
||||||
|
# expect to see this value propagated to the ReduceSum node in the expansion.
|
||||||
|
#
|
||||||
|
# This also tests that the absence of 'axes' (as an optional attribute with no
|
||||||
|
# default value) is propagated in the same way.
|
||||||
|
#
|
||||||
|
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx
|
||||||
|
|
||||||
|
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
|
||||||
|
#
|
||||||
|
# CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example
|
||||||
|
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
|
||||||
|
#
|
||||||
|
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
|
||||||
|
# CHECK: %0 = torch.operator "onnx.Mul"
|
||||||
|
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64}
|
Binary file not shown.
|
@ -0,0 +1,23 @@
|
||||||
|
# Test the expansion of ONNX operators that are functions, specifically the
|
||||||
|
# propagation of attribute values from the call-site to nodes within the
|
||||||
|
# expanded function.
|
||||||
|
#
|
||||||
|
# In this case, the model has a ReduceSumSquare node with no attributes, but the
|
||||||
|
# definition of this version of ReduceSumSquare contains a ReduceSum node that
|
||||||
|
# references the value of 'keepdims', and the definition says its default value
|
||||||
|
# is 1, so we expect to see this value propagated to the ReduceSum node in the
|
||||||
|
# expansion.
|
||||||
|
#
|
||||||
|
# This also tests that the absence of 'axes' (as an optional attribute with no
|
||||||
|
# default value) is propagated in the same way.
|
||||||
|
#
|
||||||
|
# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx
|
||||||
|
|
||||||
|
# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s
|
||||||
|
#
|
||||||
|
# CHECK-LABEL: func.func @test_reduce_sum_square_empty_set
|
||||||
|
# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}"
|
||||||
|
#
|
||||||
|
# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}"
|
||||||
|
# CHECK: %0 = torch.operator "onnx.Mul"
|
||||||
|
# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64}
|
Binary file not shown.
Loading…
Reference in New Issue