mirror of https://github.com/llvm/torch-mlir
[ONNX] Enables data propogation for onnx shape inference (#3280)
This small change seems to dramatically improve shape inference for complex models, and consequently, improves onnx importer reliability.pull/3300/head
parent
346a536c9f
commit
0abc5868b5
|
@ -27,6 +27,7 @@ from torch_mlir.ir import Context, Module
|
|||
def import_onnx(contents):
|
||||
# Import the ONNX model proto from the file contents:
|
||||
raw_model = onnx.load_from_string(contents)
|
||||
# since it does not affect current e2e tests, data_prop is left false here
|
||||
model_proto = onnx.shape_inference.infer_shapes(raw_model)
|
||||
|
||||
# Import the ONNX module into an MLIR module:
|
||||
|
|
|
@ -85,7 +85,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
|
|||
# in-memory shape inference. If not, go ahead and do the shape inference.
|
||||
try:
|
||||
onnx.checker.check_model(raw_model)
|
||||
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
|
||||
inferred_model = onnx.shape_inference.infer_shapes(
|
||||
raw_model, data_prop=args.data_prop
|
||||
)
|
||||
return inferred_model
|
||||
except ValueError:
|
||||
pass
|
||||
|
@ -103,7 +105,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
|
|||
# Model is too big for in-memory inference: do file-based shape inference
|
||||
# to a temp file.
|
||||
temp_inferred_file = temp_dir / "inferred.onnx"
|
||||
onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file)
|
||||
onnx.shape_inference.infer_shapes_path(
|
||||
args.input_file, temp_inferred_file, data_prop=args.data_prop
|
||||
)
|
||||
|
||||
# Sanity check the shape-inferred model to be sure we have a good model
|
||||
# for the importer. This call uses the file-based method, as the
|
||||
|
@ -138,6 +142,13 @@ def parse_arguments(argv=None) -> argparse.Namespace:
|
|||
action="store_true",
|
||||
help="Disable verification prior to printing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-prop",
|
||||
dest="data_prop",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Toggle data propogation for onnx shape inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-temps", action="store_true", help="Keep intermediate files"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue