[ONNX] add some args to the onnx importer to assist shape_inference (#3445)

Adds the following arguments:
- "--clear-domain": enabling this flag (default False) will delete the
domain attribute from each node in the onnx model before importing.
Shape inference does not seem to work for onnx ops in custom domains. In
the rare case when these ops have a corresponding counterpart in base
onnx, enabling this flag might allow shape inference to work properly.
- "--opset-version": allows setting the opset version manually. This
will cause the importer to attempt to update the opset_version of the
onnx model before importing. Newer opset versions sometimes have more
robust shape inference patterns.
pull/3454/head
zjgarvey 2024-06-12 10:55:14 -05:00 committed by GitHub
parent de28c8540b
commit c0eb6d89c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 0 deletions

View File

@ -34,6 +34,7 @@ except ModuleNotFoundError as e:
) from e
from typing import Optional, List, Dict, Tuple
import warnings
from dataclasses import dataclass
@ -579,6 +580,10 @@ class ContextCache:
def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
if tp == "":
warnings.warn(
"Found a node without a valid type proto. Consider updating the opset_version of"
" the model and/or running the importer with the flag '--clear-domain'."
)
return self.get_none_type()
tt = tp.tensor_type

View File

@ -20,6 +20,7 @@ import shutil
import sys
import onnx
import onnx.version
from ...extras import onnx_importer
@ -81,6 +82,16 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, args.data_dir)
if args.opset_version:
raw_model = onnx.version_converter.convert_version(
raw_model, args.opset_version
)
if args.clear_domain:
graph = raw_model.graph
for n in graph.node:
n.ClearField("domain")
# Run the checker to test whether the file is above the threshold for
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
@ -149,6 +160,14 @@ def parse_arguments(argv=None) -> argparse.Namespace:
action=argparse.BooleanOptionalAction,
help="Toggle data propogation for onnx shape inference",
)
parser.add_argument(
"--clear-domain",
dest="clear_domain",
default=False,
action=argparse.BooleanOptionalAction,
help="If enabled, this will clear the domain attribute from each node"
" in the onnx graph before performing shape inference.",
)
parser.add_argument(
"--keep-temps", action="store_true", help="Keep intermediate files"
)
@ -170,6 +189,12 @@ def parse_arguments(argv=None) -> argparse.Namespace:
" Defaults to the directory of the input file.",
type=Path,
)
parser.add_argument(
"--opset-version",
help="Allows specification of a newer opset_version to update the model"
" to before importing to MLIR. This can sometime assist with shape inference.",
type=int,
)
args = parser.parse_args(argv)
return args