mirror of https://github.com/llvm/torch-mlir
[onnx] Add torch-mlir-import-onnx tool. (#2637)
Simple Python console script to import an ONNX protobuf to the torch dialect for additional processing. For installed wheels, this can be used with something like: ``` torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx ``` Or from a dev setup: ``` python -m torch_mlir.tools.import_onnx ... ```pull/2640/head snapshot-20231213.1051
parent
7cf52ae73f
commit
ed4df38e8d
|
@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers
|
|||
extras/onnx_importer.py
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
tools/import_onnx/__main__.py
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Extensions
|
||||
################################################################################
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
"""Console tool for converting an ONNX proto to torch IR.
|
||||
|
||||
Typically, when installed from a wheel, this can be invoked as:
|
||||
|
||||
torch-mlir-import-onnx some.pb
|
||||
|
||||
Or from Python:
|
||||
|
||||
python -m torch_mlir.tools.import_onnx ...
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import onnx
|
||||
|
||||
from ...extras import onnx_importer
|
||||
|
||||
from ...dialects import torch as torch_d
|
||||
from ...ir import (
|
||||
Context,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
model_proto = load_onnx_model(args.input_file)
|
||||
context = Context()
|
||||
torch_d.register_dialect(context)
|
||||
model_info = onnx_importer.ModelInfo(model_proto)
|
||||
m = model_info.create_module(context=context)
|
||||
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
||||
imp.import_all()
|
||||
if not args.no_verify:
|
||||
m.verify()
|
||||
|
||||
# TODO: This isn't very efficient output. If these files ever
|
||||
# get large, enable bytecode and direct binary emission to save
|
||||
# some copies.
|
||||
if args.output_file and args.output_file != "-":
|
||||
with open(args.output_file, "wt") as f:
|
||||
print(m.get_asm(assume_verified=not args.no_verify), file=f)
|
||||
else:
|
||||
print(m.get_asm(assume_verified=not args.no_verify))
|
||||
|
||||
|
||||
def load_onnx_model(file_path: Path) -> onnx.ModelProto:
|
||||
raw_model = onnx.load(file_path)
|
||||
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
|
||||
return inferred_model
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool")
|
||||
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
|
||||
parser.add_argument(
|
||||
"-o", dest="output_file", help="Output path (or '-' for stdout)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-verify",
|
||||
action="store_true",
|
||||
help="Disable verification prior to printing",
|
||||
)
|
||||
args = parser.parse_args(argv)
|
||||
return args
|
||||
|
||||
|
||||
def _cli_main():
|
||||
sys.exit(main(parse_arguments()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_cli_main()
|
7
setup.py
7
setup.py
|
@ -186,6 +186,11 @@ setup(
|
|||
"onnx": [
|
||||
"onnx>=1.15.0",
|
||||
],
|
||||
}
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main",
|
||||
],
|
||||
},
|
||||
zip_safe=False,
|
||||
)
|
||||
|
|
|
@ -24,7 +24,7 @@ config.name = 'TORCH_MLIR'
|
|||
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
|
||||
|
||||
# suffixes: A list of file extensions to treat as test files.
|
||||
config.suffixes = ['.mlir', '.py']
|
||||
config.suffixes = ['.mlir', '.py', '.runlit']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
pytorch0.3:h
|
||||
"
|
||||
01" LeakyRelu*
|
||||
alpha
|
||||
×#< torch-jit-exportZ
|
||||
0
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
B
|
|
@ -0,0 +1,3 @@
|
|||
# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s
|
||||
|
||||
# CHECK: torch.operator "onnx.LeakyRelu"
|
Loading…
Reference in New Issue