mirror of https://github.com/llvm/torch-mlir
[ONNX] add some args to the onnx importer to assist shape_inference (#3445)
Adds the following arguments: - "--clear-domain": enabling this flag (default False) will delete the domain attribute from each node in the onnx model before importing. Shape inference does not seem to work for onnx ops in custom domains. In the rare case when these ops have a corresponding counterpart in base onnx, enabling this flag might allow shape inference to work properly. - "--opset-version": allows setting the opset version manually. This will cause the importer to attempt to update the opset_version of the onnx model before importing. Newer opset versions sometimes have more robust shape inference patterns.pull/3454/head
parent
de28c8540b
commit
c0eb6d89c0
|
@ -34,6 +34,7 @@ except ModuleNotFoundError as e:
|
|||
) from e
|
||||
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -579,6 +580,10 @@ class ContextCache:
|
|||
|
||||
def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
|
||||
if tp == "":
|
||||
warnings.warn(
|
||||
"Found a node without a valid type proto. Consider updating the opset_version of"
|
||||
" the model and/or running the importer with the flag '--clear-domain'."
|
||||
)
|
||||
return self.get_none_type()
|
||||
|
||||
tt = tp.tensor_type
|
||||
|
|
|
@ -20,6 +20,7 @@ import shutil
|
|||
import sys
|
||||
|
||||
import onnx
|
||||
import onnx.version
|
||||
|
||||
from ...extras import onnx_importer
|
||||
|
||||
|
@ -81,6 +82,16 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
|
|||
raw_model = onnx.load(args.input_file, load_external_data=False)
|
||||
onnx.load_external_data_for_model(raw_model, args.data_dir)
|
||||
|
||||
if args.opset_version:
|
||||
raw_model = onnx.version_converter.convert_version(
|
||||
raw_model, args.opset_version
|
||||
)
|
||||
|
||||
if args.clear_domain:
|
||||
graph = raw_model.graph
|
||||
for n in graph.node:
|
||||
n.ClearField("domain")
|
||||
|
||||
# Run the checker to test whether the file is above the threshold for
|
||||
# in-memory shape inference. If not, go ahead and do the shape inference.
|
||||
try:
|
||||
|
@ -149,6 +160,14 @@ def parse_arguments(argv=None) -> argparse.Namespace:
|
|||
action=argparse.BooleanOptionalAction,
|
||||
help="Toggle data propogation for onnx shape inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-domain",
|
||||
dest="clear_domain",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="If enabled, this will clear the domain attribute from each node"
|
||||
" in the onnx graph before performing shape inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-temps", action="store_true", help="Keep intermediate files"
|
||||
)
|
||||
|
@ -170,6 +189,12 @@ def parse_arguments(argv=None) -> argparse.Namespace:
|
|||
" Defaults to the directory of the input file.",
|
||||
type=Path,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opset-version",
|
||||
help="Allows specification of a newer opset_version to update the model"
|
||||
" to before importing to MLIR. This can sometime assist with shape inference.",
|
||||
type=int,
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
return args
|
||||
|
||||
|
|
Loading…
Reference in New Issue