mirror of https://github.com/llvm/torch-mlir
Upstream the ONNX importer. (#2636)
This is part 1 of 2, which will also include upstreaming the FX importer. I started with ONNX because it forces some project layout updates and is more self contained/easier as a first step. Deviating somewhat from the RFCs on project layout, I made the following decisions: * Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks already has opened up that namespace and it seemed to fit. Better to have fewer things at that level. * Setup the build so that the root project only contains MLIR Python and pure Python deps (like the importers), but this can be augmented with the `projects/` adding more depending on which features are enabled. * The default build continues to build everything whereas in `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a `torch-mlir-core` wheel with the pure contents only. `onnx_importer.py` and `importer_smoke_test.py` are almost verbatim copies from SHARK-Turbine. I made some minor local alterations to adapt to paths and generalize the way they interact with the outer project. I expect I can copy these back to Turbine verbatim from here. I also updated the license boilerplate (they have the same license but slightly different project norms for the headers) but retained the correct copyright. Other updates: * Added the ONNX importer unit test (which also can generate test data) in lit, conditioned on the availability of the Python `onnx` package. In a followup once I know everything is stable, I'll add another env var that the CI can set to always enable this so we know conclusively if tests pass. * Moved the ONNX conversion readme to `docs/`. * Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` -> `TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the JitIR importer and LTC options `cmake_dependent_options` for robustness.pull/2637/head
parent
f67249d34f
commit
74f7a0c9d6
|
@ -25,6 +25,8 @@ project(torch-mlir LANGUAGES CXX C)
|
||||||
set(CMAKE_C_STANDARD 11)
|
set(CMAKE_C_STANDARD 11)
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
|
||||||
|
include(CMakeDependentOption)
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
# Project options
|
# Project options
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
@ -43,24 +45,11 @@ endif()
|
||||||
|
|
||||||
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)
|
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)
|
||||||
|
|
||||||
# PT1 options.
|
# PyTorch native extension gate. If OFF, then no features which depend on
|
||||||
option(TORCH_MLIR_ENABLE_PROJECT_PT1 "Enables the PyTorch1 project under projects/pt1" OFF)
|
# native extensions will be built.
|
||||||
# TODO: Rename/scope these. They use historic names for now to ease migration
|
option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON)
|
||||||
# burden.
|
cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
|
||||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
|
||||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
|
||||||
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
|
|
||||||
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
|
||||||
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF)
|
|
||||||
set(TORCH_MLIR_ENABLE_LTC OFF)
|
|
||||||
endif()
|
|
||||||
# Force enable the PT1 project if either the JIT_IR_IMPORTER or LTC is enabled.
|
|
||||||
if(NOT TORCH_MLIR_ENABLE_PROJECT_PT1)
|
|
||||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
|
||||||
message(STATUS "Enabling projects/pt1 because features requiring it are enabled")
|
|
||||||
set(TORCH_MLIR_ENABLE_PROJECT_PT1 ON)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
# Configure out-of-tree vs in-tree build
|
# Configure out-of-tree vs in-tree build
|
||||||
|
@ -235,4 +224,16 @@ endif()
|
||||||
# Sub-projects
|
# Sub-projects
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Sub-projects can bundle additional PyTorch extensions by adding them to this
|
||||||
|
# source target. It is typically empty unless if features are enabled.
|
||||||
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Build projects first as it may populate additional Python deps.
|
||||||
add_subdirectory(projects)
|
add_subdirectory(projects)
|
||||||
|
|
||||||
|
# Finish with top-level Python bindings so it can handle additional deps.
|
||||||
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
add_subdirectory(python)
|
||||||
|
endif()
|
|
@ -351,7 +351,6 @@ function setup_venv() {
|
||||||
echo ":::: Using stable dependencies"
|
echo ":::: Using stable dependencies"
|
||||||
python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||||
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt
|
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unrecognized torch version '$torch_version'"
|
echo "Unrecognized torch version '$torch_version'"
|
||||||
|
@ -359,6 +358,7 @@ function setup_venv() {
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
|
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_out_of_tree() {
|
function build_out_of_tree() {
|
||||||
|
|
|
@ -3,11 +3,8 @@
|
||||||
We enable the direct representation of many ONNX features directly in
|
We enable the direct representation of many ONNX features directly in
|
||||||
the `torch` dialect as `torch.operator` custom ops with names like
|
the `torch` dialect as `torch.operator` custom ops with names like
|
||||||
`onnx.{OperatorName}`. The majority of ONNX operators are represented
|
`onnx.{OperatorName}`. The majority of ONNX operators are represented
|
||||||
with a systematic transformation. See
|
with a systematic transformation. `torch_mlir.extras.onnx_importer`
|
||||||
[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
|
for the reference importer which complies with the rules below.
|
||||||
for the reference importer which complies with the rules below
|
|
||||||
(this is planned to be upstreamed to torch-mlir proper in the near
|
|
||||||
future).
|
|
||||||
|
|
||||||
## Adding new ONNX operators
|
## Adding new ONNX operators
|
||||||
|
|
||||||
|
@ -26,10 +23,11 @@ are relatively straight-forward to map, following this general procedure:
|
||||||
* Open the corresponding implementation file `DefaultDomainXtoY.cpp`
|
* Open the corresponding implementation file `DefaultDomainXtoY.cpp`
|
||||||
corresponding with the alphabetic sort of the op and add a conversion.
|
corresponding with the alphabetic sort of the op and add a conversion.
|
||||||
* Generate successful test cases:
|
* Generate successful test cases:
|
||||||
* Either run the Turbine importer to produce MLIR output for all
|
* All `onnx_importer.py` tests are dumped to the test temp dir (success
|
||||||
ops/models in the ONNX test suite or use a dump that someone has
|
or failure). This is typically located under
|
||||||
generated:
|
`tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files
|
||||||
* [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing)
|
under there should provide good variants to drive lit test coverage of
|
||||||
|
conversion.
|
||||||
* There are often many variants of tests for checking conformance of
|
* There are often many variants of tests for checking conformance of
|
||||||
different historic ONNX encodings, but these are often not load bearing
|
different historic ONNX encodings, but these are often not load bearing
|
||||||
at the MLIR level.
|
at the MLIR level.
|
|
@ -1,7 +1,31 @@
|
||||||
include(AddMLIRPython)
|
include(AddMLIRPython)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# PyTorch
|
||||||
# Configure PyTorch if we have any features enabled which require it.
|
# Configure PyTorch if we have any features enabled which require it.
|
||||||
|
################################################################################
|
||||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
|
||||||
|
|
||||||
|
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
||||||
|
# Source builds
|
||||||
|
message(STATUS "Building libtorch from source (features depend on it and NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)")
|
||||||
|
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
|
||||||
|
set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH})
|
||||||
|
set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD})
|
||||||
|
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
|
||||||
|
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
|
||||||
|
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
|
||||||
|
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh
|
||||||
|
RESULT_VARIABLE _result
|
||||||
|
)
|
||||||
|
if(_result)
|
||||||
|
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
|
||||||
|
endif()
|
||||||
|
set(TORCH_INSTALL_PREFIX "libtorch")
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling PyTorch C++ dep (features depend on it)")
|
message(STATUS "Enabling PyTorch C++ dep (features depend on it)")
|
||||||
include(TorchMLIRPyTorch)
|
include(TorchMLIRPyTorch)
|
||||||
|
|
||||||
|
@ -48,6 +72,6 @@ if(TORCH_MLIR_ENABLE_LTC)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Include overall PT1 project.
|
# Include overall PT1 project.
|
||||||
if(TORCH_MLIR_ENABLE_PROJECT_PT1)
|
if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS)
|
||||||
add_subdirectory(pt1)
|
add_subdirectory(pt1)
|
||||||
endif()
|
endif()
|
||||||
|
|
|
@ -7,79 +7,22 @@ set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
|
||||||
# argument.
|
# argument.
|
||||||
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
|
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
|
||||||
|
|
||||||
|
|
||||||
# We vendor our own MLIR instance in the `torch_mlir` namespace.
|
# We vendor our own MLIR instance in the `torch_mlir` namespace.
|
||||||
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
||||||
|
|
||||||
################################################################################
|
# ################################################################################
|
||||||
# PyTorch
|
# # Sources
|
||||||
################################################################################
|
# ################################################################################
|
||||||
|
|
||||||
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
|
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
||||||
# Source builds
|
|
||||||
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
|
|
||||||
set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH})
|
|
||||||
set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD})
|
|
||||||
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
|
|
||||||
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
|
|
||||||
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
|
|
||||||
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh
|
|
||||||
RESULT_VARIABLE _result
|
|
||||||
)
|
|
||||||
if(_result)
|
|
||||||
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
|
|
||||||
endif()
|
|
||||||
set(TORCH_INSTALL_PREFIX "libtorch")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Sources
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources)
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonExtensions)
|
|
||||||
|
|
||||||
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
|
||||||
SOURCES
|
|
||||||
__init__.py
|
|
||||||
_dynamo_fx_importer.py
|
|
||||||
compiler_utils.py
|
|
||||||
dynamo.py
|
|
||||||
_version.py
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
|
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
|
||||||
)
|
|
||||||
|
|
||||||
declare_mlir_dialect_python_bindings(
|
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources.Dialects
|
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
|
||||||
TD_FILE dialects/TorchBinding.td
|
|
||||||
SOURCES dialects/torch/__init__.py
|
|
||||||
DIALECT_NAME torch
|
|
||||||
)
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Extensions
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
|
||||||
MODULE_NAME _torchMlir
|
|
||||||
ADD_TO_PARENT TorchMLIRPythonExtensions
|
|
||||||
SOURCES
|
SOURCES
|
||||||
TorchMLIRModule.cpp
|
__init__.py
|
||||||
EMBED_CAPI_LINK_LIBS
|
_dynamo_fx_importer.py
|
||||||
TorchMLIRCAPI
|
compiler_utils.py
|
||||||
PRIVATE_LINK_LIBS
|
dynamo.py
|
||||||
LLVMSupport
|
_version.py
|
||||||
)
|
)
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
@ -110,56 +53,23 @@ endif()
|
||||||
|
|
||||||
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Generate packages and shared library
|
|
||||||
# Downstreams typically will not use these, but they are useful for local
|
|
||||||
# testing.
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
set(_source_components
|
|
||||||
# TODO: Core is now implicitly building/registering all dialects, increasing
|
|
||||||
# build burden by ~5x. Make it stop.
|
|
||||||
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
|
|
||||||
# for the reference backend, but logically they can be separate. But seemingly
|
|
||||||
# the only way to handle that is to create a separate mlir python package
|
|
||||||
# tree, which seems excessive.
|
|
||||||
MLIRPythonSources
|
|
||||||
MLIRPythonExtension.Core
|
|
||||||
MLIRPythonExtension.RegisterEverything
|
|
||||||
TorchMLIRPythonSources
|
|
||||||
TorchMLIRPythonExtensions
|
|
||||||
)
|
|
||||||
|
|
||||||
add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI
|
|
||||||
INSTALL_COMPONENT TorchMLIRPythonModules
|
|
||||||
INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs
|
|
||||||
OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
|
||||||
RELATIVE_INSTALL_ROOT "../../../.."
|
|
||||||
DECLARED_SOURCES ${_source_components}
|
|
||||||
)
|
|
||||||
|
|
||||||
add_mlir_python_modules(TorchMLIRPythonModules
|
|
||||||
ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir"
|
|
||||||
INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir"
|
|
||||||
DECLARED_SOURCES ${_source_components}
|
|
||||||
COMMON_CAPI_LINK_LIBS
|
|
||||||
TorchMLIRAggregateCAPI
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Find a cleaner way to do this.
|
# TODO: Find a cleaner way to do this.
|
||||||
# Can we build the JIT IR importer with `declare_mlir_python_extension`?
|
# Can we build the JIT IR importer with `declare_mlir_python_extension`?
|
||||||
# Then it would "just work".
|
# Then it would "just work".
|
||||||
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
|
||||||
add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter)
|
add_dependencies(TorchMLIRPythonTorchExtensionsSources
|
||||||
add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind)
|
TorchMLIRJITIRImporter
|
||||||
# Build the E2E Tests (which depend on the JIT IR importer now).
|
TorchMLIRJITIRImporterPybind
|
||||||
add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules)
|
TorchMLIRE2ETestPythonModules
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(TORCH_MLIR_ENABLE_LTC)
|
if(TORCH_MLIR_ENABLE_LTC)
|
||||||
# Add Torch-MLIR LTC backend as dependency
|
# Add Torch-MLIR LTC backend as dependency
|
||||||
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
|
add_dependencies(TorchMLIRPythonTorchExtensionsSources
|
||||||
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
|
torch_mlir_ltc_backend
|
||||||
|
reference_lazy_backend
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
|
@ -4,9 +4,9 @@
|
||||||
|
|
||||||
## Declare the sources of the Python module.
|
## Declare the sources of the Python module.
|
||||||
|
|
||||||
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter
|
declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter
|
||||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
ADD_TO_PARENT TorchMLIRPythonSources
|
ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
|
||||||
SOURCES_GLOB
|
SOURCES_GLOB
|
||||||
jit_ir_importer/*.py
|
jit_ir_importer/*.py
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Disables generation of "version soname" (i.e. libFoo.so.<version>), which
|
||||||
|
# causes pure duplication as part of Python wheels.
|
||||||
|
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
|
||||||
|
|
||||||
|
# The directory at which the Python import tree begins.
|
||||||
|
# See documentation for `declare_mlir_python_sources`'s ROOT_DIR
|
||||||
|
# argument.
|
||||||
|
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
|
||||||
|
|
||||||
|
|
||||||
|
# We vendor our own MLIR instance in the `torch_mlir` namespace.
|
||||||
|
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Sources
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources)
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonExtensions)
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
)
|
||||||
|
|
||||||
|
declare_mlir_dialect_python_bindings(
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources.Dialects
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
TD_FILE dialects/TorchBinding.td
|
||||||
|
SOURCES dialects/torch/__init__.py
|
||||||
|
DIALECT_NAME torch
|
||||||
|
)
|
||||||
|
|
||||||
|
declare_mlir_python_sources(TorchMLIRPythonSources.Importers
|
||||||
|
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonSources
|
||||||
|
SOURCES
|
||||||
|
extras/onnx_importer.py
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Extensions
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
|
||||||
|
MODULE_NAME _torchMlir
|
||||||
|
ADD_TO_PARENT TorchMLIRPythonExtensions
|
||||||
|
SOURCES
|
||||||
|
TorchMLIRModule.cpp
|
||||||
|
EMBED_CAPI_LINK_LIBS
|
||||||
|
TorchMLIRCAPI
|
||||||
|
PRIVATE_LINK_LIBS
|
||||||
|
LLVMSupport
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# Generate packages and shared library
|
||||||
|
# Downstreams typically will not use these, but they are useful for local
|
||||||
|
# testing.
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
set(_source_components
|
||||||
|
# TODO: Core is now implicitly building/registering all dialects, increasing
|
||||||
|
# build burden by ~5x. Make it stop.
|
||||||
|
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
|
||||||
|
# for the reference backend, but logically they can be separate. But seemingly
|
||||||
|
# the only way to handle that is to create a separate mlir python package
|
||||||
|
# tree, which seems excessive.
|
||||||
|
MLIRPythonSources
|
||||||
|
MLIRPythonExtension.Core
|
||||||
|
MLIRPythonExtension.RegisterEverything
|
||||||
|
TorchMLIRPythonSources
|
||||||
|
TorchMLIRPythonExtensions
|
||||||
|
|
||||||
|
# Sources related to optional Torch extension dependent features. Typically
|
||||||
|
# empty unless if project features are enabled.
|
||||||
|
TorchMLIRPythonTorchExtensionsSources
|
||||||
|
)
|
||||||
|
|
||||||
|
add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI
|
||||||
|
INSTALL_COMPONENT TorchMLIRPythonModules
|
||||||
|
INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs
|
||||||
|
OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
|
||||||
|
RELATIVE_INSTALL_ROOT ".."
|
||||||
|
DECLARED_SOURCES ${_source_components}
|
||||||
|
)
|
||||||
|
|
||||||
|
add_mlir_python_modules(TorchMLIRPythonModules
|
||||||
|
ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir"
|
||||||
|
INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir"
|
||||||
|
DECLARED_SOURCES ${_source_components}
|
||||||
|
COMMON_CAPI_LINK_LIBS
|
||||||
|
TorchMLIRAggregateCAPI
|
||||||
|
)
|
|
@ -0,0 +1,607 @@
|
||||||
|
# Based on code Copyright (c) Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
"""Imports ONNX graphs to `torch` dialect ops.
|
||||||
|
|
||||||
|
See documentation:
|
||||||
|
https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md
|
||||||
|
|
||||||
|
This file is distributed/forked verbatim into various downstream projects, and
|
||||||
|
it must abide by several rules above and beyond the rest of the codebase:
|
||||||
|
- It must be standalone, only depending on:
|
||||||
|
- `onnx`
|
||||||
|
- `..ir` relative imports to the main IR directory
|
||||||
|
- `..dialects.func` relative import to the `func` dialect (TODO:
|
||||||
|
we are looking to eliminate this dep).
|
||||||
|
- Python standard library
|
||||||
|
- It does not directly use the ODS generated `torch` dialect Python
|
||||||
|
wrappers. This allows it to be used in contexts that only build a C++
|
||||||
|
compiler with minimal IR Python bindings.
|
||||||
|
- It is intended as an enabler for full onnx compilation, only handling
|
||||||
|
the import from ONNX -> the `torch` dialect. Testing, full pipelines,
|
||||||
|
and utilities belong elsewhere.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import onnx
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"The onnx package (`pip install onnx`) is required to use the onnx importer"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..ir import (
|
||||||
|
ArrayAttr,
|
||||||
|
Attribute,
|
||||||
|
Block,
|
||||||
|
Context,
|
||||||
|
DenseElementsAttr,
|
||||||
|
DenseResourceElementsAttr,
|
||||||
|
DictAttr,
|
||||||
|
FloatAttr,
|
||||||
|
BF16Type,
|
||||||
|
ComplexType,
|
||||||
|
F16Type,
|
||||||
|
F32Type,
|
||||||
|
F64Type,
|
||||||
|
Float8E4M3FNType,
|
||||||
|
Float8E5M2FNUZType,
|
||||||
|
Float8E5M2Type,
|
||||||
|
FunctionType,
|
||||||
|
InsertionPoint,
|
||||||
|
IntegerAttr,
|
||||||
|
IntegerType,
|
||||||
|
MLIRError,
|
||||||
|
RankedTensorType,
|
||||||
|
Location,
|
||||||
|
Module,
|
||||||
|
Operation,
|
||||||
|
StringAttr,
|
||||||
|
Type as IrType,
|
||||||
|
Value,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..dialects import (
|
||||||
|
func as func_dialect,
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Various configuration settings for the importer."""
|
||||||
|
|
||||||
|
# Ancient ONNX exporters would often add a model input for anything that
|
||||||
|
# might be mutable, providing an initializer for it as well. More modern
|
||||||
|
# tools tools realized this is a really bad idea for a lot of reasons.
|
||||||
|
# We choose to assume more recent norms, even if encountering older
|
||||||
|
# models. Setting this to False probably won't do what you want but
|
||||||
|
# should produce interesting errors to waste your time deciphering.
|
||||||
|
# We mainly use it as a way to document in the code that we are
|
||||||
|
# making an assumption.
|
||||||
|
elide_initialized_inputs: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo:
|
||||||
|
"""Top-level accounting and accessors for an ONNX model."""
|
||||||
|
|
||||||
|
def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()):
|
||||||
|
self.config = config
|
||||||
|
self.model_proto = model_proto
|
||||||
|
assert model_proto.graph, "Model must contain a main Graph"
|
||||||
|
self.main_graph = GraphInfo(self, model_proto.graph)
|
||||||
|
|
||||||
|
def create_module(self, context: Optional[Context] = None) -> Operation:
|
||||||
|
if not context:
|
||||||
|
context = Context()
|
||||||
|
module_op = Module.create(Location.unknown(context)).operation
|
||||||
|
# TODO: Populate module level metadata from the ModelProto
|
||||||
|
return module_op
|
||||||
|
|
||||||
|
|
||||||
|
class GraphInfo:
|
||||||
|
"""Information about a Graph within a model."""
|
||||||
|
|
||||||
|
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
|
||||||
|
self.model_info = model_info
|
||||||
|
self.graph_proto = graph_proto
|
||||||
|
self.initializer_map: dict[str, onnx.TensorProto] = {
|
||||||
|
n.name: n for n in graph_proto.initializer
|
||||||
|
}
|
||||||
|
self.value_info_map: dict[str, onnx.ValueInfoProto] = {
|
||||||
|
n.name: n for n in graph_proto.value_info
|
||||||
|
}
|
||||||
|
self.declared_input_map: dict[str, onnx.ValueInfoProto] = {
|
||||||
|
n.name: n for n in graph_proto.input
|
||||||
|
}
|
||||||
|
self.output_map: dict[str, onnx.ValueInfoProto] = {
|
||||||
|
n.name: n for n in graph_proto.output
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate the effective input map, which for old models can be a
|
||||||
|
# subset of the input map.
|
||||||
|
if model_info.config.elide_initialized_inputs:
|
||||||
|
self.input_map = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.declared_input_map.items()
|
||||||
|
if k not in self.initializer_map
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.input_map = self.declared_input_map
|
||||||
|
illegal_input_keys = self.input_map.keys() - (
|
||||||
|
self.input_map.keys() - self.initializer_map.keys()
|
||||||
|
)
|
||||||
|
assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), (
|
||||||
|
f"When not in elide_initialized_inputs=True, we expect inputs to not "
|
||||||
|
f"have an initial value (got {illegal_input_keys})."
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_type_proto_for_name(self, name: str) -> onnx.TypeProto:
|
||||||
|
# Node outputs don't typically have type information, but shape inference
|
||||||
|
# will associate them in the value_info. If not there, it may be a
|
||||||
|
# graph output, which must have type information.
|
||||||
|
value_info = self.value_info_map.get(name) or self.output_map.get(name)
|
||||||
|
if value_info is not None:
|
||||||
|
return value_info.type
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"No type information associated with '{name}'. Run shape inference?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxImportError(Exception):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class NodeImporter:
|
||||||
|
"""Imports graph nodes into MLIR.
|
||||||
|
|
||||||
|
Typically, the top level graph will be imported into a func whereas dependent
|
||||||
|
graphs may just be imported with references to pre-existing values.
|
||||||
|
|
||||||
|
Note that ONNX requires that graphs be sorted topologically and free of cycles,
|
||||||
|
so we don't take any special steps to order them for dominance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"_c",
|
||||||
|
"_cc",
|
||||||
|
"_gi",
|
||||||
|
"_p",
|
||||||
|
"_b",
|
||||||
|
"_nv_map",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
graph_info: GraphInfo,
|
||||||
|
*,
|
||||||
|
parent_op: Operation,
|
||||||
|
block: Block,
|
||||||
|
context_cache: "ContextCache",
|
||||||
|
):
|
||||||
|
self._c = parent_op.context
|
||||||
|
self._cc = context_cache
|
||||||
|
self._gi = graph_info
|
||||||
|
self._p = parent_op
|
||||||
|
self._b = block
|
||||||
|
self._nv_map: dict[str, Value] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_function(
|
||||||
|
cls, graph_info: GraphInfo, module_op: Operation
|
||||||
|
) -> "NodeImporter":
|
||||||
|
cc = ContextCache(module_op.context)
|
||||||
|
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
|
||||||
|
body = module_op.regions[0].blocks[0]
|
||||||
|
func_name = graph_info.graph_proto.name
|
||||||
|
input_types = [
|
||||||
|
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
|
||||||
|
]
|
||||||
|
output_types = [
|
||||||
|
cc.type_proto_to_type(out.type)
|
||||||
|
for out in graph_info.output_map.values()
|
||||||
|
]
|
||||||
|
ftype = FunctionType.get(input_types, output_types)
|
||||||
|
func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body))
|
||||||
|
block = func_op.add_entry_block(
|
||||||
|
[Location.name(k) for k in graph_info.input_map.keys()]
|
||||||
|
)
|
||||||
|
imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc)
|
||||||
|
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
|
||||||
|
imp._nv_map[node_name] = input_value
|
||||||
|
imp._populate_graph_attrs(func_op)
|
||||||
|
return imp
|
||||||
|
|
||||||
|
def _populate_graph_attrs(self, container_op: Operation):
|
||||||
|
"""Populates graph level meta attributes on the given container op."""
|
||||||
|
m = self._gi.model_info.model_proto
|
||||||
|
with container_op.context:
|
||||||
|
i64_type = IntegerType.get_signed(64)
|
||||||
|
default_opset_version = 0
|
||||||
|
opset_versions: dict[str, IntegerAttr] = {}
|
||||||
|
for opset_import in m.opset_import:
|
||||||
|
if opset_import.domain:
|
||||||
|
opset_versions[opset_import.domain] = IntegerAttr.get(
|
||||||
|
i64_type, opset_import.version
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
default_opset_version = opset_import.version
|
||||||
|
if default_opset_version:
|
||||||
|
container_op.attributes[
|
||||||
|
"torch.onnx_meta.opset_version"
|
||||||
|
] = IntegerAttr.get(i64_type, default_opset_version)
|
||||||
|
if opset_versions:
|
||||||
|
container_op.attributes[
|
||||||
|
"torch.onnx_meta.opset_versions"
|
||||||
|
] = DictAttr.get(opset_versions)
|
||||||
|
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
|
||||||
|
IntegerType.get_signed(64), m.ir_version
|
||||||
|
)
|
||||||
|
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
|
||||||
|
m.producer_name
|
||||||
|
)
|
||||||
|
container_op.attributes[
|
||||||
|
"torch.onnx_meta.producer_version"
|
||||||
|
] = StringAttr.get(m.producer_version)
|
||||||
|
|
||||||
|
def import_all(self):
|
||||||
|
"""Imports all nodes topologically."""
|
||||||
|
# TODO: Consider pulling in initializers on demand since there can be so
|
||||||
|
# much unused crap.
|
||||||
|
for init in self._gi.initializer_map.values():
|
||||||
|
self.import_initializer(init)
|
||||||
|
for node in self._gi.graph_proto.node:
|
||||||
|
self.import_node(node)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for output_name in self._gi.output_map.keys():
|
||||||
|
try:
|
||||||
|
outputs.append(self._nv_map[output_name])
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Non topologically produced ONNX graph output '{output_name}'"
|
||||||
|
)
|
||||||
|
with InsertionPoint(self._b), Location.unknown():
|
||||||
|
func_dialect.ReturnOp(outputs)
|
||||||
|
|
||||||
|
def import_node(self, node: onnx.NodeProto):
|
||||||
|
with InsertionPoint(self._b), Location.name(node.name):
|
||||||
|
op_type = node.op_type
|
||||||
|
# Handle special op types that materialize to non-op IR constructs.
|
||||||
|
special_key = f"_handle_node_{op_type}"
|
||||||
|
if hasattr(self, special_key):
|
||||||
|
getattr(self, special_key)(node)
|
||||||
|
return
|
||||||
|
|
||||||
|
# General node import.
|
||||||
|
input_values = []
|
||||||
|
for input_name in node.input:
|
||||||
|
try:
|
||||||
|
input_values.append(self._nv_map[input_name])
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Non topologically produced ONNX node input '{input_name}': {node}"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_names = list(node.output)
|
||||||
|
output_types = [
|
||||||
|
self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n))
|
||||||
|
for n in output_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# TODO: Attributes.
|
||||||
|
attrs = {
|
||||||
|
"name": StringAttr.get(f"onnx.{op_type}"),
|
||||||
|
}
|
||||||
|
self.import_attributes(node.attribute, attrs)
|
||||||
|
custom_op = Operation.create(
|
||||||
|
name="torch.operator",
|
||||||
|
results=output_types,
|
||||||
|
operands=input_values,
|
||||||
|
attributes=attrs,
|
||||||
|
)
|
||||||
|
for output_name, output_value in zip(output_names, custom_op.results):
|
||||||
|
self._nv_map[output_name] = output_value
|
||||||
|
|
||||||
|
def import_attributes(
|
||||||
|
self, onnx_attrs: list[onnx.AttributeProto], attrs: dict[str, Attribute]
|
||||||
|
):
|
||||||
|
for onnx_attr in onnx_attrs:
|
||||||
|
attr_type = onnx_attr.type
|
||||||
|
if attr_type not in ATTRIBUTE_TYPE_HANDLERS:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}"
|
||||||
|
)
|
||||||
|
handler = ATTRIBUTE_TYPE_HANDLERS[attr_type]
|
||||||
|
if handler is None:
|
||||||
|
# Active skip.
|
||||||
|
continue
|
||||||
|
elif handler is False:
|
||||||
|
# Active error.
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"ONNX importer does not support generic node attribute type {attr_type}. "
|
||||||
|
f"This likely means that this is a special node which requires specific "
|
||||||
|
f"handling in the importer: {onnx_attr}"
|
||||||
|
)
|
||||||
|
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)
|
||||||
|
|
||||||
|
def import_initializer(self, initializer: onnx.TensorProto) -> Value:
|
||||||
|
with InsertionPoint(self._b), Location.name(initializer.name):
|
||||||
|
value_attr = self._cc.tensor_proto_to_attr(initializer)
|
||||||
|
vtensor_type = self._cc.tensor_proto_to_type(initializer)
|
||||||
|
literal_op = Operation.create(
|
||||||
|
name="torch.vtensor.literal",
|
||||||
|
results=[vtensor_type],
|
||||||
|
attributes={"value": value_attr},
|
||||||
|
)
|
||||||
|
self._nv_map[initializer.name] = literal_op.result
|
||||||
|
return literal_op.result
|
||||||
|
|
||||||
|
def _get_immediate_tensor(self, name: str) -> np.array:
|
||||||
|
try:
|
||||||
|
initializer = self._gi.initializer_map[name]
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"An immediate value for '{name}' was required but it is dynamically produced."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type]
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}"
|
||||||
|
)
|
||||||
|
raw_data = initializer.raw_data
|
||||||
|
if raw_data:
|
||||||
|
return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims))
|
||||||
|
else:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unhandled ONNX TensorProto immediate data: {initializer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
|
||||||
|
# This op is special: It has an input of the shape, and in full generality
|
||||||
|
# could involve eager production of constants of variable size. In
|
||||||
|
# practice, the DNN profile for ONNX makes this very difficult to do
|
||||||
|
# and we hard-assert that the input can be resolved to an immediate
|
||||||
|
# value.
|
||||||
|
assert len(node.input) == 1
|
||||||
|
assert len(node.output) == 1
|
||||||
|
shape = self._get_immediate_tensor(node.input[0]).astype(np.int64)
|
||||||
|
value_proto = _get_attr(node, "value")
|
||||||
|
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
|
||||||
|
tensor_proto = value_proto.t
|
||||||
|
element_type = self._cc.tensor_element_type(tensor_proto.data_type)
|
||||||
|
vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type)
|
||||||
|
assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1
|
||||||
|
try:
|
||||||
|
cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type]
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)"
|
||||||
|
)
|
||||||
|
value_attr = cb(tensor_proto, tuple(shape))
|
||||||
|
literal_op = Operation.create(
|
||||||
|
name="torch.vtensor.literal",
|
||||||
|
results=[vtensor_type],
|
||||||
|
attributes={"value": value_attr},
|
||||||
|
)
|
||||||
|
self._nv_map[node.output[0]] = literal_op.result
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCache:
|
||||||
|
"""Caches per-context lookups of various things."""
|
||||||
|
|
||||||
|
__slots__ = [
|
||||||
|
"_c",
|
||||||
|
"_elem_type_map",
|
||||||
|
"_vtensor_type_map",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, context: Context):
|
||||||
|
self._c = context
|
||||||
|
self._elem_type_map: dict[int, IrType] = {}
|
||||||
|
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {}
|
||||||
|
|
||||||
|
def tensor_element_type(self, elem_type: int) -> IrType:
|
||||||
|
t = self._elem_type_map.get(elem_type)
|
||||||
|
if t is None:
|
||||||
|
try:
|
||||||
|
with self._c:
|
||||||
|
t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]()
|
||||||
|
except KeyError:
|
||||||
|
raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}")
|
||||||
|
self._elem_type_map[elem_type] = t
|
||||||
|
return t
|
||||||
|
|
||||||
|
def get_vtensor_type(
|
||||||
|
self, dims: tuple[Optional[int]], element_type: IrType
|
||||||
|
) -> IrType:
|
||||||
|
key = (dims, element_type)
|
||||||
|
t = self._vtensor_type_map.get(key)
|
||||||
|
if t is None:
|
||||||
|
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
|
||||||
|
asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>"
|
||||||
|
try:
|
||||||
|
t = IrType.parse(asm, context=self._c)
|
||||||
|
except MLIRError as e:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unparseable torch type (MLIR asm format bug?): {asm}"
|
||||||
|
) from e
|
||||||
|
self._vtensor_type_map[key] = t
|
||||||
|
return t
|
||||||
|
|
||||||
|
def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType:
|
||||||
|
element_type = self.tensor_element_type(tp.data_type)
|
||||||
|
return self.get_vtensor_type(tuple(tp.dims), element_type)
|
||||||
|
|
||||||
|
def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType:
|
||||||
|
element_type = self.tensor_element_type(tp.data_type)
|
||||||
|
# TODO: Fixme upstream: RankedTensorType.get should not require a location.
|
||||||
|
with Location.unknown():
|
||||||
|
return RankedTensorType.get(tuple(tp.dims), element_type)
|
||||||
|
|
||||||
|
def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
|
||||||
|
if tp.tensor_type:
|
||||||
|
tt = tp.tensor_type
|
||||||
|
if not tt.shape:
|
||||||
|
raise OnnxImportError(
|
||||||
|
f"Unsupported Tensor type without shape (run shape inference?): {tp}"
|
||||||
|
)
|
||||||
|
element_type = self.tensor_element_type(tt.elem_type)
|
||||||
|
dims = tuple(
|
||||||
|
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
|
||||||
|
)
|
||||||
|
return self.get_vtensor_type(dims, element_type)
|
||||||
|
else:
|
||||||
|
# TODO: Others if ever needed. Or we consider ourselves DNN-only.
|
||||||
|
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
|
||||||
|
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
|
||||||
|
|
||||||
|
def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
|
||||||
|
tensor_type = self.tensor_proto_to_builtin_type(tp)
|
||||||
|
if tp.HasField("raw_data"):
|
||||||
|
# Conveniently, DenseResourceElementsAttr shares the raw data
|
||||||
|
# format. We just give it maximum numeric alignment.
|
||||||
|
return DenseResourceElementsAttr.get_from_buffer(
|
||||||
|
tp.raw_data, tp.name, tensor_type, alignment=8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# We have to do a data type specific instantiation from proto fields.
|
||||||
|
# Since this is typically used for small tensor constants, we instantiate
|
||||||
|
# as a DenseElementsAttr.
|
||||||
|
handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type)
|
||||||
|
if handler is None:
|
||||||
|
raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}")
|
||||||
|
return handler(tp)
|
||||||
|
|
||||||
|
|
||||||
|
ELEM_TYPE_TO_IR_TYPE_CB = {
|
||||||
|
onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(),
|
||||||
|
onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),
|
||||||
|
onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8),
|
||||||
|
onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16),
|
||||||
|
onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16),
|
||||||
|
onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32),
|
||||||
|
onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64),
|
||||||
|
onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1),
|
||||||
|
onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(),
|
||||||
|
onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(),
|
||||||
|
onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32),
|
||||||
|
onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64),
|
||||||
|
onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()),
|
||||||
|
onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()),
|
||||||
|
onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(),
|
||||||
|
onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(),
|
||||||
|
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(),
|
||||||
|
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
|
||||||
|
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
|
||||||
|
# Ommitted: STRING,
|
||||||
|
}
|
||||||
|
|
||||||
|
ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
|
||||||
|
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
|
||||||
|
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
|
||||||
|
),
|
||||||
|
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr
|
||||||
|
# of the builtin tensor type for cases where the tensor data is inlined as typed
|
||||||
|
# values instead of raw_data.
|
||||||
|
ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
||||||
|
onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False
|
||||||
|
),
|
||||||
|
onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False
|
||||||
|
),
|
||||||
|
onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False
|
||||||
|
),
|
||||||
|
onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims)
|
||||||
|
),
|
||||||
|
onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get(
|
||||||
|
# Special case. See proto
|
||||||
|
np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims),
|
||||||
|
signless=False,
|
||||||
|
),
|
||||||
|
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
|
||||||
|
)
|
||||||
|
# Intentionally unsupported: STRING
|
||||||
|
}
|
||||||
|
|
||||||
|
ELEM_TYPE_TO_NUMPY_DTYPE = {
|
||||||
|
onnx.TensorProto.DataType.FLOAT: np.float32,
|
||||||
|
onnx.TensorProto.DataType.UINT8: np.uint8,
|
||||||
|
onnx.TensorProto.DataType.INT8: np.int8,
|
||||||
|
onnx.TensorProto.DataType.UINT16: np.uint16,
|
||||||
|
onnx.TensorProto.DataType.INT16: np.int16,
|
||||||
|
onnx.TensorProto.DataType.INT32: np.int32,
|
||||||
|
onnx.TensorProto.DataType.INT64: np.int64,
|
||||||
|
onnx.TensorProto.DataType.BOOL: np.bool_,
|
||||||
|
onnx.TensorProto.DataType.FLOAT16: np.float16,
|
||||||
|
onnx.TensorProto.DataType.DOUBLE: np.float64,
|
||||||
|
onnx.TensorProto.DataType.UINT32: np.uint32,
|
||||||
|
onnx.TensorProto.DataType.UINT64: np.uint64,
|
||||||
|
onnx.TensorProto.DataType.COMPLEX64: np.complex64,
|
||||||
|
onnx.TensorProto.DataType.COMPLEX128: np.complex128,
|
||||||
|
# onnx.TensorProto.DataType.BFLOAT16:
|
||||||
|
# onnx.TensorProto.DataType.FLOAT8E4M3FN:
|
||||||
|
# onnx.TensorProto.DataType.FLOAT8E4M3FNUZ:
|
||||||
|
# onnx.TensorProto.DataType.FLOAT8E5M2:
|
||||||
|
# onnx.TensorProto.DataType.FLOAT8E5M2FNUZ:
|
||||||
|
# Ommitted: STRING,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mapping of AttributeType code to one of:
|
||||||
|
# None: Ignore attribute and do not output to MLIR
|
||||||
|
# False: Error if an attribute of this type is present
|
||||||
|
# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute
|
||||||
|
ATTRIBUTE_TYPE_HANDLERS = {
|
||||||
|
onnx.AttributeProto.AttributeType.UNDEFINED: False,
|
||||||
|
onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get(
|
||||||
|
F32Type.get(), a.f
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get(
|
||||||
|
IntegerType.get_signed(64), a.i
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s),
|
||||||
|
onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr(
|
||||||
|
a.t
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.GRAPH: False,
|
||||||
|
onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False,
|
||||||
|
onnx.AttributeProto.AttributeType.TYPE_PROTO: False,
|
||||||
|
onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get(
|
||||||
|
[FloatAttr.get(F32Type.get(), f) for f in a.floats]
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get(
|
||||||
|
[IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints]
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get(
|
||||||
|
[StringAttr.get(s) for s in a.strings]
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get(
|
||||||
|
[cc.tensor_proto_to_attr(t) for t in a.tensors]
|
||||||
|
),
|
||||||
|
onnx.AttributeProto.AttributeType.GRAPHS: False,
|
||||||
|
onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False,
|
||||||
|
onnx.AttributeProto.AttributeType.TYPE_PROTOS: False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto:
|
||||||
|
for attr in node.attribute:
|
||||||
|
if attr.name == attr_name:
|
||||||
|
return attr
|
||||||
|
else:
|
||||||
|
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")
|
41
setup.py
41
setup.py
|
@ -47,8 +47,6 @@ PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1"
|
||||||
# If true, enable LTC build by default
|
# If true, enable LTC build by default
|
||||||
TORCH_MLIR_ENABLE_LTC_DEFAULT = True
|
TORCH_MLIR_ENABLE_LTC_DEFAULT = True
|
||||||
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False))
|
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False))
|
||||||
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Build phase discovery is unreliable. Just tell it what phases to run.
|
# Build phase discovery is unreliable. Just tell it what phases to run.
|
||||||
class CustomBuild(_build):
|
class CustomBuild(_build):
|
||||||
|
@ -91,7 +89,7 @@ class CMakeBuild(build_py):
|
||||||
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
|
||||||
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
|
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
|
||||||
f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}",
|
f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}",
|
||||||
f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}",
|
f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}",
|
||||||
]
|
]
|
||||||
|
|
||||||
os.makedirs(cmake_build_dir, exist_ok=True)
|
os.makedirs(cmake_build_dir, exist_ok=True)
|
||||||
|
@ -145,8 +143,31 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
|
|
||||||
|
# Requires and extension modules depend on whether building PyTorch
|
||||||
|
# extensions.
|
||||||
|
INSTALL_REQUIRES = [
|
||||||
|
"numpy",
|
||||||
|
"packaging",
|
||||||
|
]
|
||||||
|
EXT_MODULES = [
|
||||||
|
CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
|
||||||
|
]
|
||||||
|
NAME = "torch-mlir-core"
|
||||||
|
|
||||||
|
# If building PyTorch extensions, customize.
|
||||||
|
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
|
||||||
|
import torch
|
||||||
|
NAME = "torch-mlir"
|
||||||
|
INSTALL_REQUIRES.extend([
|
||||||
|
f"torch=={torch.__version__}".split("+", 1)[0],
|
||||||
|
])
|
||||||
|
EXT_MODULES.extend([
|
||||||
|
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core",
|
name=NAME,
|
||||||
version=f"{PACKAGE_VERSION}",
|
version=f"{PACKAGE_VERSION}",
|
||||||
author="Sean Silva",
|
author="Sean Silva",
|
||||||
author_email="silvasean@google.com",
|
author_email="silvasean@google.com",
|
||||||
|
@ -159,10 +180,12 @@ setup(
|
||||||
"built_ext": NoopBuildExtension,
|
"built_ext": NoopBuildExtension,
|
||||||
"build_py": CMakeBuild,
|
"build_py": CMakeBuild,
|
||||||
},
|
},
|
||||||
ext_modules=[
|
ext_modules=EXT_MODULES,
|
||||||
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"),
|
install_requires=INSTALL_REQUIRES,
|
||||||
] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")],
|
extras_require={
|
||||||
install_requires=["numpy", "packaging"] + (
|
"onnx": [
|
||||||
[f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []),
|
"onnx>=1.15.0",
|
||||||
|
],
|
||||||
|
}
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
pillow
|
pillow
|
||||||
dill
|
dill
|
||||||
multiprocess
|
multiprocess
|
||||||
|
onnx==1.15.0
|
|
@ -0,0 +1,2 @@
|
||||||
|
if not config.enable_bindings_python:
|
||||||
|
config.unsupported = True
|
|
@ -0,0 +1 @@
|
||||||
|
output/
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s
|
||||||
|
|
||||||
|
"""This file exists so that the tests can find/configure torch_mlir.
|
||||||
|
|
||||||
|
It allows the test file to be standalone and used verbatim in other
|
||||||
|
projects (i.e. by just providing this file on the side).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from torch_mlir import ir
|
||||||
|
from torch_mlir.extras import onnx_importer
|
||||||
|
|
||||||
|
def configure_context(context):
|
||||||
|
from torch_mlir.dialects import torch as torch_d
|
||||||
|
torch_d.register_dialect(context)
|
|
@ -0,0 +1,374 @@
|
||||||
|
# Based on code Copyright (c) Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s --output %t
|
||||||
|
|
||||||
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
from _torch_mlir_config import (
|
||||||
|
configure_context,
|
||||||
|
ir,
|
||||||
|
onnx_importer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accept the output path on the command line or default to a sibling
|
||||||
|
# to this file. We have to pop this off explicitly or else unittest
|
||||||
|
# won't understand.
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] == "--output":
|
||||||
|
OUTPUT_PATH = Path(sys.argv[2])
|
||||||
|
del sys.argv[1:3]
|
||||||
|
else:
|
||||||
|
OUTPUT_PATH = Path(__file__).resolve().parent / "output"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Add some verification and overrides. For now, just use the
|
||||||
|
# onnx package install for onnx test files, since they were nice
|
||||||
|
# enough to include the test suite in the deployable.
|
||||||
|
import onnx.backend.test.data
|
||||||
|
|
||||||
|
ONNX_TEST_DATA_DIR = Path(onnx.backend.test.__file__).resolve().parent / "data"
|
||||||
|
print(f"ONNX Test Data Dir: {ONNX_TEST_DATA_DIR}")
|
||||||
|
ONNX_REL_PATHS = glob(f"**/*.onnx", root_dir=ONNX_TEST_DATA_DIR, recursive=True)
|
||||||
|
|
||||||
|
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
TEST_CAST_XFAILS = [
|
||||||
|
"light_light_bvlc_alexnet",
|
||||||
|
"light_light_inception_v1",
|
||||||
|
"light_light_squeezenet",
|
||||||
|
"light_light_vgg19",
|
||||||
|
"node_test_affine_grid_2d_align_corners_expanded_model",
|
||||||
|
"node_test_affine_grid_2d_expanded_model",
|
||||||
|
"node_test_affine_grid_3d_align_corners_expanded_model",
|
||||||
|
"node_test_affine_grid_3d_expanded_model",
|
||||||
|
"node_test_ai_onnx_ml_label_encoder_string_int_model",
|
||||||
|
"node_test_ai_onnx_ml_label_encoder_string_int_no_default_model",
|
||||||
|
"node_test_ai_onnx_ml_label_encoder_tensor_mapping_model",
|
||||||
|
"node_test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_model",
|
||||||
|
"node_test_cast_FLOAT16_to_FLOAT8E4M3FNUZ_model",
|
||||||
|
"node_test_cast_FLOAT16_to_FLOAT8E4M3FN_model",
|
||||||
|
"node_test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_model",
|
||||||
|
"node_test_cast_FLOAT16_to_FLOAT8E5M2_model",
|
||||||
|
"node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_model",
|
||||||
|
"node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT_model",
|
||||||
|
"node_test_cast_FLOAT8E4M3FN_to_FLOAT16_model",
|
||||||
|
"node_test_cast_FLOAT8E4M3FN_to_FLOAT_model",
|
||||||
|
"node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT16_model",
|
||||||
|
"node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT_model",
|
||||||
|
"node_test_cast_FLOAT8E5M2_to_FLOAT16_model",
|
||||||
|
"node_test_cast_FLOAT8E5M2_to_FLOAT_model",
|
||||||
|
"node_test_cast_FLOAT_to_FLOAT8E4M3FNUZ_model",
|
||||||
|
"node_test_cast_FLOAT_to_FLOAT8E4M3FN_model",
|
||||||
|
"node_test_cast_FLOAT_to_FLOAT8E5M2FNUZ_model",
|
||||||
|
"node_test_cast_FLOAT_to_FLOAT8E5M2_model",
|
||||||
|
"node_test_cast_FLOAT_to_STRING_model",
|
||||||
|
"node_test_cast_STRING_to_FLOAT_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ_model",
|
||||||
|
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2_model",
|
||||||
|
"node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_model",
|
||||||
|
"node_test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT8E4M3FN_to_FLOAT_model",
|
||||||
|
"node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_model",
|
||||||
|
"node_test_castlike_FLOAT8E5M2_to_FLOAT_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT8E5M2_to_FLOAT_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E4M3FN_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E5M2_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT_to_FLOAT8E5M2_model",
|
||||||
|
"node_test_castlike_FLOAT_to_STRING_expanded_model",
|
||||||
|
"node_test_castlike_FLOAT_to_STRING_model",
|
||||||
|
"node_test_castlike_STRING_to_FLOAT_expanded_model",
|
||||||
|
"node_test_castlike_STRING_to_FLOAT_model",
|
||||||
|
"node_test_center_crop_pad_crop_axes_chw_expanded_model",
|
||||||
|
"node_test_center_crop_pad_crop_axes_hwc_expanded_model",
|
||||||
|
"node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model",
|
||||||
|
"node_test_clip_default_inbounds_model",
|
||||||
|
"node_test_clip_default_int8_inbounds_model",
|
||||||
|
"node_test_clip_default_int8_max_model",
|
||||||
|
"node_test_clip_default_max_model",
|
||||||
|
"node_test_constantofshape_float_ones_model",
|
||||||
|
"node_test_constantofshape_int_shape_zero_model",
|
||||||
|
"node_test_constantofshape_int_zeros_model",
|
||||||
|
"node_test_dequantizelinear_e4m3fn_model",
|
||||||
|
"node_test_dequantizelinear_e4m3fn_zero_point_model",
|
||||||
|
"node_test_dequantizelinear_e5m2_model",
|
||||||
|
"node_test_dft_axis_model",
|
||||||
|
"node_test_dft_inverse_model",
|
||||||
|
"node_test_dft_model",
|
||||||
|
"node_test_equal_string_broadcast_model",
|
||||||
|
"node_test_equal_string_model",
|
||||||
|
"node_test_gru_defaults_model",
|
||||||
|
"node_test_gru_seq_length_model",
|
||||||
|
"node_test_gru_with_initial_bias_model",
|
||||||
|
"node_test_identity_opt_model",
|
||||||
|
"node_test_identity_sequence_model",
|
||||||
|
"node_test_if_model",
|
||||||
|
"node_test_if_opt_model",
|
||||||
|
"node_test_if_seq_model",
|
||||||
|
"node_test_layer_normalization_2d_axis0_expanded_model",
|
||||||
|
"node_test_layer_normalization_2d_axis0_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_2d_axis1_expanded_model",
|
||||||
|
"node_test_layer_normalization_2d_axis1_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_2d_axis_negative_1_expanded_model",
|
||||||
|
"node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_2d_axis_negative_2_expanded_model",
|
||||||
|
"node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis0_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis1_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis2_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model",
|
||||||
|
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis0_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis0_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis1_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis1_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis2_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis2_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis3_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis3_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_1_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_2_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_3_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_4_expanded_model",
|
||||||
|
"node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model",
|
||||||
|
"node_test_layer_normalization_default_axis_expanded_model",
|
||||||
|
"node_test_layer_normalization_default_axis_expanded_ver18_model",
|
||||||
|
"node_test_loop11_model",
|
||||||
|
"node_test_loop13_seq_model",
|
||||||
|
"node_test_loop16_seq_none_model",
|
||||||
|
"node_test_lstm_defaults_model",
|
||||||
|
"node_test_lstm_with_initial_bias_model",
|
||||||
|
"node_test_lstm_with_peepholes_model",
|
||||||
|
"node_test_optional_get_element_optional_sequence_model",
|
||||||
|
"node_test_optional_get_element_optional_tensor_model",
|
||||||
|
"node_test_optional_get_element_sequence_model",
|
||||||
|
"node_test_optional_has_element_empty_no_input_name_optional_input_model",
|
||||||
|
"node_test_optional_has_element_empty_no_input_name_tensor_input_model",
|
||||||
|
"node_test_optional_has_element_empty_optional_input_model",
|
||||||
|
"node_test_optional_has_element_optional_input_model",
|
||||||
|
"node_test_optional_has_element_tensor_input_model",
|
||||||
|
"node_test_quantizelinear_e4m3fn_model",
|
||||||
|
"node_test_quantizelinear_e5m2_model",
|
||||||
|
"node_test_range_float_type_positive_delta_expanded_model",
|
||||||
|
"node_test_range_int32_type_negative_delta_expanded_model",
|
||||||
|
"node_test_regex_full_match_basic_model",
|
||||||
|
"node_test_regex_full_match_email_domain_model",
|
||||||
|
"node_test_regex_full_match_empty_model",
|
||||||
|
"node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model",
|
||||||
|
"node_test_resize_downsample_scales_cubic_align_corners_model",
|
||||||
|
"node_test_resize_downsample_scales_cubic_antialias_model",
|
||||||
|
"node_test_resize_downsample_scales_cubic_model",
|
||||||
|
"node_test_resize_downsample_scales_linear_align_corners_model",
|
||||||
|
"node_test_resize_downsample_scales_linear_antialias_model",
|
||||||
|
"node_test_resize_downsample_scales_linear_half_pixel_symmetric_model",
|
||||||
|
"node_test_resize_downsample_scales_linear_model",
|
||||||
|
"node_test_resize_downsample_scales_nearest_model",
|
||||||
|
"node_test_resize_downsample_sizes_cubic_antialias_model",
|
||||||
|
"node_test_resize_downsample_sizes_cubic_model",
|
||||||
|
"node_test_resize_downsample_sizes_linear_antialias_model",
|
||||||
|
"node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model",
|
||||||
|
"node_test_resize_downsample_sizes_nearest_model",
|
||||||
|
"node_test_resize_downsample_sizes_nearest_not_larger_model",
|
||||||
|
"node_test_resize_downsample_sizes_nearest_not_smaller_model",
|
||||||
|
"node_test_resize_tf_crop_and_resize_axes_2_3_model",
|
||||||
|
"node_test_resize_tf_crop_and_resize_axes_3_2_model",
|
||||||
|
"node_test_resize_tf_crop_and_resize_model",
|
||||||
|
"node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model",
|
||||||
|
"node_test_resize_upsample_scales_cubic_align_corners_model",
|
||||||
|
"node_test_resize_upsample_scales_cubic_asymmetric_model",
|
||||||
|
"node_test_resize_upsample_scales_cubic_model",
|
||||||
|
"node_test_resize_upsample_scales_linear_align_corners_model",
|
||||||
|
"node_test_resize_upsample_scales_linear_half_pixel_symmetric_model",
|
||||||
|
"node_test_resize_upsample_scales_linear_model",
|
||||||
|
"node_test_resize_upsample_scales_nearest_axes_2_3_model",
|
||||||
|
"node_test_resize_upsample_scales_nearest_axes_3_2_model",
|
||||||
|
"node_test_resize_upsample_scales_nearest_model",
|
||||||
|
"node_test_resize_upsample_sizes_cubic_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_axes_2_3_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_axes_3_2_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_floor_align_corners_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_not_larger_model",
|
||||||
|
"node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model",
|
||||||
|
"node_test_rnn_seq_length_model",
|
||||||
|
"node_test_scan9_sum_model",
|
||||||
|
"node_test_scan_sum_model",
|
||||||
|
"node_test_sequence_insert_at_back_model",
|
||||||
|
"node_test_sequence_insert_at_front_model",
|
||||||
|
"node_test_sequence_map_add_1_sequence_1_tensor_expanded_model",
|
||||||
|
"node_test_sequence_map_add_1_sequence_1_tensor_model",
|
||||||
|
"node_test_sequence_map_add_2_sequences_expanded_model",
|
||||||
|
"node_test_sequence_map_add_2_sequences_model",
|
||||||
|
"node_test_sequence_map_extract_shapes_expanded_model",
|
||||||
|
"node_test_sequence_map_extract_shapes_model",
|
||||||
|
"node_test_sequence_map_identity_1_sequence_1_tensor_expanded_model",
|
||||||
|
"node_test_sequence_map_identity_1_sequence_1_tensor_model",
|
||||||
|
"node_test_sequence_map_identity_1_sequence_expanded_model",
|
||||||
|
"node_test_sequence_map_identity_1_sequence_model",
|
||||||
|
"node_test_sequence_map_identity_2_sequences_expanded_model",
|
||||||
|
"node_test_sequence_map_identity_2_sequences_model",
|
||||||
|
"node_test_simple_rnn_defaults_model",
|
||||||
|
"node_test_simple_rnn_with_initial_bias_model",
|
||||||
|
"node_test_split_to_sequence_1_model",
|
||||||
|
"node_test_split_to_sequence_2_model",
|
||||||
|
"node_test_split_to_sequence_nokeepdims_model",
|
||||||
|
"node_test_stft_model",
|
||||||
|
"node_test_string_concat_broadcasting_model",
|
||||||
|
"node_test_string_concat_empty_string_model",
|
||||||
|
"node_test_string_concat_model",
|
||||||
|
"node_test_string_concat_utf8_model",
|
||||||
|
"node_test_string_concat_zero_dimensional_model",
|
||||||
|
"node_test_string_split_basic_model",
|
||||||
|
"node_test_string_split_consecutive_delimiters_model",
|
||||||
|
"node_test_string_split_empty_string_delimiter_model",
|
||||||
|
"node_test_string_split_empty_tensor_model",
|
||||||
|
"node_test_string_split_maxsplit_model",
|
||||||
|
"node_test_string_split_no_delimiter_model",
|
||||||
|
"node_test_strnormalizer_export_monday_casesensintive_lower_model",
|
||||||
|
"node_test_strnormalizer_export_monday_casesensintive_nochangecase_model",
|
||||||
|
"node_test_strnormalizer_export_monday_casesensintive_upper_model",
|
||||||
|
"node_test_strnormalizer_export_monday_empty_output_model",
|
||||||
|
"node_test_strnormalizer_export_monday_insensintive_upper_twodim_model",
|
||||||
|
"node_test_strnormalizer_nostopwords_nochangecase_model",
|
||||||
|
"simple_test_sequence_model1_model",
|
||||||
|
"simple_test_sequence_model2_model",
|
||||||
|
"simple_test_sequence_model3_model",
|
||||||
|
"simple_test_sequence_model4_model",
|
||||||
|
"simple_test_sequence_model5_model",
|
||||||
|
"simple_test_sequence_model6_model",
|
||||||
|
"simple_test_sequence_model7_model",
|
||||||
|
"simple_test_sequence_model8_model",
|
||||||
|
"simple_test_strnorm_model_monday_casesensintive_lower_model",
|
||||||
|
"simple_test_strnorm_model_monday_casesensintive_nochangecase_model",
|
||||||
|
"simple_test_strnorm_model_monday_casesensintive_upper_model",
|
||||||
|
"simple_test_strnorm_model_monday_empty_output_model",
|
||||||
|
"simple_test_strnorm_model_monday_insensintive_upper_twodim_model",
|
||||||
|
"simple_test_strnorm_model_nostopwords_nochangecase_model",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ImportSmokeTest(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.unexpected_failure_count = 0
|
||||||
|
ImportSmokeTest.actual_failures = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
if cls.unexpected_failure_count:
|
||||||
|
# Print a helpful message with copy-paste XFAIL def.
|
||||||
|
failure_report_path = OUTPUT_PATH / "import_smoke_test_report.txt"
|
||||||
|
print(
|
||||||
|
"Unexpected failures. Writing copy/paste report to:",
|
||||||
|
failure_report_path,
|
||||||
|
)
|
||||||
|
with open(failure_report_path, "wt") as f:
|
||||||
|
lines = [f' "{s}",' for s in ImportSmokeTest.actual_failures]
|
||||||
|
print(
|
||||||
|
f"Unexpected failures in the following. Copy/paste to update `TEST_CAST_XFAILS`:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(f"TEST_CAST_XFAILS = [", file=f)
|
||||||
|
[print(l, file=f) for l in lines]
|
||||||
|
print(f"]", file=f)
|
||||||
|
|
||||||
|
ImportSmokeTest.actual_failures.clear()
|
||||||
|
|
||||||
|
def load_onnx_model(self, file_path: Path) -> onnx.ModelProto:
|
||||||
|
raw_model = onnx.load(file_path)
|
||||||
|
try:
|
||||||
|
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
|
||||||
|
except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e:
|
||||||
|
print("WARNING: Shape inference failure (skipping test):", e)
|
||||||
|
self.skipTest(reason="shape inference failure")
|
||||||
|
|
||||||
|
# inferred_model = raw_model
|
||||||
|
return inferred_model
|
||||||
|
|
||||||
|
def run_import_test(self, norm_name: str, rel_path: str):
|
||||||
|
context = ir.Context()
|
||||||
|
configure_context(context)
|
||||||
|
|
||||||
|
model_info = onnx_importer.ModelInfo(
|
||||||
|
self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path),
|
||||||
|
)
|
||||||
|
m = model_info.create_module(context=context)
|
||||||
|
try:
|
||||||
|
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
||||||
|
imp.import_all()
|
||||||
|
m.verify()
|
||||||
|
finally:
|
||||||
|
# Use a ".txt" extension to avoid lit test discovery.
|
||||||
|
with open(OUTPUT_PATH / f"{norm_name}.mlir", "wt") as f:
|
||||||
|
print(m.get_asm(), file=f)
|
||||||
|
|
||||||
|
def testExists(self):
|
||||||
|
# We expect a lot of test cases. Die if not the case (i.e. if paths change
|
||||||
|
# or something).
|
||||||
|
self.assertGreater(len(ONNX_REL_PATHS), 10)
|
||||||
|
|
||||||
|
|
||||||
|
# Generate test methods for each onnx file.
|
||||||
|
for _rel_path in ONNX_REL_PATHS:
|
||||||
|
|
||||||
|
def attach_test(rel_path):
|
||||||
|
norm_name = rel_path.removesuffix(".onnx").replace("/", "_")
|
||||||
|
|
||||||
|
def test_method(self: ImportSmokeTest):
|
||||||
|
try:
|
||||||
|
self.run_import_test(norm_name, rel_path)
|
||||||
|
except onnx_importer.OnnxImportError as e:
|
||||||
|
# All legitimate failures should be caught and reported
|
||||||
|
# as an OnnxImportError.
|
||||||
|
ImportSmokeTest.actual_failures.append(norm_name)
|
||||||
|
if norm_name not in TEST_CAST_XFAILS:
|
||||||
|
ImportSmokeTest.unexpected_failure_count += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
|
test_method.__name__ = f"test_{norm_name}"
|
||||||
|
|
||||||
|
if norm_name in TEST_CAST_XFAILS:
|
||||||
|
test_method = unittest.expectedFailure(test_method)
|
||||||
|
|
||||||
|
setattr(ImportSmokeTest, test_method.__name__, test_method)
|
||||||
|
|
||||||
|
attach_test(_rel_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,5 @@
|
||||||
|
try:
|
||||||
|
import onnx
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print("Skipping onnx tests.. no onnx")
|
||||||
|
config.unsupported = True
|
Loading…
Reference in New Issue