diff --git a/CMakeLists.txt b/CMakeLists.txt index ccbe7ccb3..f821d6003 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ project(torch-mlir LANGUAGES CXX C) set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_STANDARD 17) +include(CMakeDependentOption) + #------------------------------------------------------------------------------- # Project options #------------------------------------------------------------------------------- @@ -43,24 +45,11 @@ endif() option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF) -# PT1 options. -option(TORCH_MLIR_ENABLE_PROJECT_PT1 "Enables the PyTorch1 project under projects/pt1" OFF) -# TODO: Rename/scope these. They use historic names for now to ease migration -# burden. -option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) -option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) -option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF) -if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF) - set(TORCH_MLIR_ENABLE_LTC OFF) -endif() -# Force enable the PT1 project if either the JIT_IR_IMPORTER or LTC is enabled. -if(NOT TORCH_MLIR_ENABLE_PROJECT_PT1) - if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) - message(STATUS "Enabling projects/pt1 because features requiring it are enabled") - set(TORCH_MLIR_ENABLE_PROJECT_PT1 ON) - endif() -endif() +# PyTorch native extension gate. If OFF, then no features which depend on +# native extensions will be built. +option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON) +cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) +cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build @@ -235,4 +224,16 @@ endif() # Sub-projects #------------------------------------------------------------------------------- +# Sub-projects can bundle additional PyTorch extensions by adding them to this +# source target. It is typically empty unless if features are enabled. +if(MLIR_ENABLE_BINDINGS_PYTHON) + declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources) +endif() + +# Build projects first as it may populate additional Python deps. add_subdirectory(projects) + +# Finish with top-level Python bindings so it can handle additional deps. +if(MLIR_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() \ No newline at end of file diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 2a909266f..f0336b2a1 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -351,7 +351,6 @@ function setup_venv() { echo ":::: Using stable dependencies" python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt - python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt ;; *) echo "Unrecognized torch version '$torch_version'" @@ -359,6 +358,7 @@ function setup_venv() { ;; esac + python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt } function build_out_of_tree() { diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/docs/importers/onnx_importer.md similarity index 90% rename from include/torch-mlir/Conversion/TorchOnnxToTorch/README.md rename to docs/importers/onnx_importer.md index 6de1cc923..acc45bb2e 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md +++ b/docs/importers/onnx_importer.md @@ -3,11 +3,8 @@ We enable the direct representation of many ONNX features directly in the `torch` dialect as `torch.operator` custom ops with names like `onnx.{OperatorName}`. The majority of ONNX operators are represented -with a systematic transformation. See -[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) -for the reference importer which complies with the rules below -(this is planned to be upstreamed to torch-mlir proper in the near -future). +with a systematic transformation. `torch_mlir.extras.onnx_importer` +for the reference importer which complies with the rules below. ## Adding new ONNX operators @@ -26,10 +23,11 @@ are relatively straight-forward to map, following this general procedure: * Open the corresponding implementation file `DefaultDomainXtoY.cpp` corresponding with the alphabetic sort of the op and add a conversion. * Generate successful test cases: - * Either run the Turbine importer to produce MLIR output for all - ops/models in the ONNX test suite or use a dump that someone has - generated: - * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * All `onnx_importer.py` tests are dumped to the test temp dir (success + or failure). This is typically located under + `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files + under there should provide good variants to drive lit test coverage of + conversion. * There are often many variants of tests for checking conformance of different historic ONNX encodings, but these are often not load bearing at the MLIR level. diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index 4b54be65a..d4fead890 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -1,7 +1,31 @@ include(AddMLIRPython) +################################################################################ +# PyTorch # Configure PyTorch if we have any features enabled which require it. +################################################################################ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) + + if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) + # Source builds + message(STATUS "Building libtorch from source (features depend on it and NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)") + set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) + set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) + set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) + set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) + set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) + set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) + set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) + execute_process( + COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh + RESULT_VARIABLE _result + ) + if(_result) + message(FATAL_ERROR "Failed to run `build_libtorch.sh`") + endif() + set(TORCH_INSTALL_PREFIX "libtorch") + endif() + message(STATUS "Enabling PyTorch C++ dep (features depend on it)") include(TorchMLIRPyTorch) @@ -48,6 +72,6 @@ if(TORCH_MLIR_ENABLE_LTC) endif() # Include overall PT1 project. -if(TORCH_MLIR_ENABLE_PROJECT_PT1) +if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) add_subdirectory(pt1) endif() diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index e951772df..ce4042698 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -7,79 +7,22 @@ set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) # argument. set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") - # We vendor our own MLIR instance in the `torch_mlir` namespace. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") -################################################################################ -# PyTorch -################################################################################ +# ################################################################################ +# # Sources +# ################################################################################ -if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) - # Source builds - set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) - set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) - set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) - set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) - set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) - set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) - set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) - execute_process( - COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh - RESULT_VARIABLE _result - ) - if(_result) - message(FATAL_ERROR "Failed to run `build_libtorch.sh`") - endif() - set(TORCH_INSTALL_PREFIX "libtorch") -endif() - -################################################################################ -# Sources -################################################################################ - -declare_mlir_python_sources(TorchMLIRPythonSources) -declare_mlir_python_sources(TorchMLIRPythonExtensions) - -if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES - __init__.py - _dynamo_fx_importer.py - compiler_utils.py - dynamo.py - _version.py - ) -endif() - -declare_mlir_python_sources(TorchMLIRPythonSources.Dialects +declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources -) - -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT TorchMLIRPythonSources.Dialects - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - TD_FILE dialects/TorchBinding.td - SOURCES dialects/torch/__init__.py - DIALECT_NAME torch -) - -################################################################################ -# Extensions -################################################################################ - -declare_mlir_python_extension(TorchMLIRPythonExtensions.Main - MODULE_NAME _torchMlir - ADD_TO_PARENT TorchMLIRPythonExtensions + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES - TorchMLIRModule.cpp - EMBED_CAPI_LINK_LIBS - TorchMLIRCAPI - PRIVATE_LINK_LIBS - LLVMSupport + __init__.py + _dynamo_fx_importer.py + compiler_utils.py + dynamo.py + _version.py ) ################################################################################ @@ -110,56 +53,23 @@ endif() # add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) -################################################################################ -# Generate packages and shared library -# Downstreams typically will not use these, but they are useful for local -# testing. -################################################################################ - -set(_source_components - # TODO: Core is now implicitly building/registering all dialects, increasing - # build burden by ~5x. Make it stop. - # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes - # for the reference backend, but logically they can be separate. But seemingly - # the only way to handle that is to create a separate mlir python package - # tree, which seems excessive. - MLIRPythonSources - MLIRPythonExtension.Core - MLIRPythonExtension.RegisterEverything - TorchMLIRPythonSources - TorchMLIRPythonExtensions -) - -add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI - INSTALL_COMPONENT TorchMLIRPythonModules - INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs - OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" - RELATIVE_INSTALL_ROOT "../../../.." - DECLARED_SOURCES ${_source_components} -) - -add_mlir_python_modules(TorchMLIRPythonModules - ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" - INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" - DECLARED_SOURCES ${_source_components} - COMMON_CAPI_LINK_LIBS - TorchMLIRAggregateCAPI - ) - # TODO: Find a cleaner way to do this. # Can we build the JIT IR importer with `declare_mlir_python_extension`? # Then it would "just work". if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind) - # Build the E2E Tests (which depend on the JIT IR importer now). - add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + TorchMLIRJITIRImporter + TorchMLIRJITIRImporterPybind + TorchMLIRE2ETestPythonModules + ) endif() if(TORCH_MLIR_ENABLE_LTC) # Add Torch-MLIR LTC backend as dependency - add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) - add_dependencies(TorchMLIRPythonModules reference_lazy_backend) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + torch_mlir_ltc_backend + reference_lazy_backend + ) endif() add_subdirectory(test) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt index c2883b3dc..6c2ccf62e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt @@ -4,9 +4,9 @@ ## Declare the sources of the Python module. -declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter +declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES_GLOB jit_ir_importer/*.py ) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 000000000..7b9bf12f2 --- /dev/null +++ b/python/CMakeLists.txt @@ -0,0 +1,94 @@ +# Disables generation of "version soname" (i.e. libFoo.so.), which +# causes pure duplication as part of Python wheels. +set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) + +# The directory at which the Python import tree begins. +# See documentation for `declare_mlir_python_sources`'s ROOT_DIR +# argument. +set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") + + +# We vendor our own MLIR instance in the `torch_mlir` namespace. +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") + +################################################################################ +# Sources +################################################################################ + +declare_mlir_python_sources(TorchMLIRPythonSources) +declare_mlir_python_sources(TorchMLIRPythonExtensions) + +declare_mlir_python_sources(TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + TD_FILE dialects/TorchBinding.td + SOURCES dialects/torch/__init__.py + DIALECT_NAME torch +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Importers + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + extras/onnx_importer.py +) + +################################################################################ +# Extensions +################################################################################ + +declare_mlir_python_extension(TorchMLIRPythonExtensions.Main + MODULE_NAME _torchMlir + ADD_TO_PARENT TorchMLIRPythonExtensions + SOURCES + TorchMLIRModule.cpp + EMBED_CAPI_LINK_LIBS + TorchMLIRCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) + +################################################################################ +# Generate packages and shared library +# Downstreams typically will not use these, but they are useful for local +# testing. +################################################################################ + +set(_source_components + # TODO: Core is now implicitly building/registering all dialects, increasing + # build burden by ~5x. Make it stop. + # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes + # for the reference backend, but logically they can be separate. But seemingly + # the only way to handle that is to create a separate mlir python package + # tree, which seems excessive. + MLIRPythonSources + MLIRPythonExtension.Core + MLIRPythonExtension.RegisterEverything + TorchMLIRPythonSources + TorchMLIRPythonExtensions + + # Sources related to optional Torch extension dependent features. Typically + # empty unless if project features are enabled. + TorchMLIRPythonTorchExtensionsSources +) + +add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI + INSTALL_COMPONENT TorchMLIRPythonModules + INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs + OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT ".." + DECLARED_SOURCES ${_source_components} +) + +add_mlir_python_modules(TorchMLIRPythonModules + ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" + INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" + DECLARED_SOURCES ${_source_components} + COMMON_CAPI_LINK_LIBS + TorchMLIRAggregateCAPI + ) diff --git a/projects/pt1/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp similarity index 100% rename from projects/pt1/python/TorchMLIRModule.cpp rename to python/TorchMLIRModule.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/TorchBinding.td b/python/torch_mlir/dialects/TorchBinding.td similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/TorchBinding.td rename to python/torch_mlir/dialects/TorchBinding.td diff --git a/projects/pt1/python/torch_mlir/dialects/torch/__init__.py b/python/torch_mlir/dialects/torch/__init__.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/__init__.py rename to python/torch_mlir/dialects/torch/__init__.py diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py new file mode 100644 index 000000000..a9dd52253 --- /dev/null +++ b/python/torch_mlir/extras/onnx_importer.py @@ -0,0 +1,607 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Imports ONNX graphs to `torch` dialect ops. + +See documentation: + https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md + +This file is distributed/forked verbatim into various downstream projects, and +it must abide by several rules above and beyond the rest of the codebase: + - It must be standalone, only depending on: + - `onnx` + - `..ir` relative imports to the main IR directory + - `..dialects.func` relative import to the `func` dialect (TODO: + we are looking to eliminate this dep). + - Python standard library + - It does not directly use the ODS generated `torch` dialect Python + wrappers. This allows it to be used in contexts that only build a C++ + compiler with minimal IR Python bindings. + - It is intended as an enabler for full onnx compilation, only handling + the import from ONNX -> the `torch` dialect. Testing, full pipelines, + and utilities belong elsewhere. +""" + +try: + import onnx +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "The onnx package (`pip install onnx`) is required to use the onnx importer" + ) from e + +from typing import Optional + +from dataclasses import dataclass + +import numpy as np + +from ..ir import ( + ArrayAttr, + Attribute, + Block, + Context, + DenseElementsAttr, + DenseResourceElementsAttr, + DictAttr, + FloatAttr, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E5M2Type, + FunctionType, + InsertionPoint, + IntegerAttr, + IntegerType, + MLIRError, + RankedTensorType, + Location, + Module, + Operation, + StringAttr, + Type as IrType, + Value, +) + +from ..dialects import ( + func as func_dialect, +) + +@dataclass +class Config: + """Various configuration settings for the importer.""" + + # Ancient ONNX exporters would often add a model input for anything that + # might be mutable, providing an initializer for it as well. More modern + # tools tools realized this is a really bad idea for a lot of reasons. + # We choose to assume more recent norms, even if encountering older + # models. Setting this to False probably won't do what you want but + # should produce interesting errors to waste your time deciphering. + # We mainly use it as a way to document in the code that we are + # making an assumption. + elide_initialized_inputs: bool = True + + +class ModelInfo: + """Top-level accounting and accessors for an ONNX model.""" + + def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()): + self.config = config + self.model_proto = model_proto + assert model_proto.graph, "Model must contain a main Graph" + self.main_graph = GraphInfo(self, model_proto.graph) + + def create_module(self, context: Optional[Context] = None) -> Operation: + if not context: + context = Context() + module_op = Module.create(Location.unknown(context)).operation + # TODO: Populate module level metadata from the ModelProto + return module_op + + +class GraphInfo: + """Information about a Graph within a model.""" + + def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + self.model_info = model_info + self.graph_proto = graph_proto + self.initializer_map: dict[str, onnx.TensorProto] = { + n.name: n for n in graph_proto.initializer + } + self.value_info_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.value_info + } + self.declared_input_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.input + } + self.output_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.output + } + + # Generate the effective input map, which for old models can be a + # subset of the input map. + if model_info.config.elide_initialized_inputs: + self.input_map = { + k: v + for k, v in self.declared_input_map.items() + if k not in self.initializer_map + } + else: + self.input_map = self.declared_input_map + illegal_input_keys = self.input_map.keys() - ( + self.input_map.keys() - self.initializer_map.keys() + ) + assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), ( + f"When not in elide_initialized_inputs=True, we expect inputs to not " + f"have an initial value (got {illegal_input_keys})." + ) + + def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: + # Node outputs don't typically have type information, but shape inference + # will associate them in the value_info. If not there, it may be a + # graph output, which must have type information. + value_info = self.value_info_map.get(name) or self.output_map.get(name) + if value_info is not None: + return value_info.type + raise OnnxImportError( + f"No type information associated with '{name}'. Run shape inference?" + ) + + +class OnnxImportError(Exception): + ... + + +class NodeImporter: + """Imports graph nodes into MLIR. + + Typically, the top level graph will be imported into a func whereas dependent + graphs may just be imported with references to pre-existing values. + + Note that ONNX requires that graphs be sorted topologically and free of cycles, + so we don't take any special steps to order them for dominance. + """ + + __slots__ = [ + "_c", + "_cc", + "_gi", + "_p", + "_b", + "_nv_map", + ] + + def __init__( + self, + graph_info: GraphInfo, + *, + parent_op: Operation, + block: Block, + context_cache: "ContextCache", + ): + self._c = parent_op.context + self._cc = context_cache + self._gi = graph_info + self._p = parent_op + self._b = block + self._nv_map: dict[str, Value] = {} + + @classmethod + def define_function( + cls, graph_info: GraphInfo, module_op: Operation + ) -> "NodeImporter": + cc = ContextCache(module_op.context) + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): + body = module_op.regions[0].blocks[0] + func_name = graph_info.graph_proto.name + input_types = [ + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() + ] + output_types = [ + cc.type_proto_to_type(out.type) + for out in graph_info.output_map.values() + ] + ftype = FunctionType.get(input_types, output_types) + func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + block = func_op.add_entry_block( + [Location.name(k) for k in graph_info.input_map.keys()] + ) + imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): + imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) + return imp + + def _populate_graph_attrs(self, container_op: Operation): + """Populates graph level meta attributes on the given container op.""" + m = self._gi.model_info.model_proto + with container_op.context: + i64_type = IntegerType.get_signed(64) + default_opset_version = 0 + opset_versions: dict[str, IntegerAttr] = {} + for opset_import in m.opset_import: + if opset_import.domain: + opset_versions[opset_import.domain] = IntegerAttr.get( + i64_type, opset_import.version + ) + else: + default_opset_version = opset_import.version + if default_opset_version: + container_op.attributes[ + "torch.onnx_meta.opset_version" + ] = IntegerAttr.get(i64_type, default_opset_version) + if opset_versions: + container_op.attributes[ + "torch.onnx_meta.opset_versions" + ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( + IntegerType.get_signed(64), m.ir_version + ) + container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( + m.producer_name + ) + container_op.attributes[ + "torch.onnx_meta.producer_version" + ] = StringAttr.get(m.producer_version) + + def import_all(self): + """Imports all nodes topologically.""" + # TODO: Consider pulling in initializers on demand since there can be so + # much unused crap. + for init in self._gi.initializer_map.values(): + self.import_initializer(init) + for node in self._gi.graph_proto.node: + self.import_node(node) + + outputs = [] + for output_name in self._gi.output_map.keys(): + try: + outputs.append(self._nv_map[output_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX graph output '{output_name}'" + ) + with InsertionPoint(self._b), Location.unknown(): + func_dialect.ReturnOp(outputs) + + def import_node(self, node: onnx.NodeProto): + with InsertionPoint(self._b), Location.name(node.name): + op_type = node.op_type + # Handle special op types that materialize to non-op IR constructs. + special_key = f"_handle_node_{op_type}" + if hasattr(self, special_key): + getattr(self, special_key)(node) + return + + # General node import. + input_values = [] + for input_name in node.input: + try: + input_values.append(self._nv_map[input_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX node input '{input_name}': {node}" + ) + + output_names = list(node.output) + output_types = [ + self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) + for n in output_names + ] + + # TODO: Attributes. + attrs = { + "name": StringAttr.get(f"onnx.{op_type}"), + } + self.import_attributes(node.attribute, attrs) + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + ) + for output_name, output_value in zip(output_names, custom_op.results): + self._nv_map[output_name] = output_value + + def import_attributes( + self, onnx_attrs: list[onnx.AttributeProto], attrs: dict[str, Attribute] + ): + for onnx_attr in onnx_attrs: + attr_type = onnx_attr.type + if attr_type not in ATTRIBUTE_TYPE_HANDLERS: + raise OnnxImportError( + f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}" + ) + handler = ATTRIBUTE_TYPE_HANDLERS[attr_type] + if handler is None: + # Active skip. + continue + elif handler is False: + # Active error. + raise OnnxImportError( + f"ONNX importer does not support generic node attribute type {attr_type}. " + f"This likely means that this is a special node which requires specific " + f"handling in the importer: {onnx_attr}" + ) + attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc) + + def import_initializer(self, initializer: onnx.TensorProto) -> Value: + with InsertionPoint(self._b), Location.name(initializer.name): + value_attr = self._cc.tensor_proto_to_attr(initializer) + vtensor_type = self._cc.tensor_proto_to_type(initializer) + literal_op = Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": value_attr}, + ) + self._nv_map[initializer.name] = literal_op.result + return literal_op.result + + def _get_immediate_tensor(self, name: str) -> np.array: + try: + initializer = self._gi.initializer_map[name] + except KeyError: + raise OnnxImportError( + f"An immediate value for '{name}' was required but it is dynamically produced." + ) + try: + dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type] + except KeyError: + raise OnnxImportError( + f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}" + ) + raw_data = initializer.raw_data + if raw_data: + return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims)) + else: + raise OnnxImportError( + f"Unhandled ONNX TensorProto immediate data: {initializer}" + ) + + def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): + # This op is special: It has an input of the shape, and in full generality + # could involve eager production of constants of variable size. In + # practice, the DNN profile for ONNX makes this very difficult to do + # and we hard-assert that the input can be resolved to an immediate + # value. + assert len(node.input) == 1 + assert len(node.output) == 1 + shape = self._get_immediate_tensor(node.input[0]).astype(np.int64) + value_proto = _get_attr(node, "value") + assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR + tensor_proto = value_proto.t + element_type = self._cc.tensor_element_type(tensor_proto.data_type) + vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type) + assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1 + try: + cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type] + except KeyError: + raise OnnxImportError( + f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)" + ) + value_attr = cb(tensor_proto, tuple(shape)) + literal_op = Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": value_attr}, + ) + self._nv_map[node.output[0]] = literal_op.result + + +class ContextCache: + """Caches per-context lookups of various things.""" + + __slots__ = [ + "_c", + "_elem_type_map", + "_vtensor_type_map", + ] + + def __init__(self, context: Context): + self._c = context + self._elem_type_map: dict[int, IrType] = {} + self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} + + def tensor_element_type(self, elem_type: int) -> IrType: + t = self._elem_type_map.get(elem_type) + if t is None: + try: + with self._c: + t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]() + except KeyError: + raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}") + self._elem_type_map[elem_type] = t + return t + + def get_vtensor_type( + self, dims: tuple[Optional[int]], element_type: IrType + ) -> IrType: + key = (dims, element_type) + t = self._vtensor_type_map.get(key) + if t is None: + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._vtensor_type_map[key] = t + return t + + def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + return self.get_vtensor_type(tuple(tp.dims), element_type) + + def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + # TODO: Fixme upstream: RankedTensorType.get should not require a location. + with Location.unknown(): + return RankedTensorType.get(tuple(tp.dims), element_type) + + def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: + if tp.tensor_type: + tt = tp.tensor_type + if not tt.shape: + raise OnnxImportError( + f"Unsupported Tensor type without shape (run shape inference?): {tp}" + ) + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + return self.get_vtensor_type(dims, element_type) + else: + # TODO: Others if ever needed. Or we consider ourselves DNN-only. + # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. + raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") + + def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: + tensor_type = self.tensor_proto_to_builtin_type(tp) + if tp.HasField("raw_data"): + # Conveniently, DenseResourceElementsAttr shares the raw data + # format. We just give it maximum numeric alignment. + return DenseResourceElementsAttr.get_from_buffer( + tp.raw_data, tp.name, tensor_type, alignment=8 + ) + else: + # We have to do a data type specific instantiation from proto fields. + # Since this is typically used for small tensor constants, we instantiate + # as a DenseElementsAttr. + handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type) + if handler is None: + raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}") + return handler(tp) + + +ELEM_TYPE_TO_IR_TYPE_CB = { + onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), + onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8), + onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8), + onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16), + onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16), + onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32), + onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64), + onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1), + onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(), + onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(), + onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32), + onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64), + onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()), + onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()), + onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(), + onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), + onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), + # Ommitted: STRING, +} + +ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) + ), + # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB +} + +# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr +# of the builtin tensor type for cases where the tensor data is inlined as typed +# values instead of raw_data. +ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get( + np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get( + np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims) + ), + onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get( + # Special case. See proto + np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims), + signless=False, + ), + onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False + ) + # Intentionally unsupported: STRING +} + +ELEM_TYPE_TO_NUMPY_DTYPE = { + onnx.TensorProto.DataType.FLOAT: np.float32, + onnx.TensorProto.DataType.UINT8: np.uint8, + onnx.TensorProto.DataType.INT8: np.int8, + onnx.TensorProto.DataType.UINT16: np.uint16, + onnx.TensorProto.DataType.INT16: np.int16, + onnx.TensorProto.DataType.INT32: np.int32, + onnx.TensorProto.DataType.INT64: np.int64, + onnx.TensorProto.DataType.BOOL: np.bool_, + onnx.TensorProto.DataType.FLOAT16: np.float16, + onnx.TensorProto.DataType.DOUBLE: np.float64, + onnx.TensorProto.DataType.UINT32: np.uint32, + onnx.TensorProto.DataType.UINT64: np.uint64, + onnx.TensorProto.DataType.COMPLEX64: np.complex64, + onnx.TensorProto.DataType.COMPLEX128: np.complex128, + # onnx.TensorProto.DataType.BFLOAT16: + # onnx.TensorProto.DataType.FLOAT8E4M3FN: + # onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: + # onnx.TensorProto.DataType.FLOAT8E5M2: + # onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: + # Ommitted: STRING, +} + +# Mapping of AttributeType code to one of: +# None: Ignore attribute and do not output to MLIR +# False: Error if an attribute of this type is present +# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute +ATTRIBUTE_TYPE_HANDLERS = { + onnx.AttributeProto.AttributeType.UNDEFINED: False, + onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get( + F32Type.get(), a.f + ), + onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get( + IntegerType.get_signed(64), a.i + ), + onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s), + onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr( + a.t + ), + onnx.AttributeProto.AttributeType.GRAPH: False, + onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False, + onnx.AttributeProto.AttributeType.TYPE_PROTO: False, + onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get( + [FloatAttr.get(F32Type.get(), f) for f in a.floats] + ), + onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints] + ), + onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get( + [StringAttr.get(s) for s in a.strings] + ), + onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get( + [cc.tensor_proto_to_attr(t) for t in a.tensors] + ), + onnx.AttributeProto.AttributeType.GRAPHS: False, + onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False, + onnx.AttributeProto.AttributeType.TYPE_PROTOS: False, +} + + +def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto: + for attr in node.attribute: + if attr.name == attr_name: + return attr + else: + raise OnnxImportError(f"Required attribute {attr_name} not found in {node}") diff --git a/setup.py b/setup.py index 46217d307..a4b42309d 100644 --- a/setup.py +++ b/setup.py @@ -47,8 +47,6 @@ PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC_DEFAULT = True TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) -if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: - import torch # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): @@ -91,7 +89,7 @@ class CMakeBuild(build_py): f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", - f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}", + f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", ] os.makedirs(cmake_build_dir, exist_ok=True) @@ -145,8 +143,31 @@ with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() +# Requires and extension modules depend on whether building PyTorch +# extensions. +INSTALL_REQUIRES = [ + "numpy", + "packaging", +] +EXT_MODULES = [ + CMakeExtension("torch_mlir._mlir_libs._torchMlir"), +] +NAME = "torch-mlir-core" + +# If building PyTorch extensions, customize. +if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: + import torch + NAME = "torch-mlir" + INSTALL_REQUIRES.extend([ + f"torch=={torch.__version__}".split("+", 1)[0], + ]) + EXT_MODULES.extend([ + CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), + ]) + + setup( - name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core", + name=NAME, version=f"{PACKAGE_VERSION}", author="Sean Silva", author_email="silvasean@google.com", @@ -159,10 +180,12 @@ setup( "built_ext": NoopBuildExtension, "build_py": CMakeBuild, }, - ext_modules=[ - CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), - ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], - install_requires=["numpy", "packaging"] + ( - [f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), + ext_modules=EXT_MODULES, + install_requires=INSTALL_REQUIRES, + extras_require={ + "onnx": [ + "onnx>=1.15.0", + ], + } zip_safe=False, ) diff --git a/test-requirements.txt b/test-requirements.txt index 523772dde..315e02130 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ pillow dill multiprocess +onnx==1.15.0 \ No newline at end of file diff --git a/test/python/lit.local.cfg b/test/python/lit.local.cfg new file mode 100644 index 000000000..4cfe04325 --- /dev/null +++ b/test/python/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_bindings_python: + config.unsupported = True diff --git a/test/python/onnx_importer/.gitignore b/test/python/onnx_importer/.gitignore new file mode 100644 index 000000000..ea1472ec1 --- /dev/null +++ b/test/python/onnx_importer/.gitignore @@ -0,0 +1 @@ +output/ diff --git a/test/python/onnx_importer/_torch_mlir_config.py b/test/python/onnx_importer/_torch_mlir_config.py new file mode 100644 index 000000000..f597b63b4 --- /dev/null +++ b/test/python/onnx_importer/_torch_mlir_config.py @@ -0,0 +1,19 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s + +"""This file exists so that the tests can find/configure torch_mlir. + +It allows the test file to be standalone and used verbatim in other +projects (i.e. by just providing this file on the side). +""" + +from torch_mlir import ir +from torch_mlir.extras import onnx_importer + +def configure_context(context): + from torch_mlir.dialects import torch as torch_d + torch_d.register_dialect(context) diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py new file mode 100644 index 000000000..39a0b3098 --- /dev/null +++ b/test/python/onnx_importer/import_smoke_test.py @@ -0,0 +1,374 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s --output %t + +from glob import glob +from pathlib import Path + +import logging +import sys +import unittest + +import onnx + +from _torch_mlir_config import ( + configure_context, + ir, + onnx_importer, +) + +# Accept the output path on the command line or default to a sibling +# to this file. We have to pop this off explicitly or else unittest +# won't understand. +if len(sys.argv) > 1 and sys.argv[1] == "--output": + OUTPUT_PATH = Path(sys.argv[2]) + del sys.argv[1:3] +else: + OUTPUT_PATH = Path(__file__).resolve().parent / "output" + + +# TODO: Add some verification and overrides. For now, just use the +# onnx package install for onnx test files, since they were nice +# enough to include the test suite in the deployable. +import onnx.backend.test.data + +ONNX_TEST_DATA_DIR = Path(onnx.backend.test.__file__).resolve().parent / "data" +print(f"ONNX Test Data Dir: {ONNX_TEST_DATA_DIR}") +ONNX_REL_PATHS = glob(f"**/*.onnx", root_dir=ONNX_TEST_DATA_DIR, recursive=True) + +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + +TEST_CAST_XFAILS = [ + "light_light_bvlc_alexnet", + "light_light_inception_v1", + "light_light_squeezenet", + "light_light_vgg19", + "node_test_affine_grid_2d_align_corners_expanded_model", + "node_test_affine_grid_2d_expanded_model", + "node_test_affine_grid_3d_align_corners_expanded_model", + "node_test_affine_grid_3d_expanded_model", + "node_test_ai_onnx_ml_label_encoder_string_int_model", + "node_test_ai_onnx_ml_label_encoder_string_int_no_default_model", + "node_test_ai_onnx_ml_label_encoder_tensor_mapping_model", + "node_test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_model", + "node_test_cast_FLOAT16_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_FLOAT16_to_FLOAT8E4M3FN_model", + "node_test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_FLOAT16_to_FLOAT8E5M2_model", + "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_model", + "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT_model", + "node_test_cast_FLOAT8E4M3FN_to_FLOAT16_model", + "node_test_cast_FLOAT8E4M3FN_to_FLOAT_model", + "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT16_model", + "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT_model", + "node_test_cast_FLOAT8E5M2_to_FLOAT16_model", + "node_test_cast_FLOAT8E5M2_to_FLOAT_model", + "node_test_cast_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_cast_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_FLOAT_to_FLOAT8E5M2_model", + "node_test_cast_FLOAT_to_STRING_model", + "node_test_cast_STRING_to_FLOAT_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2_model", + "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_model", + "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_model", + "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_model", + "node_test_castlike_FLOAT8E5M2_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E5M2_to_FLOAT_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2_model", + "node_test_castlike_FLOAT_to_STRING_expanded_model", + "node_test_castlike_FLOAT_to_STRING_model", + "node_test_castlike_STRING_to_FLOAT_expanded_model", + "node_test_castlike_STRING_to_FLOAT_model", + "node_test_center_crop_pad_crop_axes_chw_expanded_model", + "node_test_center_crop_pad_crop_axes_hwc_expanded_model", + "node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model", + "node_test_clip_default_inbounds_model", + "node_test_clip_default_int8_inbounds_model", + "node_test_clip_default_int8_max_model", + "node_test_clip_default_max_model", + "node_test_constantofshape_float_ones_model", + "node_test_constantofshape_int_shape_zero_model", + "node_test_constantofshape_int_zeros_model", + "node_test_dequantizelinear_e4m3fn_model", + "node_test_dequantizelinear_e4m3fn_zero_point_model", + "node_test_dequantizelinear_e5m2_model", + "node_test_dft_axis_model", + "node_test_dft_inverse_model", + "node_test_dft_model", + "node_test_equal_string_broadcast_model", + "node_test_equal_string_model", + "node_test_gru_defaults_model", + "node_test_gru_seq_length_model", + "node_test_gru_with_initial_bias_model", + "node_test_identity_opt_model", + "node_test_identity_sequence_model", + "node_test_if_model", + "node_test_if_opt_model", + "node_test_if_seq_model", + "node_test_layer_normalization_2d_axis0_expanded_model", + "node_test_layer_normalization_2d_axis0_expanded_ver18_model", + "node_test_layer_normalization_2d_axis1_expanded_model", + "node_test_layer_normalization_2d_axis1_expanded_ver18_model", + "node_test_layer_normalization_2d_axis_negative_1_expanded_model", + "node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model", + "node_test_layer_normalization_2d_axis_negative_2_expanded_model", + "node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model", + "node_test_layer_normalization_3d_axis0_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis1_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis2_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model", + "node_test_layer_normalization_4d_axis0_expanded_model", + "node_test_layer_normalization_4d_axis0_expanded_ver18_model", + "node_test_layer_normalization_4d_axis1_expanded_model", + "node_test_layer_normalization_4d_axis1_expanded_ver18_model", + "node_test_layer_normalization_4d_axis2_expanded_model", + "node_test_layer_normalization_4d_axis2_expanded_ver18_model", + "node_test_layer_normalization_4d_axis3_expanded_model", + "node_test_layer_normalization_4d_axis3_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_1_expanded_model", + "node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_2_expanded_model", + "node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_3_expanded_model", + "node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_4_expanded_model", + "node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model", + "node_test_layer_normalization_default_axis_expanded_model", + "node_test_layer_normalization_default_axis_expanded_ver18_model", + "node_test_loop11_model", + "node_test_loop13_seq_model", + "node_test_loop16_seq_none_model", + "node_test_lstm_defaults_model", + "node_test_lstm_with_initial_bias_model", + "node_test_lstm_with_peepholes_model", + "node_test_optional_get_element_optional_sequence_model", + "node_test_optional_get_element_optional_tensor_model", + "node_test_optional_get_element_sequence_model", + "node_test_optional_has_element_empty_no_input_name_optional_input_model", + "node_test_optional_has_element_empty_no_input_name_tensor_input_model", + "node_test_optional_has_element_empty_optional_input_model", + "node_test_optional_has_element_optional_input_model", + "node_test_optional_has_element_tensor_input_model", + "node_test_quantizelinear_e4m3fn_model", + "node_test_quantizelinear_e5m2_model", + "node_test_range_float_type_positive_delta_expanded_model", + "node_test_range_int32_type_negative_delta_expanded_model", + "node_test_regex_full_match_basic_model", + "node_test_regex_full_match_email_domain_model", + "node_test_regex_full_match_empty_model", + "node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model", + "node_test_resize_downsample_scales_cubic_align_corners_model", + "node_test_resize_downsample_scales_cubic_antialias_model", + "node_test_resize_downsample_scales_cubic_model", + "node_test_resize_downsample_scales_linear_align_corners_model", + "node_test_resize_downsample_scales_linear_antialias_model", + "node_test_resize_downsample_scales_linear_half_pixel_symmetric_model", + "node_test_resize_downsample_scales_linear_model", + "node_test_resize_downsample_scales_nearest_model", + "node_test_resize_downsample_sizes_cubic_antialias_model", + "node_test_resize_downsample_sizes_cubic_model", + "node_test_resize_downsample_sizes_linear_antialias_model", + "node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model", + "node_test_resize_downsample_sizes_nearest_model", + "node_test_resize_downsample_sizes_nearest_not_larger_model", + "node_test_resize_downsample_sizes_nearest_not_smaller_model", + "node_test_resize_tf_crop_and_resize_axes_2_3_model", + "node_test_resize_tf_crop_and_resize_axes_3_2_model", + "node_test_resize_tf_crop_and_resize_model", + "node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model", + "node_test_resize_upsample_scales_cubic_align_corners_model", + "node_test_resize_upsample_scales_cubic_asymmetric_model", + "node_test_resize_upsample_scales_cubic_model", + "node_test_resize_upsample_scales_linear_align_corners_model", + "node_test_resize_upsample_scales_linear_half_pixel_symmetric_model", + "node_test_resize_upsample_scales_linear_model", + "node_test_resize_upsample_scales_nearest_axes_2_3_model", + "node_test_resize_upsample_scales_nearest_axes_3_2_model", + "node_test_resize_upsample_scales_nearest_model", + "node_test_resize_upsample_sizes_cubic_model", + "node_test_resize_upsample_sizes_nearest_axes_2_3_model", + "node_test_resize_upsample_sizes_nearest_axes_3_2_model", + "node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model", + "node_test_resize_upsample_sizes_nearest_floor_align_corners_model", + "node_test_resize_upsample_sizes_nearest_model", + "node_test_resize_upsample_sizes_nearest_not_larger_model", + "node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model", + "node_test_rnn_seq_length_model", + "node_test_scan9_sum_model", + "node_test_scan_sum_model", + "node_test_sequence_insert_at_back_model", + "node_test_sequence_insert_at_front_model", + "node_test_sequence_map_add_1_sequence_1_tensor_expanded_model", + "node_test_sequence_map_add_1_sequence_1_tensor_model", + "node_test_sequence_map_add_2_sequences_expanded_model", + "node_test_sequence_map_add_2_sequences_model", + "node_test_sequence_map_extract_shapes_expanded_model", + "node_test_sequence_map_extract_shapes_model", + "node_test_sequence_map_identity_1_sequence_1_tensor_expanded_model", + "node_test_sequence_map_identity_1_sequence_1_tensor_model", + "node_test_sequence_map_identity_1_sequence_expanded_model", + "node_test_sequence_map_identity_1_sequence_model", + "node_test_sequence_map_identity_2_sequences_expanded_model", + "node_test_sequence_map_identity_2_sequences_model", + "node_test_simple_rnn_defaults_model", + "node_test_simple_rnn_with_initial_bias_model", + "node_test_split_to_sequence_1_model", + "node_test_split_to_sequence_2_model", + "node_test_split_to_sequence_nokeepdims_model", + "node_test_stft_model", + "node_test_string_concat_broadcasting_model", + "node_test_string_concat_empty_string_model", + "node_test_string_concat_model", + "node_test_string_concat_utf8_model", + "node_test_string_concat_zero_dimensional_model", + "node_test_string_split_basic_model", + "node_test_string_split_consecutive_delimiters_model", + "node_test_string_split_empty_string_delimiter_model", + "node_test_string_split_empty_tensor_model", + "node_test_string_split_maxsplit_model", + "node_test_string_split_no_delimiter_model", + "node_test_strnormalizer_export_monday_casesensintive_lower_model", + "node_test_strnormalizer_export_monday_casesensintive_nochangecase_model", + "node_test_strnormalizer_export_monday_casesensintive_upper_model", + "node_test_strnormalizer_export_monday_empty_output_model", + "node_test_strnormalizer_export_monday_insensintive_upper_twodim_model", + "node_test_strnormalizer_nostopwords_nochangecase_model", + "simple_test_sequence_model1_model", + "simple_test_sequence_model2_model", + "simple_test_sequence_model3_model", + "simple_test_sequence_model4_model", + "simple_test_sequence_model5_model", + "simple_test_sequence_model6_model", + "simple_test_sequence_model7_model", + "simple_test_sequence_model8_model", + "simple_test_strnorm_model_monday_casesensintive_lower_model", + "simple_test_strnorm_model_monday_casesensintive_nochangecase_model", + "simple_test_strnorm_model_monday_casesensintive_upper_model", + "simple_test_strnorm_model_monday_empty_output_model", + "simple_test_strnorm_model_monday_insensintive_upper_twodim_model", + "simple_test_strnorm_model_nostopwords_nochangecase_model", +] + + +class ImportSmokeTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.unexpected_failure_count = 0 + ImportSmokeTest.actual_failures = [] + + @classmethod + def tearDownClass(cls): + if cls.unexpected_failure_count: + # Print a helpful message with copy-paste XFAIL def. + failure_report_path = OUTPUT_PATH / "import_smoke_test_report.txt" + print( + "Unexpected failures. Writing copy/paste report to:", + failure_report_path, + ) + with open(failure_report_path, "wt") as f: + lines = [f' "{s}",' for s in ImportSmokeTest.actual_failures] + print( + f"Unexpected failures in the following. Copy/paste to update `TEST_CAST_XFAILS`:", + file=f, + ) + print(f"TEST_CAST_XFAILS = [", file=f) + [print(l, file=f) for l in lines] + print(f"]", file=f) + + ImportSmokeTest.actual_failures.clear() + + def load_onnx_model(self, file_path: Path) -> onnx.ModelProto: + raw_model = onnx.load(file_path) + try: + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e: + print("WARNING: Shape inference failure (skipping test):", e) + self.skipTest(reason="shape inference failure") + + # inferred_model = raw_model + return inferred_model + + def run_import_test(self, norm_name: str, rel_path: str): + context = ir.Context() + configure_context(context) + + model_info = onnx_importer.ModelInfo( + self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path), + ) + m = model_info.create_module(context=context) + try: + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + m.verify() + finally: + # Use a ".txt" extension to avoid lit test discovery. + with open(OUTPUT_PATH / f"{norm_name}.mlir", "wt") as f: + print(m.get_asm(), file=f) + + def testExists(self): + # We expect a lot of test cases. Die if not the case (i.e. if paths change + # or something). + self.assertGreater(len(ONNX_REL_PATHS), 10) + + +# Generate test methods for each onnx file. +for _rel_path in ONNX_REL_PATHS: + + def attach_test(rel_path): + norm_name = rel_path.removesuffix(".onnx").replace("/", "_") + + def test_method(self: ImportSmokeTest): + try: + self.run_import_test(norm_name, rel_path) + except onnx_importer.OnnxImportError as e: + # All legitimate failures should be caught and reported + # as an OnnxImportError. + ImportSmokeTest.actual_failures.append(norm_name) + if norm_name not in TEST_CAST_XFAILS: + ImportSmokeTest.unexpected_failure_count += 1 + raise e + + test_method.__name__ = f"test_{norm_name}" + + if norm_name in TEST_CAST_XFAILS: + test_method = unittest.expectedFailure(test_method) + + setattr(ImportSmokeTest, test_method.__name__, test_method) + + attach_test(_rel_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/test/python/onnx_importer/lit.local.cfg b/test/python/onnx_importer/lit.local.cfg new file mode 100644 index 000000000..8e0adb7c1 --- /dev/null +++ b/test/python/onnx_importer/lit.local.cfg @@ -0,0 +1,5 @@ +try: + import onnx +except ModuleNotFoundError: + print("Skipping onnx tests.. no onnx") + config.unsupported = True