[onnx] Add torch-mlir-import-onnx tool. (#2637)

Simple Python console script to import an ONNX protobuf to the torch
dialect for additional processing.

For installed wheels, this can be used with something like:

```
torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx
```

Or from a dev setup:

```
python -m torch_mlir.tools.import_onnx ...
```
pull/2640/head snapshot-20231213.1051
Stella Laurenzo 2023-12-12 22:01:30 -08:00 committed by GitHub
parent 7cf52ae73f
commit ed4df38e8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 2 deletions

View File

@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
extras/onnx_importer.py
)
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
tools/import_onnx/__main__.py
)
################################################################################
# Extensions
################################################################################

View File

@ -0,0 +1,77 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""Console tool for converting an ONNX proto to torch IR.
Typically, when installed from a wheel, this can be invoked as:
torch-mlir-import-onnx some.pb
Or from Python:
python -m torch_mlir.tools.import_onnx ...
"""
import argparse
from pathlib import Path
import sys
import onnx
from ...extras import onnx_importer
from ...dialects import torch as torch_d
from ...ir import (
Context,
)
def main(args):
model_proto = load_onnx_model(args.input_file)
context = Context()
torch_d.register_dialect(context)
model_info = onnx_importer.ModelInfo(model_proto)
m = model_info.create_module(context=context)
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()
if not args.no_verify:
m.verify()
# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
# some copies.
if args.output_file and args.output_file != "-":
with open(args.output_file, "wt") as f:
print(m.get_asm(assume_verified=not args.no_verify), file=f)
else:
print(m.get_asm(assume_verified=not args.no_verify))
def load_onnx_model(file_path: Path) -> onnx.ModelProto:
raw_model = onnx.load(file_path)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
return inferred_model
def parse_arguments(argv=None):
parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool")
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
parser.add_argument(
"-o", dest="output_file", help="Output path (or '-' for stdout)"
)
parser.add_argument(
"--no-verify",
action="store_true",
help="Disable verification prior to printing",
)
args = parser.parse_args(argv)
return args
def _cli_main():
sys.exit(main(parse_arguments()))
if __name__ == "__main__":
_cli_main()

View File

@ -186,6 +186,11 @@ setup(
"onnx": [
"onnx>=1.15.0",
],
}
},
entry_points={
"console_scripts": [
"torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main",
],
},
zip_safe=False,
)

View File

@ -24,7 +24,7 @@ config.name = 'TORCH_MLIR'
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
# suffixes: A list of file extensions to treat as test files.
config.suffixes = ['.mlir', '.py']
config.suffixes = ['.mlir', '.py', '.runlit']
# test_source_root: The root path where tests are located.
config.test_source_root = os.path.dirname(__file__)

View File

@ -0,0 +1,15 @@
pytorch0.3:h
"
01" LeakyRelu*
alpha
×#< torch-jit-exportZ
0



b
1



B

View File

@ -0,0 +1,3 @@
# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s
# CHECK: torch.operator "onnx.LeakyRelu"