[onnx] Add torch-mlir-import-onnx native port as an optional tool/library. (#2694)

As noted in the plan when this work started, we need to produce an ORT
EP plugin for a downstream project, and this will necessitate a C-based
ONNX importer (as opposed to the existing Python one). Because this
comes with dependencies that we do not want to impart on various
projects, this is optional in torch-mlir. It is also factored so that it
can be used as standalone sources in downstreams that need it. Since it
only depends on public C APIs on the MLIR side, this will make build
coupling a lot better (since a C++ dep is not needed on the compiler and
it is trivial to dynamically load).

Our original plan was just to maintain this fork off to the side in our
ORT plugin, but once work started, it seemed better to write it clean
and contribute it upstream for anyone to use. We expect that for non-ORT
use, the Python importer will have better ergonomics for most folks.

I will follow-up with a test suite refactor so that we can drive the
Python or C importer.

This is a relatively mechanical port from Python to C, borrowing some
scaffolding from the old JitIR importer. It does attempt to lay some
groundwork for external data, which will need to be implemented on the
Python side as well.
pull/2703/head
Stella Laurenzo 2023-12-27 12:13:34 -08:00 committed by GitHub
parent 2d796b7502
commit 1b40b6384e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1409 additions and 0 deletions

View File

@ -51,6 +51,8 @@ option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension fe
cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF)
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
#-------------------------------------------------------------------------------
# Configure out-of-tree vs in-tree build
#-------------------------------------------------------------------------------

View File

@ -1,5 +1,9 @@
include(AddMLIRPython)
if(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER)
add_subdirectory(onnx_c_importer)
endif()
################################################################################
# PyTorch
# Configure PyTorch if we have any features enabled which require it.

View File

@ -0,0 +1,44 @@
message(STATUS "Enabling onnx_c_importer...")
include(FetchContent)
find_package(Protobuf)
if(NOT Protobuf_FOUND)
message(FATAL_ERROR
"In order to build C ONNX support, the Protobuf package must be installed "
"on the system. Without this ONNX will attempt to build it in the project "
"and the dependent ABSEIL build system is incompatible. "
"On Ubuntu, install with: "
"apt install libprotobuf-dev protobuf-compiler\n\n"
"(or this entire component can be disabled with "
"-DTORCH_MLIR_ENABLE_ONNX_C_IMPORTER=OFF)")
endif()
option(ONNX_DISABLE_EXCEPTIONS "For compatibility with LLVM build" ON)
FetchContent_Declare(
onnx
EXCLUDE_FROM_ALL
GIT_REPOSITORY https://github.com/onnx/onnx.git
GIT_TAG v1.15.0
GIT_SHALLOW ON
GIT_PROGRESS ON
)
FetchContent_MakeAvailable(onnx)
add_llvm_executable(
torch-mlir-import-onnx
PARTIAL_SOURCES_INTENDED
import-onnx-main.cpp
OnnxImporter.h
OnnxImporter.cpp
)
target_link_libraries(
torch-mlir-import-onnx
LLVMSupport
MLIRCAPIIR
TorchMLIRCAPI
onnx
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,240 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// Stand-alone ONNX -> MLIR importer.
// This library only depends on ONNX (and transitively protobuf, of course)
// and the MLIR C API. It does this to minimize its dependency surface area
// and make it possible to integrate as source code into other systems while
// retaining this implementation as the source of truth.
//
// It uses a hybrid of LLVM and Google C++ coding style, preferring the latter
// for class members/accessors because canonical protobuf coding presumes
// this kind of style.
#include "mlir-c/IR.h"
#include "onnx/onnx_pb.h"
#include <optional>
#include <string_view>
#include <unordered_map>
namespace torch_mlir_onnx {
struct Config;
class GraphInfo;
class ModelInfo;
struct Config {
// Ancient ONNX exporters would often add a model input for anything that
// might be mutable, providing an initializer for it as well. More modern
// tools tools realized this is a really bad idea for a lot of reasons.
// We choose to assume more recent norms, even if encountering older
// models. Setting this to False probably won't do what you want but
// should produce interesting errors to waste your time deciphering.
// We mainly use it as a way to document in the code that we are
// making an assumption.
bool elide_initialized_inputs = true;
};
/// A light-weight status. It only encapsulates success/failure.
/// Full error information will be set on the ModelInfo.
class Status {
public:
static Status success(bool isSuccess = true) { return Status(isSuccess); }
static Status failure(bool isFailure = true) { return Status(!isFailure); }
bool is_success() { return is_success_; }
private:
Status(bool is_success) : is_success_(is_success) {}
bool is_success_;
};
static inline Status success() { return Status::success(); }
static inline Status failure() { return Status::failure(); }
static inline bool succeeded(Status status) { return status.is_success(); }
static inline bool failed(Status status) { return !status.is_success(); }
// Accounting for a GraphProto.
class GraphInfo {
public:
GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto)
: model_info_(model_info), graph_proto_(graph_proto) {}
ModelInfo &model_info() { return model_info_; }
const onnx::GraphProto &graph_proto() { return graph_proto_; }
/// Post-construction, failable initialization.
Status Initialize();
/// Finds a TypeProto for the given value name. If returning nullptr, then
/// an error will have been set.
const onnx::TypeProto *FindTypeProtoForName(std::string_view name);
/// Attempts to access the raw or external data of the TensorProto. If the
/// the data is located in those positions, returns a types pointer to it
/// and stores the number of elements to `out_size`. Otherwise, nullptr is
/// returned (and no error is set).
template <typename ElementType>
const ElementType *GetOptionalRawData(const onnx::TensorProto &tp,
size_t &out_size) {
if (tp.has_raw_data()) {
out_size = tp.raw_data().size() / sizeof(ElementType);
return reinterpret_cast<const ElementType *>(tp.raw_data().data());
}
return nullptr;
}
std::vector<const onnx::ValueInfoProto *> &inputs() { return inputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
input_map() {
return input_map_;
}
std::vector<const onnx::ValueInfoProto *> &outputs() { return outputs_; }
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> &
output_map() {
return output_map_;
}
std::unordered_map<std::string_view, const onnx::TensorProto &> &
initializer_map() {
return initializer_map_;
}
private:
ModelInfo &model_info_;
const onnx::GraphProto &graph_proto_;
std::unordered_map<std::string_view, const onnx::TensorProto &>
initializer_map_;
std::unordered_map<std::string_view, const onnx::ValueInfoProto &>
value_info_map_;
std::vector<const onnx::ValueInfoProto *> declared_inputs_;
std::vector<const onnx::ValueInfoProto *> inputs_;
std::vector<const onnx::ValueInfoProto *> outputs_;
std::unordered_map<std::string_view, const onnx::ValueInfoProto &> input_map_;
std::unordered_map<std::string_view, const onnx::ValueInfoProto &>
output_map_;
};
/// Top-level accounting and accessors for an ONNX model.
class ModelInfo {
public:
ModelInfo();
Config &config() { return config_; }
onnx::ModelProto &model_proto() { return model_proto_; }
/// Post-construction, failable initialization.
Status Initialize();
GraphInfo &main_graph() { return *main_graph_; }
const std::string &error_message() { return error_message_; }
Status SetError(std::string msg) {
error_message_ = std::move(msg);
return failure();
}
void DebugDumpProto();
private:
Config config_;
onnx::ModelProto model_proto_;
std::unique_ptr<GraphInfo> main_graph_;
std::string error_message_;
};
class ContextCache {
public:
ContextCache(ModelInfo &model_info, MlirContext context)
: model_info_(model_info), context_(context) {}
MlirContext context() { return context_; }
/// Converts the TypeProto to an MlirType, returning a null type and
/// setting an error if not possible.
MlirType ConvertTypeProto(const onnx::TypeProto &tp);
/// Converts the ONNX element type code to an MlirType, returning a null type
/// and setting an error if not possible.
MlirType ConvertTensorElementType(int element_type_code);
/// Converts an ONNX TensorProto to an MlirAttribute, returning a null
/// attribute and setting an error if not possible.
MlirAttribute ConvertTensorProtoToAttr(const onnx::TensorProto &tp);
/// Converts the ONNX TensorProto to an Mlir RankedTensor type.
MlirType ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp);
/// Converts the ONNX TensorProto to a !torch.vtensor type.
MlirType ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp);
/// Gets a !torch.vtensor type for the given dims and element type.
/// Dynamic dims are represented as -1.
/// If it was not possible to create the type, sets an error and returns
/// the null type.
MlirType GetVtensorType(const std::vector<int64_t> &dims,
MlirType element_type);
private:
ModelInfo &model_info_;
MlirContext context_;
std::unordered_map<int, MlirType> elem_type_map_;
std::unordered_map<std::string, MlirType> asm_type_map_;
std::vector<int64_t> shared_dims_;
};
/// Imports graph nodes into a function.
class NodeImporter {
public:
NodeImporter(GraphInfo &graph_info, ContextCache &cc,
MlirOperation module_op);
/// Called after construction to define the function in the module. Must be
/// called prior to importing nodes.
Status DefineFunction(std::optional<std::string> name = {});
/// Imports all nodes topologically.
Status ImportAll();
void DebugDumpModule();
private:
void PopulateGraphAttrs(MlirOperation container_op);
Status ImportInitializer(const onnx::TensorProto &initializer);
Status ImportNode(const onnx::NodeProto &node);
MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr);
// Special-form nodes.
Status ImportGeneralNode(const onnx::NodeProto &node);
Status ImportConstantOfShapeNode(const onnx::NodeProto &node);
/// Looks for an initializer for `name` and attempts to treat it as a 1D
/// shape, filling `shape` if successful. Returns failure and sets an error
/// if not.
Status GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape);
Status SetError(std::string msg) {
return graph_info_.model_info().SetError(std::move(msg));
}
GraphInfo &graph_info_;
ContextCache &cc_;
MlirContext context_;
MlirOperation module_op_;
MlirOperation func_op_;
MlirBlock body_block_;
MlirLocation default_loc_;
std::unordered_map<std::string_view, MlirValue> nv_map_;
};
} // namespace torch_mlir_onnx

View File

@ -0,0 +1,7 @@
# ONNX C Importer
This project provides a C implementation of the `onnx_importer.py`, which is
the canonical source. It is provided as sample code for anyone who wishes to
integrate it into their system. By design, it only depends on the ONNX API
and the MLIR C API via the `mlir-c` headers. As such, it should be easy to
build into any system that already has those things by adding the sources.

View File

@ -0,0 +1,103 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
// This main driver uses LLVM tool-making facilities and the support lib.
// The actual importer libraries, however, only depend on the C API so that
// they can be included in foreign projects more easily.
#include "torch-mlir-c/Registration.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "OnnxImporter.h"
#include "onnx/onnx_pb.h"
#include <fstream>
#include <iostream>
using namespace llvm;
using namespace torch_mlir_onnx;
struct MlirState {
MlirState() {
context = mlirContextCreateWithThreading(false);
torchMlirRegisterAllDialects(context);
module = mlirModuleCreateEmpty(mlirLocationUnknownGet(context));
}
~MlirState() {
mlirModuleDestroy(module);
mlirContextDestroy(context);
}
MlirContext context;
MlirModule module;
};
int main(int argc, char **argv) {
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));
InitLLVM y(argc, argv);
cl::ParseCommandLineOptions(argc, argv, "torch-mlir-onnx-import-c");
// Open the input as an istream because that is what protobuf likes.
std::unique_ptr<std::ifstream> alloced_input_stream;
std::istream *input_stream = nullptr;
if (inputFilename == "-") {
errs() << "(parsing from stdin)\n";
input_stream = &std::cin;
} else {
alloced_input_stream = std::make_unique<std::ifstream>(
inputFilename, std::ios::in | std::ios::binary);
if (!*alloced_input_stream) {
errs() << "error: could not open input file " << inputFilename << "\n";
return 1;
}
input_stream = alloced_input_stream.get();
}
// Parse the model proto.
ModelInfo model_info;
if (!model_info.model_proto().ParseFromIstream(input_stream)) {
errs() << "Failed to parse ONNX ModelProto from " << inputFilename << "\n";
return 2;
}
if (failed(model_info.Initialize())) {
errs() << "error: Import failure: " << model_info.error_message() << "\n";
model_info.DebugDumpProto();
return 3;
}
model_info.DebugDumpProto();
// Import.
MlirState owned_state;
ContextCache cc(model_info, owned_state.context);
NodeImporter importer(model_info.main_graph(), cc,
mlirModuleGetOperation(owned_state.module));
if (failed(importer.DefineFunction())) {
errs() << "error: Could not define MLIR function for graph: "
<< model_info.error_message() << "\n";
return 4;
}
if (failed(importer.ImportAll())) {
errs() << "error: Could not import one or more graph nodes: "
<< model_info.error_message() << "\n";
return 5;
}
importer.DebugDumpModule();
return 0;
}