Upstream the ONNX importer. (#2636)

This is part 1 of 2, which will also include upstreaming the FX
importer. I started with ONNX because it forces some project layout
updates and is more self contained/easier as a first step.

Deviating somewhat from the RFCs on project layout, I made the following
decisions:

* Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks
already has opened up that namespace and it seemed to fit. Better to
have fewer things at that level.
* Setup the build so that the root project only contains MLIR Python and
pure Python deps (like the importers), but this can be augmented with
the `projects/` adding more depending on which features are enabled.
* The default build continues to build everything whereas in
`TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a
`torch-mlir-core` wheel with the pure contents only.

`onnx_importer.py` and `importer_smoke_test.py` are almost verbatim
copies from SHARK-Turbine. I made some minor local alterations to adapt
to paths and generalize the way they interact with the outer project. I
expect I can copy these back to Turbine verbatim from here. I also
updated the license boilerplate (they have the same license but slightly
different project norms for the headers) but retained the correct
copyright.

Other updates:

* Added the ONNX importer unit test (which also can generate test data)
in lit, conditioned on the availability of the Python `onnx` package. In
a followup once I know everything is stable, I'll add another env var
that the CI can set to always enable this so we know conclusively if
tests pass.
* Moved the ONNX conversion readme to `docs/`.
* Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` ->
`TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the
JitIR importer and LTC options `cmake_dependent_options` for robustness.
pull/2637/head
Stella Laurenzo 2023-12-12 19:02:51 -08:00 committed by GitHub
parent f67249d34f
commit 74f7a0c9d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1208 additions and 149 deletions

View File

@ -25,6 +25,8 @@ project(torch-mlir LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
include(CMakeDependentOption)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# Project options # Project options
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -43,24 +45,11 @@ endif()
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF) option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)
# PT1 options. # PyTorch native extension gate. If OFF, then no features which depend on
option(TORCH_MLIR_ENABLE_PROJECT_PT1 "Enables the PyTorch1 project under projects/pt1" OFF) # native extensions will be built.
# TODO: Rename/scope these. They use historic names for now to ease migration option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON)
# burden. cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF)
set(TORCH_MLIR_ENABLE_LTC OFF)
endif()
# Force enable the PT1 project if either the JIT_IR_IMPORTER or LTC is enabled.
if(NOT TORCH_MLIR_ENABLE_PROJECT_PT1)
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
message(STATUS "Enabling projects/pt1 because features requiring it are enabled")
set(TORCH_MLIR_ENABLE_PROJECT_PT1 ON)
endif()
endif()
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# Configure out-of-tree vs in-tree build # Configure out-of-tree vs in-tree build
@ -235,4 +224,16 @@ endif()
# Sub-projects # Sub-projects
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# Sub-projects can bundle additional PyTorch extensions by adding them to this
# source target. It is typically empty unless if features are enabled.
if(MLIR_ENABLE_BINDINGS_PYTHON)
declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources)
endif()
# Build projects first as it may populate additional Python deps.
add_subdirectory(projects) add_subdirectory(projects)
# Finish with top-level Python bindings so it can handle additional deps.
if(MLIR_ENABLE_BINDINGS_PYTHON)
add_subdirectory(python)
endif()

View File

@ -351,7 +351,6 @@ function setup_venv() {
echo ":::: Using stable dependencies" echo ":::: Using stable dependencies"
python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt
;; ;;
*) *)
echo "Unrecognized torch version '$torch_version'" echo "Unrecognized torch version '$torch_version'"
@ -359,6 +358,7 @@ function setup_venv() {
;; ;;
esac esac
python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt
} }
function build_out_of_tree() { function build_out_of_tree() {

View File

@ -3,11 +3,8 @@
We enable the direct representation of many ONNX features directly in We enable the direct representation of many ONNX features directly in
the `torch` dialect as `torch.operator` custom ops with names like the `torch` dialect as `torch.operator` custom ops with names like
`onnx.{OperatorName}`. The majority of ONNX operators are represented `onnx.{OperatorName}`. The majority of ONNX operators are represented
with a systematic transformation. See with a systematic transformation. `torch_mlir.extras.onnx_importer`
[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) for the reference importer which complies with the rules below.
for the reference importer which complies with the rules below
(this is planned to be upstreamed to torch-mlir proper in the near
future).
## Adding new ONNX operators ## Adding new ONNX operators
@ -26,10 +23,11 @@ are relatively straight-forward to map, following this general procedure:
* Open the corresponding implementation file `DefaultDomainXtoY.cpp` * Open the corresponding implementation file `DefaultDomainXtoY.cpp`
corresponding with the alphabetic sort of the op and add a conversion. corresponding with the alphabetic sort of the op and add a conversion.
* Generate successful test cases: * Generate successful test cases:
* Either run the Turbine importer to produce MLIR output for all * All `onnx_importer.py` tests are dumped to the test temp dir (success
ops/models in the ONNX test suite or use a dump that someone has or failure). This is typically located under
generated: `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files
* [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) under there should provide good variants to drive lit test coverage of
conversion.
* There are often many variants of tests for checking conformance of * There are often many variants of tests for checking conformance of
different historic ONNX encodings, but these are often not load bearing different historic ONNX encodings, but these are often not load bearing
at the MLIR level. at the MLIR level.

View File

@ -1,7 +1,31 @@
include(AddMLIRPython) include(AddMLIRPython)
################################################################################
# PyTorch
# Configure PyTorch if we have any features enabled which require it. # Configure PyTorch if we have any features enabled which require it.
################################################################################
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC)
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
# Source builds
message(STATUS "Building libtorch from source (features depend on it and NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)")
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH})
set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD})
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
execute_process(
COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh
RESULT_VARIABLE _result
)
if(_result)
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
endif()
set(TORCH_INSTALL_PREFIX "libtorch")
endif()
message(STATUS "Enabling PyTorch C++ dep (features depend on it)") message(STATUS "Enabling PyTorch C++ dep (features depend on it)")
include(TorchMLIRPyTorch) include(TorchMLIRPyTorch)
@ -48,6 +72,6 @@ if(TORCH_MLIR_ENABLE_LTC)
endif() endif()
# Include overall PT1 project. # Include overall PT1 project.
if(TORCH_MLIR_ENABLE_PROJECT_PT1) if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS)
add_subdirectory(pt1) add_subdirectory(pt1)
endif() endif()

View File

@ -7,79 +7,22 @@ set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
# argument. # argument.
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
# We vendor our own MLIR instance in the `torch_mlir` namespace. # We vendor our own MLIR instance in the `torch_mlir` namespace.
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
################################################################################ # ################################################################################
# PyTorch # # Sources
################################################################################ # ################################################################################
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
# Source builds
set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO})
set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH})
set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD})
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
execute_process(
COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh
RESULT_VARIABLE _result
)
if(_result)
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
endif()
set(TORCH_INSTALL_PREFIX "libtorch")
endif()
################################################################################
# Sources
################################################################################
declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions)
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
_dynamo_fx_importer.py
compiler_utils.py
dynamo.py
_version.py
)
endif()
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
TD_FILE dialects/TorchBinding.td
SOURCES dialects/torch/__init__.py
DIALECT_NAME torch
)
################################################################################
# Extensions
################################################################################
declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
MODULE_NAME _torchMlir
ADD_TO_PARENT TorchMLIRPythonExtensions
SOURCES SOURCES
TorchMLIRModule.cpp __init__.py
EMBED_CAPI_LINK_LIBS _dynamo_fx_importer.py
TorchMLIRCAPI compiler_utils.py
PRIVATE_LINK_LIBS dynamo.py
LLVMSupport _version.py
) )
################################################################################ ################################################################################
@ -110,56 +53,23 @@ endif()
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) # add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
################################################################################
# Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local
# testing.
################################################################################
set(_source_components
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
# for the reference backend, but logically they can be separate. But seemingly
# the only way to handle that is to create a separate mlir python package
# tree, which seems excessive.
MLIRPythonSources
MLIRPythonExtension.Core
MLIRPythonExtension.RegisterEverything
TorchMLIRPythonSources
TorchMLIRPythonExtensions
)
add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI
INSTALL_COMPONENT TorchMLIRPythonModules
INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs
OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_SOURCES ${_source_components}
)
add_mlir_python_modules(TorchMLIRPythonModules
ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir"
INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir"
DECLARED_SOURCES ${_source_components}
COMMON_CAPI_LINK_LIBS
TorchMLIRAggregateCAPI
)
# TODO: Find a cleaner way to do this. # TODO: Find a cleaner way to do this.
# Can we build the JIT IR importer with `declare_mlir_python_extension`? # Can we build the JIT IR importer with `declare_mlir_python_extension`?
# Then it would "just work". # Then it would "just work".
if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter) add_dependencies(TorchMLIRPythonTorchExtensionsSources
add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind) TorchMLIRJITIRImporter
# Build the E2E Tests (which depend on the JIT IR importer now). TorchMLIRJITIRImporterPybind
add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules) TorchMLIRE2ETestPythonModules
)
endif() endif()
if(TORCH_MLIR_ENABLE_LTC) if(TORCH_MLIR_ENABLE_LTC)
# Add Torch-MLIR LTC backend as dependency # Add Torch-MLIR LTC backend as dependency
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) add_dependencies(TorchMLIRPythonTorchExtensionsSources
add_dependencies(TorchMLIRPythonModules reference_lazy_backend) torch_mlir_ltc_backend
reference_lazy_backend
)
endif() endif()
add_subdirectory(test) add_subdirectory(test)

View File

@ -4,9 +4,9 @@
## Declare the sources of the Python module. ## Declare the sources of the Python module.
declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources
SOURCES_GLOB SOURCES_GLOB
jit_ir_importer/*.py jit_ir_importer/*.py
) )

View File

@ -0,0 +1,94 @@
# Disables generation of "version soname" (i.e. libFoo.so.<version>), which
# causes pure duplication as part of Python wheels.
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON)
# The directory at which the Python import tree begins.
# See documentation for `declare_mlir_python_sources`'s ROOT_DIR
# argument.
set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
# We vendor our own MLIR instance in the `torch_mlir` namespace.
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
################################################################################
# Sources
################################################################################
declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions)
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
TD_FILE dialects/TorchBinding.td
SOURCES dialects/torch/__init__.py
DIALECT_NAME torch
)
declare_mlir_python_sources(TorchMLIRPythonSources.Importers
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
extras/onnx_importer.py
)
################################################################################
# Extensions
################################################################################
declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
MODULE_NAME _torchMlir
ADD_TO_PARENT TorchMLIRPythonExtensions
SOURCES
TorchMLIRModule.cpp
EMBED_CAPI_LINK_LIBS
TorchMLIRCAPI
PRIVATE_LINK_LIBS
LLVMSupport
)
################################################################################
# Generate packages and shared library
# Downstreams typically will not use these, but they are useful for local
# testing.
################################################################################
set(_source_components
# TODO: Core is now implicitly building/registering all dialects, increasing
# build burden by ~5x. Make it stop.
# TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes
# for the reference backend, but logically they can be separate. But seemingly
# the only way to handle that is to create a separate mlir python package
# tree, which seems excessive.
MLIRPythonSources
MLIRPythonExtension.Core
MLIRPythonExtension.RegisterEverything
TorchMLIRPythonSources
TorchMLIRPythonExtensions
# Sources related to optional Torch extension dependent features. Typically
# empty unless if project features are enabled.
TorchMLIRPythonTorchExtensionsSources
)
add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI
INSTALL_COMPONENT TorchMLIRPythonModules
INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs
OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs"
RELATIVE_INSTALL_ROOT ".."
DECLARED_SOURCES ${_source_components}
)
add_mlir_python_modules(TorchMLIRPythonModules
ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir"
INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir"
DECLARED_SOURCES ${_source_components}
COMMON_CAPI_LINK_LIBS
TorchMLIRAggregateCAPI
)

View File

@ -0,0 +1,607 @@
# Based on code Copyright (c) Advanced Micro Devices, Inc.
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
"""Imports ONNX graphs to `torch` dialect ops.
See documentation:
https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md
This file is distributed/forked verbatim into various downstream projects, and
it must abide by several rules above and beyond the rest of the codebase:
- It must be standalone, only depending on:
- `onnx`
- `..ir` relative imports to the main IR directory
- `..dialects.func` relative import to the `func` dialect (TODO:
we are looking to eliminate this dep).
- Python standard library
- It does not directly use the ODS generated `torch` dialect Python
wrappers. This allows it to be used in contexts that only build a C++
compiler with minimal IR Python bindings.
- It is intended as an enabler for full onnx compilation, only handling
the import from ONNX -> the `torch` dialect. Testing, full pipelines,
and utilities belong elsewhere.
"""
try:
import onnx
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"The onnx package (`pip install onnx`) is required to use the onnx importer"
) from e
from typing import Optional
from dataclasses import dataclass
import numpy as np
from ..ir import (
ArrayAttr,
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
DictAttr,
FloatAttr,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float8E4M3FNType,
Float8E5M2FNUZType,
Float8E5M2Type,
FunctionType,
InsertionPoint,
IntegerAttr,
IntegerType,
MLIRError,
RankedTensorType,
Location,
Module,
Operation,
StringAttr,
Type as IrType,
Value,
)
from ..dialects import (
func as func_dialect,
)
@dataclass
class Config:
"""Various configuration settings for the importer."""
# Ancient ONNX exporters would often add a model input for anything that
# might be mutable, providing an initializer for it as well. More modern
# tools tools realized this is a really bad idea for a lot of reasons.
# We choose to assume more recent norms, even if encountering older
# models. Setting this to False probably won't do what you want but
# should produce interesting errors to waste your time deciphering.
# We mainly use it as a way to document in the code that we are
# making an assumption.
elide_initialized_inputs: bool = True
class ModelInfo:
"""Top-level accounting and accessors for an ONNX model."""
def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()):
self.config = config
self.model_proto = model_proto
assert model_proto.graph, "Model must contain a main Graph"
self.main_graph = GraphInfo(self, model_proto.graph)
def create_module(self, context: Optional[Context] = None) -> Operation:
if not context:
context = Context()
module_op = Module.create(Location.unknown(context)).operation
# TODO: Populate module level metadata from the ModelProto
return module_op
class GraphInfo:
"""Information about a Graph within a model."""
def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto):
self.model_info = model_info
self.graph_proto = graph_proto
self.initializer_map: dict[str, onnx.TensorProto] = {
n.name: n for n in graph_proto.initializer
}
self.value_info_map: dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.value_info
}
self.declared_input_map: dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.input
}
self.output_map: dict[str, onnx.ValueInfoProto] = {
n.name: n for n in graph_proto.output
}
# Generate the effective input map, which for old models can be a
# subset of the input map.
if model_info.config.elide_initialized_inputs:
self.input_map = {
k: v
for k, v in self.declared_input_map.items()
if k not in self.initializer_map
}
else:
self.input_map = self.declared_input_map
illegal_input_keys = self.input_map.keys() - (
self.input_map.keys() - self.initializer_map.keys()
)
assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), (
f"When not in elide_initialized_inputs=True, we expect inputs to not "
f"have an initial value (got {illegal_input_keys})."
)
def find_type_proto_for_name(self, name: str) -> onnx.TypeProto:
# Node outputs don't typically have type information, but shape inference
# will associate them in the value_info. If not there, it may be a
# graph output, which must have type information.
value_info = self.value_info_map.get(name) or self.output_map.get(name)
if value_info is not None:
return value_info.type
raise OnnxImportError(
f"No type information associated with '{name}'. Run shape inference?"
)
class OnnxImportError(Exception):
...
class NodeImporter:
"""Imports graph nodes into MLIR.
Typically, the top level graph will be imported into a func whereas dependent
graphs may just be imported with references to pre-existing values.
Note that ONNX requires that graphs be sorted topologically and free of cycles,
so we don't take any special steps to order them for dominance.
"""
__slots__ = [
"_c",
"_cc",
"_gi",
"_p",
"_b",
"_nv_map",
]
def __init__(
self,
graph_info: GraphInfo,
*,
parent_op: Operation,
block: Block,
context_cache: "ContextCache",
):
self._c = parent_op.context
self._cc = context_cache
self._gi = graph_info
self._p = parent_op
self._b = block
self._nv_map: dict[str, Value] = {}
@classmethod
def define_function(
cls, graph_info: GraphInfo, module_op: Operation
) -> "NodeImporter":
cc = ContextCache(module_op.context)
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
body = module_op.regions[0].blocks[0]
func_name = graph_info.graph_proto.name
input_types = [
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
]
output_types = [
cc.type_proto_to_type(out.type)
for out in graph_info.output_map.values()
]
ftype = FunctionType.get(input_types, output_types)
func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body))
block = func_op.add_entry_block(
[Location.name(k) for k in graph_info.input_map.keys()]
)
imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc)
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
imp._nv_map[node_name] = input_value
imp._populate_graph_attrs(func_op)
return imp
def _populate_graph_attrs(self, container_op: Operation):
"""Populates graph level meta attributes on the given container op."""
m = self._gi.model_info.model_proto
with container_op.context:
i64_type = IntegerType.get_signed(64)
default_opset_version = 0
opset_versions: dict[str, IntegerAttr] = {}
for opset_import in m.opset_import:
if opset_import.domain:
opset_versions[opset_import.domain] = IntegerAttr.get(
i64_type, opset_import.version
)
else:
default_opset_version = opset_import.version
if default_opset_version:
container_op.attributes[
"torch.onnx_meta.opset_version"
] = IntegerAttr.get(i64_type, default_opset_version)
if opset_versions:
container_op.attributes[
"torch.onnx_meta.opset_versions"
] = DictAttr.get(opset_versions)
container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get(
IntegerType.get_signed(64), m.ir_version
)
container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get(
m.producer_name
)
container_op.attributes[
"torch.onnx_meta.producer_version"
] = StringAttr.get(m.producer_version)
def import_all(self):
"""Imports all nodes topologically."""
# TODO: Consider pulling in initializers on demand since there can be so
# much unused crap.
for init in self._gi.initializer_map.values():
self.import_initializer(init)
for node in self._gi.graph_proto.node:
self.import_node(node)
outputs = []
for output_name in self._gi.output_map.keys():
try:
outputs.append(self._nv_map[output_name])
except KeyError:
raise OnnxImportError(
f"Non topologically produced ONNX graph output '{output_name}'"
)
with InsertionPoint(self._b), Location.unknown():
func_dialect.ReturnOp(outputs)
def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
# Handle special op types that materialize to non-op IR constructs.
special_key = f"_handle_node_{op_type}"
if hasattr(self, special_key):
getattr(self, special_key)(node)
return
# General node import.
input_values = []
for input_name in node.input:
try:
input_values.append(self._nv_map[input_name])
except KeyError:
raise OnnxImportError(
f"Non topologically produced ONNX node input '{input_name}': {node}"
)
output_names = list(node.output)
output_types = [
self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n))
for n in output_names
]
# TODO: Attributes.
attrs = {
"name": StringAttr.get(f"onnx.{op_type}"),
}
self.import_attributes(node.attribute, attrs)
custom_op = Operation.create(
name="torch.operator",
results=output_types,
operands=input_values,
attributes=attrs,
)
for output_name, output_value in zip(output_names, custom_op.results):
self._nv_map[output_name] = output_value
def import_attributes(
self, onnx_attrs: list[onnx.AttributeProto], attrs: dict[str, Attribute]
):
for onnx_attr in onnx_attrs:
attr_type = onnx_attr.type
if attr_type not in ATTRIBUTE_TYPE_HANDLERS:
raise OnnxImportError(
f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}"
)
handler = ATTRIBUTE_TYPE_HANDLERS[attr_type]
if handler is None:
# Active skip.
continue
elif handler is False:
# Active error.
raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. "
f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}"
)
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)
def import_initializer(self, initializer: onnx.TensorProto) -> Value:
with InsertionPoint(self._b), Location.name(initializer.name):
value_attr = self._cc.tensor_proto_to_attr(initializer)
vtensor_type = self._cc.tensor_proto_to_type(initializer)
literal_op = Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
attributes={"value": value_attr},
)
self._nv_map[initializer.name] = literal_op.result
return literal_op.result
def _get_immediate_tensor(self, name: str) -> np.array:
try:
initializer = self._gi.initializer_map[name]
except KeyError:
raise OnnxImportError(
f"An immediate value for '{name}' was required but it is dynamically produced."
)
try:
dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type]
except KeyError:
raise OnnxImportError(
f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}"
)
raw_data = initializer.raw_data
if raw_data:
return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims))
else:
raise OnnxImportError(
f"Unhandled ONNX TensorProto immediate data: {initializer}"
)
def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
# This op is special: It has an input of the shape, and in full generality
# could involve eager production of constants of variable size. In
# practice, the DNN profile for ONNX makes this very difficult to do
# and we hard-assert that the input can be resolved to an immediate
# value.
assert len(node.input) == 1
assert len(node.output) == 1
shape = self._get_immediate_tensor(node.input[0]).astype(np.int64)
value_proto = _get_attr(node, "value")
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
tensor_proto = value_proto.t
element_type = self._cc.tensor_element_type(tensor_proto.data_type)
vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type)
assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1
try:
cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type]
except KeyError:
raise OnnxImportError(
f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)"
)
value_attr = cb(tensor_proto, tuple(shape))
literal_op = Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
attributes={"value": value_attr},
)
self._nv_map[node.output[0]] = literal_op.result
class ContextCache:
"""Caches per-context lookups of various things."""
__slots__ = [
"_c",
"_elem_type_map",
"_vtensor_type_map",
]
def __init__(self, context: Context):
self._c = context
self._elem_type_map: dict[int, IrType] = {}
self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {}
def tensor_element_type(self, elem_type: int) -> IrType:
t = self._elem_type_map.get(elem_type)
if t is None:
try:
with self._c:
t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]()
except KeyError:
raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}")
self._elem_type_map[elem_type] = t
return t
def get_vtensor_type(
self, dims: tuple[Optional[int]], element_type: IrType
) -> IrType:
key = (dims, element_type)
t = self._vtensor_type_map.get(key)
if t is None:
shape_asm = ",".join("?" if d is None else str(d) for d in dims)
asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>"
try:
t = IrType.parse(asm, context=self._c)
except MLIRError as e:
raise OnnxImportError(
f"Unparseable torch type (MLIR asm format bug?): {asm}"
) from e
self._vtensor_type_map[key] = t
return t
def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType:
element_type = self.tensor_element_type(tp.data_type)
return self.get_vtensor_type(tuple(tp.dims), element_type)
def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType:
element_type = self.tensor_element_type(tp.data_type)
# TODO: Fixme upstream: RankedTensorType.get should not require a location.
with Location.unknown():
return RankedTensorType.get(tuple(tp.dims), element_type)
def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType:
if tp.tensor_type:
tt = tp.tensor_type
if not tt.shape:
raise OnnxImportError(
f"Unsupported Tensor type without shape (run shape inference?): {tp}"
)
element_type = self.tensor_element_type(tt.elem_type)
dims = tuple(
(d.dim_value if not d.dim_param else None) for d in tt.shape.dim
)
return self.get_vtensor_type(dims, element_type)
else:
# TODO: Others if ever needed. Or we consider ourselves DNN-only.
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
tensor_type = self.tensor_proto_to_builtin_type(tp)
if tp.HasField("raw_data"):
# Conveniently, DenseResourceElementsAttr shares the raw data
# format. We just give it maximum numeric alignment.
return DenseResourceElementsAttr.get_from_buffer(
tp.raw_data, tp.name, tensor_type, alignment=8
)
else:
# We have to do a data type specific instantiation from proto fields.
# Since this is typically used for small tensor constants, we instantiate
# as a DenseElementsAttr.
handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type)
if handler is None:
raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}")
return handler(tp)
ELEM_TYPE_TO_IR_TYPE_CB = {
onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(),
onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8),
onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8),
onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16),
onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16),
onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32),
onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64),
onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1),
onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(),
onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(),
onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32),
onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64),
onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()),
onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()),
onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(),
onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(),
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
# Ommitted: STRING,
}
ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = {
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
}
# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr
# of the builtin tensor type for cases where the tensor data is inlined as typed
# values instead of raw_data.
ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get(
np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get(
np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims)
),
onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get(
# Special case. See proto
np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims),
signless=False,
),
onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get(
np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False
)
# Intentionally unsupported: STRING
}
ELEM_TYPE_TO_NUMPY_DTYPE = {
onnx.TensorProto.DataType.FLOAT: np.float32,
onnx.TensorProto.DataType.UINT8: np.uint8,
onnx.TensorProto.DataType.INT8: np.int8,
onnx.TensorProto.DataType.UINT16: np.uint16,
onnx.TensorProto.DataType.INT16: np.int16,
onnx.TensorProto.DataType.INT32: np.int32,
onnx.TensorProto.DataType.INT64: np.int64,
onnx.TensorProto.DataType.BOOL: np.bool_,
onnx.TensorProto.DataType.FLOAT16: np.float16,
onnx.TensorProto.DataType.DOUBLE: np.float64,
onnx.TensorProto.DataType.UINT32: np.uint32,
onnx.TensorProto.DataType.UINT64: np.uint64,
onnx.TensorProto.DataType.COMPLEX64: np.complex64,
onnx.TensorProto.DataType.COMPLEX128: np.complex128,
# onnx.TensorProto.DataType.BFLOAT16:
# onnx.TensorProto.DataType.FLOAT8E4M3FN:
# onnx.TensorProto.DataType.FLOAT8E4M3FNUZ:
# onnx.TensorProto.DataType.FLOAT8E5M2:
# onnx.TensorProto.DataType.FLOAT8E5M2FNUZ:
# Ommitted: STRING,
}
# Mapping of AttributeType code to one of:
# None: Ignore attribute and do not output to MLIR
# False: Error if an attribute of this type is present
# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute
ATTRIBUTE_TYPE_HANDLERS = {
onnx.AttributeProto.AttributeType.UNDEFINED: False,
onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get(
F32Type.get(), a.f
),
onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get(
IntegerType.get_signed(64), a.i
),
onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s),
onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr(
a.t
),
onnx.AttributeProto.AttributeType.GRAPH: False,
onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False,
onnx.AttributeProto.AttributeType.TYPE_PROTO: False,
onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get(
[FloatAttr.get(F32Type.get(), f) for f in a.floats]
),
onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints]
),
onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get(
[StringAttr.get(s) for s in a.strings]
),
onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get(
[cc.tensor_proto_to_attr(t) for t in a.tensors]
),
onnx.AttributeProto.AttributeType.GRAPHS: False,
onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False,
onnx.AttributeProto.AttributeType.TYPE_PROTOS: False,
}
def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto:
for attr in node.attribute:
if attr.name == attr_name:
return attr
else:
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")

View File

@ -47,8 +47,6 @@ PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1"
# If true, enable LTC build by default # If true, enable LTC build by default
TORCH_MLIR_ENABLE_LTC_DEFAULT = True TORCH_MLIR_ENABLE_LTC_DEFAULT = True
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False))
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
import torch
# Build phase discovery is unreliable. Just tell it what phases to run. # Build phase discovery is unreliable. Just tell it what phases to run.
class CustomBuild(_build): class CustomBuild(_build):
@ -91,7 +89,7 @@ class CMakeBuild(build_py):
f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_C_VISIBILITY_PRESET=hidden",
f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden",
f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}",
f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}", f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}",
] ]
os.makedirs(cmake_build_dir, exist_ok=True) os.makedirs(cmake_build_dir, exist_ok=True)
@ -145,8 +143,31 @@ with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read() long_description = fh.read()
# Requires and extension modules depend on whether building PyTorch
# extensions.
INSTALL_REQUIRES = [
"numpy",
"packaging",
]
EXT_MODULES = [
CMakeExtension("torch_mlir._mlir_libs._torchMlir"),
]
NAME = "torch-mlir-core"
# If building PyTorch extensions, customize.
if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS:
import torch
NAME = "torch-mlir"
INSTALL_REQUIRES.extend([
f"torch=={torch.__version__}".split("+", 1)[0],
])
EXT_MODULES.extend([
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"),
])
setup( setup(
name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core", name=NAME,
version=f"{PACKAGE_VERSION}", version=f"{PACKAGE_VERSION}",
author="Sean Silva", author="Sean Silva",
author_email="silvasean@google.com", author_email="silvasean@google.com",
@ -159,10 +180,12 @@ setup(
"built_ext": NoopBuildExtension, "built_ext": NoopBuildExtension,
"build_py": CMakeBuild, "build_py": CMakeBuild,
}, },
ext_modules=[ ext_modules=EXT_MODULES,
CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), install_requires=INSTALL_REQUIRES,
] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], extras_require={
install_requires=["numpy", "packaging"] + ( "onnx": [
[f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), "onnx>=1.15.0",
],
}
zip_safe=False, zip_safe=False,
) )

View File

@ -1,3 +1,4 @@
pillow pillow
dill dill
multiprocess multiprocess
onnx==1.15.0

View File

@ -0,0 +1,2 @@
if not config.enable_bindings_python:
config.unsupported = True

View File

@ -0,0 +1 @@
output/

View File

@ -0,0 +1,19 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s
"""This file exists so that the tests can find/configure torch_mlir.
It allows the test file to be standalone and used verbatim in other
projects (i.e. by just providing this file on the side).
"""
from torch_mlir import ir
from torch_mlir.extras import onnx_importer
def configure_context(context):
from torch_mlir.dialects import torch as torch_d
torch_d.register_dialect(context)

View File

@ -0,0 +1,374 @@
# Based on code Copyright (c) Advanced Micro Devices, Inc.
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
# RUN: %PYTHON %s --output %t
from glob import glob
from pathlib import Path
import logging
import sys
import unittest
import onnx
from _torch_mlir_config import (
configure_context,
ir,
onnx_importer,
)
# Accept the output path on the command line or default to a sibling
# to this file. We have to pop this off explicitly or else unittest
# won't understand.
if len(sys.argv) > 1 and sys.argv[1] == "--output":
OUTPUT_PATH = Path(sys.argv[2])
del sys.argv[1:3]
else:
OUTPUT_PATH = Path(__file__).resolve().parent / "output"
# TODO: Add some verification and overrides. For now, just use the
# onnx package install for onnx test files, since they were nice
# enough to include the test suite in the deployable.
import onnx.backend.test.data
ONNX_TEST_DATA_DIR = Path(onnx.backend.test.__file__).resolve().parent / "data"
print(f"ONNX Test Data Dir: {ONNX_TEST_DATA_DIR}")
ONNX_REL_PATHS = glob(f"**/*.onnx", root_dir=ONNX_TEST_DATA_DIR, recursive=True)
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
TEST_CAST_XFAILS = [
"light_light_bvlc_alexnet",
"light_light_inception_v1",
"light_light_squeezenet",
"light_light_vgg19",
"node_test_affine_grid_2d_align_corners_expanded_model",
"node_test_affine_grid_2d_expanded_model",
"node_test_affine_grid_3d_align_corners_expanded_model",
"node_test_affine_grid_3d_expanded_model",
"node_test_ai_onnx_ml_label_encoder_string_int_model",
"node_test_ai_onnx_ml_label_encoder_string_int_no_default_model",
"node_test_ai_onnx_ml_label_encoder_tensor_mapping_model",
"node_test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_model",
"node_test_cast_FLOAT16_to_FLOAT8E4M3FNUZ_model",
"node_test_cast_FLOAT16_to_FLOAT8E4M3FN_model",
"node_test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_model",
"node_test_cast_FLOAT16_to_FLOAT8E5M2_model",
"node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_model",
"node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT_model",
"node_test_cast_FLOAT8E4M3FN_to_FLOAT16_model",
"node_test_cast_FLOAT8E4M3FN_to_FLOAT_model",
"node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT16_model",
"node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT_model",
"node_test_cast_FLOAT8E5M2_to_FLOAT16_model",
"node_test_cast_FLOAT8E5M2_to_FLOAT_model",
"node_test_cast_FLOAT_to_FLOAT8E4M3FNUZ_model",
"node_test_cast_FLOAT_to_FLOAT8E4M3FN_model",
"node_test_cast_FLOAT_to_FLOAT8E5M2FNUZ_model",
"node_test_cast_FLOAT_to_FLOAT8E5M2_model",
"node_test_cast_FLOAT_to_STRING_model",
"node_test_cast_STRING_to_FLOAT_model",
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_model",
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_model",
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_model",
"node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2_model",
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ_model",
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN_model",
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ_model",
"node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2_model",
"node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded_model",
"node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_model",
"node_test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded_model",
"node_test_castlike_FLOAT8E4M3FN_to_FLOAT_model",
"node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded_model",
"node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_model",
"node_test_castlike_FLOAT8E5M2_to_FLOAT_expanded_model",
"node_test_castlike_FLOAT8E5M2_to_FLOAT_model",
"node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded_model",
"node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_model",
"node_test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded_model",
"node_test_castlike_FLOAT_to_FLOAT8E4M3FN_model",
"node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded_model",
"node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_model",
"node_test_castlike_FLOAT_to_FLOAT8E5M2_expanded_model",
"node_test_castlike_FLOAT_to_FLOAT8E5M2_model",
"node_test_castlike_FLOAT_to_STRING_expanded_model",
"node_test_castlike_FLOAT_to_STRING_model",
"node_test_castlike_STRING_to_FLOAT_expanded_model",
"node_test_castlike_STRING_to_FLOAT_model",
"node_test_center_crop_pad_crop_axes_chw_expanded_model",
"node_test_center_crop_pad_crop_axes_hwc_expanded_model",
"node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model",
"node_test_clip_default_inbounds_model",
"node_test_clip_default_int8_inbounds_model",
"node_test_clip_default_int8_max_model",
"node_test_clip_default_max_model",
"node_test_constantofshape_float_ones_model",
"node_test_constantofshape_int_shape_zero_model",
"node_test_constantofshape_int_zeros_model",
"node_test_dequantizelinear_e4m3fn_model",
"node_test_dequantizelinear_e4m3fn_zero_point_model",
"node_test_dequantizelinear_e5m2_model",
"node_test_dft_axis_model",
"node_test_dft_inverse_model",
"node_test_dft_model",
"node_test_equal_string_broadcast_model",
"node_test_equal_string_model",
"node_test_gru_defaults_model",
"node_test_gru_seq_length_model",
"node_test_gru_with_initial_bias_model",
"node_test_identity_opt_model",
"node_test_identity_sequence_model",
"node_test_if_model",
"node_test_if_opt_model",
"node_test_if_seq_model",
"node_test_layer_normalization_2d_axis0_expanded_model",
"node_test_layer_normalization_2d_axis0_expanded_ver18_model",
"node_test_layer_normalization_2d_axis1_expanded_model",
"node_test_layer_normalization_2d_axis1_expanded_ver18_model",
"node_test_layer_normalization_2d_axis_negative_1_expanded_model",
"node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model",
"node_test_layer_normalization_2d_axis_negative_2_expanded_model",
"node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model",
"node_test_layer_normalization_3d_axis0_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis1_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis2_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model",
"node_test_layer_normalization_4d_axis0_expanded_model",
"node_test_layer_normalization_4d_axis0_expanded_ver18_model",
"node_test_layer_normalization_4d_axis1_expanded_model",
"node_test_layer_normalization_4d_axis1_expanded_ver18_model",
"node_test_layer_normalization_4d_axis2_expanded_model",
"node_test_layer_normalization_4d_axis2_expanded_ver18_model",
"node_test_layer_normalization_4d_axis3_expanded_model",
"node_test_layer_normalization_4d_axis3_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_1_expanded_model",
"node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_2_expanded_model",
"node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_3_expanded_model",
"node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_4_expanded_model",
"node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model",
"node_test_layer_normalization_default_axis_expanded_model",
"node_test_layer_normalization_default_axis_expanded_ver18_model",
"node_test_loop11_model",
"node_test_loop13_seq_model",
"node_test_loop16_seq_none_model",
"node_test_lstm_defaults_model",
"node_test_lstm_with_initial_bias_model",
"node_test_lstm_with_peepholes_model",
"node_test_optional_get_element_optional_sequence_model",
"node_test_optional_get_element_optional_tensor_model",
"node_test_optional_get_element_sequence_model",
"node_test_optional_has_element_empty_no_input_name_optional_input_model",
"node_test_optional_has_element_empty_no_input_name_tensor_input_model",
"node_test_optional_has_element_empty_optional_input_model",
"node_test_optional_has_element_optional_input_model",
"node_test_optional_has_element_tensor_input_model",
"node_test_quantizelinear_e4m3fn_model",
"node_test_quantizelinear_e5m2_model",
"node_test_range_float_type_positive_delta_expanded_model",
"node_test_range_int32_type_negative_delta_expanded_model",
"node_test_regex_full_match_basic_model",
"node_test_regex_full_match_email_domain_model",
"node_test_regex_full_match_empty_model",
"node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_downsample_scales_cubic_align_corners_model",
"node_test_resize_downsample_scales_cubic_antialias_model",
"node_test_resize_downsample_scales_cubic_model",
"node_test_resize_downsample_scales_linear_align_corners_model",
"node_test_resize_downsample_scales_linear_antialias_model",
"node_test_resize_downsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_downsample_scales_linear_model",
"node_test_resize_downsample_scales_nearest_model",
"node_test_resize_downsample_sizes_cubic_antialias_model",
"node_test_resize_downsample_sizes_cubic_model",
"node_test_resize_downsample_sizes_linear_antialias_model",
"node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model",
"node_test_resize_downsample_sizes_nearest_model",
"node_test_resize_downsample_sizes_nearest_not_larger_model",
"node_test_resize_downsample_sizes_nearest_not_smaller_model",
"node_test_resize_tf_crop_and_resize_axes_2_3_model",
"node_test_resize_tf_crop_and_resize_axes_3_2_model",
"node_test_resize_tf_crop_and_resize_model",
"node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_upsample_scales_cubic_align_corners_model",
"node_test_resize_upsample_scales_cubic_asymmetric_model",
"node_test_resize_upsample_scales_cubic_model",
"node_test_resize_upsample_scales_linear_align_corners_model",
"node_test_resize_upsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_upsample_scales_linear_model",
"node_test_resize_upsample_scales_nearest_axes_2_3_model",
"node_test_resize_upsample_scales_nearest_axes_3_2_model",
"node_test_resize_upsample_scales_nearest_model",
"node_test_resize_upsample_sizes_cubic_model",
"node_test_resize_upsample_sizes_nearest_axes_2_3_model",
"node_test_resize_upsample_sizes_nearest_axes_3_2_model",
"node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model",
"node_test_resize_upsample_sizes_nearest_floor_align_corners_model",
"node_test_resize_upsample_sizes_nearest_model",
"node_test_resize_upsample_sizes_nearest_not_larger_model",
"node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model",
"node_test_rnn_seq_length_model",
"node_test_scan9_sum_model",
"node_test_scan_sum_model",
"node_test_sequence_insert_at_back_model",
"node_test_sequence_insert_at_front_model",
"node_test_sequence_map_add_1_sequence_1_tensor_expanded_model",
"node_test_sequence_map_add_1_sequence_1_tensor_model",
"node_test_sequence_map_add_2_sequences_expanded_model",
"node_test_sequence_map_add_2_sequences_model",
"node_test_sequence_map_extract_shapes_expanded_model",
"node_test_sequence_map_extract_shapes_model",
"node_test_sequence_map_identity_1_sequence_1_tensor_expanded_model",
"node_test_sequence_map_identity_1_sequence_1_tensor_model",
"node_test_sequence_map_identity_1_sequence_expanded_model",
"node_test_sequence_map_identity_1_sequence_model",
"node_test_sequence_map_identity_2_sequences_expanded_model",
"node_test_sequence_map_identity_2_sequences_model",
"node_test_simple_rnn_defaults_model",
"node_test_simple_rnn_with_initial_bias_model",
"node_test_split_to_sequence_1_model",
"node_test_split_to_sequence_2_model",
"node_test_split_to_sequence_nokeepdims_model",
"node_test_stft_model",
"node_test_string_concat_broadcasting_model",
"node_test_string_concat_empty_string_model",
"node_test_string_concat_model",
"node_test_string_concat_utf8_model",
"node_test_string_concat_zero_dimensional_model",
"node_test_string_split_basic_model",
"node_test_string_split_consecutive_delimiters_model",
"node_test_string_split_empty_string_delimiter_model",
"node_test_string_split_empty_tensor_model",
"node_test_string_split_maxsplit_model",
"node_test_string_split_no_delimiter_model",
"node_test_strnormalizer_export_monday_casesensintive_lower_model",
"node_test_strnormalizer_export_monday_casesensintive_nochangecase_model",
"node_test_strnormalizer_export_monday_casesensintive_upper_model",
"node_test_strnormalizer_export_monday_empty_output_model",
"node_test_strnormalizer_export_monday_insensintive_upper_twodim_model",
"node_test_strnormalizer_nostopwords_nochangecase_model",
"simple_test_sequence_model1_model",
"simple_test_sequence_model2_model",
"simple_test_sequence_model3_model",
"simple_test_sequence_model4_model",
"simple_test_sequence_model5_model",
"simple_test_sequence_model6_model",
"simple_test_sequence_model7_model",
"simple_test_sequence_model8_model",
"simple_test_strnorm_model_monday_casesensintive_lower_model",
"simple_test_strnorm_model_monday_casesensintive_nochangecase_model",
"simple_test_strnorm_model_monday_casesensintive_upper_model",
"simple_test_strnorm_model_monday_empty_output_model",
"simple_test_strnorm_model_monday_insensintive_upper_twodim_model",
"simple_test_strnorm_model_nostopwords_nochangecase_model",
]
class ImportSmokeTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.unexpected_failure_count = 0
ImportSmokeTest.actual_failures = []
@classmethod
def tearDownClass(cls):
if cls.unexpected_failure_count:
# Print a helpful message with copy-paste XFAIL def.
failure_report_path = OUTPUT_PATH / "import_smoke_test_report.txt"
print(
"Unexpected failures. Writing copy/paste report to:",
failure_report_path,
)
with open(failure_report_path, "wt") as f:
lines = [f' "{s}",' for s in ImportSmokeTest.actual_failures]
print(
f"Unexpected failures in the following. Copy/paste to update `TEST_CAST_XFAILS`:",
file=f,
)
print(f"TEST_CAST_XFAILS = [", file=f)
[print(l, file=f) for l in lines]
print(f"]", file=f)
ImportSmokeTest.actual_failures.clear()
def load_onnx_model(self, file_path: Path) -> onnx.ModelProto:
raw_model = onnx.load(file_path)
try:
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e:
print("WARNING: Shape inference failure (skipping test):", e)
self.skipTest(reason="shape inference failure")
# inferred_model = raw_model
return inferred_model
def run_import_test(self, norm_name: str, rel_path: str):
context = ir.Context()
configure_context(context)
model_info = onnx_importer.ModelInfo(
self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path),
)
m = model_info.create_module(context=context)
try:
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()
m.verify()
finally:
# Use a ".txt" extension to avoid lit test discovery.
with open(OUTPUT_PATH / f"{norm_name}.mlir", "wt") as f:
print(m.get_asm(), file=f)
def testExists(self):
# We expect a lot of test cases. Die if not the case (i.e. if paths change
# or something).
self.assertGreater(len(ONNX_REL_PATHS), 10)
# Generate test methods for each onnx file.
for _rel_path in ONNX_REL_PATHS:
def attach_test(rel_path):
norm_name = rel_path.removesuffix(".onnx").replace("/", "_")
def test_method(self: ImportSmokeTest):
try:
self.run_import_test(norm_name, rel_path)
except onnx_importer.OnnxImportError as e:
# All legitimate failures should be caught and reported
# as an OnnxImportError.
ImportSmokeTest.actual_failures.append(norm_name)
if norm_name not in TEST_CAST_XFAILS:
ImportSmokeTest.unexpected_failure_count += 1
raise e
test_method.__name__ = f"test_{norm_name}"
if norm_name in TEST_CAST_XFAILS:
test_method = unittest.expectedFailure(test_method)
setattr(ImportSmokeTest, test_method.__name__, test_method)
attach_test(_rel_path)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

View File

@ -0,0 +1,5 @@
try:
import onnx
except ModuleNotFoundError:
print("Skipping onnx tests.. no onnx")
config.unsupported = True