torch-mlir/projects/onnx_c_importer/OnnxImporter.cpp

1010 lines
34 KiB
C++

//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "OnnxImporter.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include <cstdio>
#include <functional>
using namespace torch_mlir_onnx;
namespace {
std::string SanitizeNameAsIdentifier(std::string_view in) {
std::string out;
if (!in.empty() && !std::isalnum(in.front())) {
out.append("_");
}
out.append(in);
for (char &c : out) {
if (c == ':' || c == '/')
c = '_';
}
return out;
}
template <typename T>
void AppendDelimittedStrings(std::string &into, T &container) {
bool first = true;
for (auto &item : container) {
if (first) {
first = false;
} else {
into.append(", ");
}
into.append(item);
}
}
inline MlirStringRef toMlirStringRef(const std::string_view &s) {
return mlirStringRefCreate(s.data(), s.size());
}
inline MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}
inline MlirStringRef toMlirStringRef(const char *s) {
return mlirStringRefCreate(s, std::strlen(s));
}
inline MlirNamedAttribute toMlirNamedAttribute(const char *s,
MlirAttribute attr) {
MlirContext context = mlirAttributeGetContext(attr);
MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s));
return mlirNamedAttributeGet(ident, attr);
}
std::string getMlirAsm(MlirType t) {
std::string result;
mlirTypePrint(
t,
+[](MlirStringRef sr, void *userData) {
std::string *s = static_cast<std::string *>(userData);
s->append(sr.data, sr.length);
},
static_cast<void *>(&result));
return result;
}
// C++ helpers to create operations.
void addToMlirOperationState(MlirOperationState &state,
MlirNamedAttribute namedAttr) {
mlirOperationStateAddAttributes(&state, 1, &namedAttr);
}
void addToMlirOperationState(
MlirOperationState &state,
std::vector<std::pair<std::string, MlirAttribute>> &attrs) {
for (auto &p : attrs) {
addToMlirOperationState(state,
toMlirNamedAttribute(p.first.c_str(), p.second));
}
}
void addToMlirOperationState(MlirOperationState &state, MlirRegion region) {
mlirOperationStateAddOwnedRegions(&state, 1, &region);
}
[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state,
MlirValue value) {
mlirOperationStateAddOperands(&state, 1, &value);
}
void addToMlirOperationState(MlirOperationState &state,
const std::vector<MlirValue> &values) {
mlirOperationStateAddOperands(&state, values.size(), values.data());
}
void addToMlirOperationState(MlirOperationState &state, MlirType resultType) {
mlirOperationStateAddResults(&state, 1, &resultType);
}
void addToMlirOperationState(MlirOperationState &state,
const std::vector<MlirType> &resultTypes) {
mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data());
}
[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state) {}
template <typename T, typename U, typename... Ts>
void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u,
Ts &&...ts) {
addToMlirOperationState(state, std::forward<T>(t));
addToMlirOperationState(state, std::forward<U>(u), std::forward<Ts>(ts)...);
}
template <typename... Ts>
MlirOperation createMlirOperation(std::string name, MlirLocation loc,
Ts &&...ts) {
MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc);
addToMlirOperationState(state, std::forward<Ts>(ts)...);
return mlirOperationCreate(&state);
}
template <typename... Ts>
MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name,
MlirLocation loc, Ts &&...ts) {
MlirOperation operation =
createMlirOperation(name, loc, std::forward<Ts>(ts)...);
mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block),
operation);
return operation;
}
} // namespace
// ---------------------------------------------------------------------------//
// ModelInfo
// ---------------------------------------------------------------------------//
ModelInfo::ModelInfo() = default;
void ModelInfo::DebugDumpProto() {
std::string debug_string = model_proto_.DebugString();
fprintf(stderr, "%s\n", debug_string.c_str());
}
Status ModelInfo::Initialize() {
if (!model_proto_.has_graph()) {
return SetError("ONNX ModelProto has no main graph");
}
main_graph_ = std::make_unique<GraphInfo>(*this, model_proto_.graph());
if (failed(main_graph_->Initialize())) {
return failure();
}
return success();
}
// ---------------------------------------------------------------------------//
// GraphInfo
// ---------------------------------------------------------------------------//
Status GraphInfo::Initialize() {
// Initialize look up tables.
for (const onnx::TensorProto &t : graph_proto_.initializer()) {
initializer_map_.emplace(t.name(), t);
}
for (const onnx::ValueInfoProto &v : graph_proto_.value_info()) {
value_info_map_.emplace(v.name(), v);
}
for (const onnx::ValueInfoProto &v : graph_proto_.input()) {
declared_inputs_.emplace_back(&v);
}
for (const onnx::ValueInfoProto &v : graph_proto_.output()) {
outputs_.emplace_back(&v);
}
// Generate the effective input map, which for old models can be a subset of
// the input map.
if (model_info_.config().elide_initialized_inputs) {
// Default. Add declared inputs to the input map unless if they appear
// as an initializer.
for (const onnx::ValueInfoProto *it : declared_inputs_) {
std::string_view key = it->name();
if (initializer_map_.find(key) != initializer_map_.end()) {
// In initializers. Skip.
continue;
}
inputs_.emplace_back(it);
}
} else {
// Fallback for some legacy compatibility.
inputs_ = declared_inputs_;
std::vector<std::string_view> illegal_keys;
for (const onnx::ValueInfoProto *it : inputs_) {
std::string_view key = it->name();
if (initializer_map_.find(key) != initializer_map_.end()) {
illegal_keys.push_back(key);
}
}
if (!illegal_keys.empty()) {
std::string error = "When not in elide_initialized_inputs=true mode, we "
"expect inputs to not have an initial value (got ";
AppendDelimittedStrings(error, illegal_keys);
error.append(")");
return model_info_.SetError(std::move(error));
}
}
// Index the inputs and outputs.
for (auto *input : inputs_) {
input_map_.emplace(input->name(), *input);
}
for (auto *output : outputs_) {
output_map_.emplace(output->name(), *output);
}
return success();
}
const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) {
// Node outputs don't typically have type information, but shape inference
// will associate them in the value_info. If not there, it may be a
// graph output, which must have type information.
{
auto it = value_info_map_.find(name);
if (it != value_info_map_.end()) {
return &it->second.type();
}
}
{
auto it = output_map_.find(name);
if (it != output_map_.end()) {
return &it->second.type();
}
}
std::string msg = "No type information associated with '";
msg.append(name);
msg.append("'. Run shape inference?");
model_info_.SetError(std::move(msg));
return nullptr;
}
// ---------------------------------------------------------------------------//
// ContextCache
// ---------------------------------------------------------------------------//
MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) {
if (tp.has_tensor_type()) {
// Convert Tensor TypeProto.
const onnx::TypeProto_Tensor &tt = tp.tensor_type();
if (!tt.has_shape()) {
std::string msg =
"Unsupported Tensor type without shape (run shape inference?): ";
msg.append(tt.DebugString());
model_info_.SetError(std::move(msg));
return {nullptr};
}
MlirType element_type = ConvertTensorElementType(tt.elem_type());
if (mlirTypeIsNull(element_type)) {
return {nullptr};
}
shared_dims_.clear();
shared_dims_.reserve(6);
for (const onnx::TensorShapeProto::Dimension &dim : tt.shape().dim()) {
if (dim.has_dim_value()) {
// Static.
shared_dims_.push_back(dim.dim_value());
} else {
// Dynamic.
shared_dims_.push_back(-1);
}
}
return GetVtensorType(shared_dims_, element_type);
} else {
std::string msg = "Unsupported ONNX TypeProto: ";
msg.append(tp.DebugString());
model_info_.SetError(std::move(msg));
return {nullptr};
}
}
MlirType ContextCache::ConvertTensorElementType(int elem_type) {
auto it = elem_type_map_.find(elem_type);
if (it != elem_type_map_.end()) {
return it->second;
}
MlirType t = {nullptr};
switch (elem_type) {
case onnx::TensorProto::FLOAT:
t = mlirF32TypeGet(context_);
break;
case onnx::TensorProto::UINT8:
t = mlirIntegerTypeUnsignedGet(context_, 8);
break;
case onnx::TensorProto::INT8:
t = mlirIntegerTypeSignedGet(context_, 8);
break;
case onnx::TensorProto::UINT16:
t = mlirIntegerTypeUnsignedGet(context_, 16);
break;
case onnx::TensorProto::INT16:
t = mlirIntegerTypeSignedGet(context_, 16);
break;
case onnx::TensorProto::INT32:
t = mlirIntegerTypeSignedGet(context_, 32);
break;
case onnx::TensorProto::UINT32:
t = mlirIntegerTypeUnsignedGet(context_, 32);
break;
case onnx::TensorProto::INT64:
t = mlirIntegerTypeSignedGet(context_, 64);
break;
case onnx::TensorProto::UINT64:
t = mlirIntegerTypeUnsignedGet(context_, 64);
break;
case onnx::TensorProto::BOOL:
t = mlirIntegerTypeGet(context_, 1);
break;
case onnx::TensorProto::FLOAT16:
t = mlirF16TypeGet(context_);
break;
case onnx::TensorProto::DOUBLE:
t = mlirF64TypeGet(context_);
break;
case onnx::TensorProto::COMPLEX64:
t = mlirComplexTypeGet(mlirF32TypeGet(context_));
break;
case onnx::TensorProto::COMPLEX128:
t = mlirComplexTypeGet(mlirF64TypeGet(context_));
break;
case onnx::TensorProto::BFLOAT16:
t = mlirBF16TypeGet(context_);
break;
case onnx::TensorProto::FLOAT8E4M3FN:
t = mlirFloat8E4M3FNTypeGet(context_);
break;
case onnx::TensorProto::FLOAT8E4M3FNUZ:
t = mlirFloat8E4M3FNUZTypeGet(context_);
break;
case onnx::TensorProto::FLOAT8E5M2:
t = mlirFloat8E5M2TypeGet(context_);
break;
case onnx::TensorProto::FLOAT8E5M2FNUZ:
t = mlirFloat8E5M2FNUZTypeGet(context_);
break;
default: {
std::string msg = "Unknown ONNX tensor element type: ";
msg.append(std::to_string(elem_type));
model_info_.SetError(std::move(msg));
return {nullptr};
}
}
assert(t.ptr && "did not convert type");
elem_type_map_[elem_type] = t;
return t;
}
MlirAttribute
ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) {
MlirType tensor_type = ConvertTensorProtoToBuiltinType(tp);
if (tp.has_raw_data()) {
std::string sanitized_name = SanitizeNameAsIdentifier(tp.name());
// Conveniently, DenseResourceElementsAttr shares the raw data
// format. We just give it maximum numeric alignment.
return mlirUnmanagedDenseResourceElementsAttrGet(
tensor_type, toMlirStringRef(sanitized_name),
const_cast<void *>(static_cast<const void *>(tp.raw_data().data())),
tp.raw_data().size(), /*dataAlignment=*/8, /*dataIsMutable=*/false,
/*deleter=*/nullptr, /*userData=*/nullptr);
} else {
switch (tp.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
return mlirDenseElementsAttrFloatGet(tensor_type, tp.float_data_size(),
tp.float_data().data());
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(),
tp.int32_data().data());
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
return mlirDenseElementsAttrInt64Get(tensor_type, tp.int64_data_size(),
tp.int64_data().data());
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
return mlirDenseElementsAttrDoubleGet(tensor_type, tp.double_data_size(),
tp.double_data().data());
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
// Special case. See proto. Someone apparently got lazy.
std::vector<uint32_t> stupid_conversion;
stupid_conversion.reserve(tp.uint64_data_size());
for (uint64_t v : tp.uint64_data())
stupid_conversion.push_back(v);
return mlirDenseElementsAttrUInt32Get(
tensor_type, stupid_conversion.size(), stupid_conversion.data());
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
return mlirDenseElementsAttrUInt64Get(tensor_type, tp.uint64_data_size(),
tp.uint64_data().data());
}
}
std::string message =
"Unable to convert ONNX TensorProto to MLIR attribute: ";
message.append(tp.DebugString());
model_info_.SetError(std::move(message));
return {nullptr};
}
MlirType
ContextCache::ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp) {
MlirType element_type = ConvertTensorElementType(tp.data_type());
if (mlirTypeIsNull(element_type))
return {nullptr};
shared_dims_.clear();
for (auto dim : tp.dims()) {
shared_dims_.push_back(dim);
}
return mlirRankedTensorTypeGet(shared_dims_.size(), shared_dims_.data(),
element_type,
/*encoding=*/{nullptr});
}
MlirType
ContextCache::ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp) {
MlirType element_type = ConvertTensorElementType(tp.data_type());
if (mlirTypeIsNull(element_type))
return {nullptr};
shared_dims_.clear();
for (auto dim : tp.dims()) {
shared_dims_.push_back(dim);
}
return GetVtensorType(shared_dims_, element_type);
}
MlirType ContextCache::GetVtensorType(const std::vector<int64_t> &dims,
MlirType element_type) {
std::string type_asm = "!torch.vtensor<[";
// Add dimension list.
bool first_dim = true;
for (int dim : dims) {
if (first_dim)
first_dim = false;
else
type_asm.push_back(',');
if (dim < 0)
type_asm.push_back('?');
else
type_asm.append(std::to_string(dim));
}
type_asm.append("],");
// Add element type.
type_asm.append(getMlirAsm(element_type));
type_asm.push_back('>');
// Look in cache.
auto found_it = asm_type_map_.find(type_asm);
if (found_it != asm_type_map_.end()) {
return found_it->second;
}
// Parse.
MlirType t = mlirTypeParseGet(context_, toMlirStringRef(type_asm));
if (mlirTypeIsNull(t)) {
std::string message =
"internal error: could not parse !torch.vtensor type: ";
message.append(type_asm);
model_info_.SetError(std::move(message));
return t;
}
asm_type_map_[std::move(type_asm)] = t;
return t;
}
// ---------------------------------------------------------------------------//
// NodeImporter
// ---------------------------------------------------------------------------//
NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc,
MlirOperation module_op)
: graph_info_(graph_info), cc_(cc),
context_(mlirOperationGetContext(module_op)), module_op_(module_op),
func_op_({nullptr}), body_block_({nullptr}) {
std::string locName = "graph:";
locName.append(graph_info.graph_proto().name());
default_loc_ = mlirLocationNameGet(context_, toMlirStringRef(locName),
/*childLoc=*/{nullptr});
}
Status NodeImporter::DefineFunction(std::optional<std::string> name) {
const onnx::GraphProto &p = graph_info_.graph_proto();
MlirRegion moduleBodyRegion = mlirOperationGetRegion(module_op_, 0);
MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion);
MlirAttribute nameAttr;
if (name) {
// Explicitly named.
nameAttr = mlirStringAttrGet(context_, toMlirStringRef(*name));
} else {
// Name it according to the graph.
nameAttr = mlirStringAttrGet(context_, toMlirStringRef(p.name()));
}
// Derive the FunctionType.
std::vector<MlirType> input_types;
std::vector<MlirLocation> input_locs;
std::vector<MlirType> output_types;
for (auto *input : graph_info_.inputs()) {
MlirType t = cc_.ConvertTypeProto(input->type());
if (mlirTypeIsNull(t)) {
return failure();
}
input_types.push_back(t);
input_locs.push_back(default_loc_);
}
for (auto *output : graph_info_.outputs()) {
MlirType t = cc_.ConvertTypeProto(output->type());
if (mlirTypeIsNull(t)) {
return failure();
}
output_types.push_back(t);
}
MlirType ftype =
mlirFunctionTypeGet(context_, input_types.size(), input_types.data(),
output_types.size(), output_types.data());
// Create func.func.
func_op_ = createMlirOperationAtEnd(
moduleBody, "func.func", default_loc_, mlirRegionCreate(),
toMlirNamedAttribute("function_type", mlirTypeAttrGet(ftype)),
toMlirNamedAttribute("sym_name", nameAttr));
// Add entry block.
body_block_ = mlirBlockCreate(input_types.size(), input_types.data(),
input_locs.data());
MlirRegion bodyRegion = mlirOperationGetRegion(func_op_, 0);
mlirRegionAppendOwnedBlock(bodyRegion, body_block_);
// Map the block args to names and store for evaluation.
for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) {
std::string_view name = graph_info_.inputs()[i]->name();
MlirValue value = mlirBlockGetArgument(body_block_, i);
nv_map_[name] = value;
}
PopulateGraphAttrs(func_op_);
return success();
}
void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) {
const onnx::ModelProto &m = graph_info_.model_info().model_proto();
MlirType i64_type = mlirIntegerTypeSignedGet(context_, 64);
int default_opset_version = 0;
std::unordered_map<std::string_view, MlirAttribute> opset_versions;
// Determine model level opset versions.
for (const onnx::OperatorSetIdProto &opset_import : m.opset_import()) {
if (opset_import.has_domain()) {
opset_versions[opset_import.domain()] =
mlirIntegerAttrGet(i64_type, opset_import.version());
} else {
default_opset_version = opset_import.version();
}
}
// Set the default domain version.
if (default_opset_version != 0) {
mlirOperationSetDiscardableAttributeByName(
container_op, toMlirStringRef("torch.onnx_meta.opset_version"),
mlirIntegerAttrGet(i64_type, default_opset_version));
}
// Set versions for other domains.
if (!opset_versions.empty()) {
std::vector<MlirNamedAttribute> version_attrs;
for (auto it : opset_versions) {
version_attrs.push_back(mlirNamedAttributeGet(
mlirIdentifierGet(context_, toMlirStringRef(it.first)), it.second));
}
MlirAttribute dict_attr = mlirDictionaryAttrGet(
context_, version_attrs.size(), version_attrs.data());
mlirOperationSetDiscardableAttributeByName(
container_op, toMlirStringRef("torch.onnx_meta.opset_versions"),
dict_attr);
}
// IR version and producer.
mlirOperationSetDiscardableAttributeByName(
container_op, toMlirStringRef("torch.onnx_meta.ir_version"),
mlirIntegerAttrGet(i64_type, m.ir_version()));
mlirOperationSetDiscardableAttributeByName(
container_op, toMlirStringRef("torch.onnx_meta.producer_name"),
mlirStringAttrGet(context_, toMlirStringRef(m.producer_name())));
mlirOperationSetDiscardableAttributeByName(
container_op, toMlirStringRef("torch.onnx_meta.producer_version"),
mlirStringAttrGet(context_, toMlirStringRef(m.producer_version())));
}
Status NodeImporter::ImportAll() {
// TODO: Consider pulling in initializers on demand since there can be so
// much unused crap.
for (auto it : graph_info_.initializer_map()) {
if (failed(ImportInitializer(it.second)))
return failure();
}
for (auto it : graph_info_.graph_proto().node()) {
if (failed(ImportNode(it)))
return failure();
}
// Lookup the outputs, which should all be in the nv_map if the graph was
// properly formed.
std::vector<MlirValue> output_values;
for (const auto *output : graph_info_.outputs()) {
std::string_view name = output->name();
auto found_it = nv_map_.find(name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX graph output '";
msg.append(name);
msg.append("'");
return SetError(std::move(msg));
}
output_values.push_back(found_it->second);
}
createMlirOperationAtEnd(body_block_, "func.return", default_loc_,
output_values);
return success();
}
Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) {
std::string_view name = initializer.name();
MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name),
/*childLoc=*/{nullptr});
MlirAttribute value_attr = cc_.ConvertTensorProtoToAttr(initializer);
MlirType vtensor_type = cc_.ConvertTensorProtoToVtensorType(initializer);
if (mlirAttributeIsNull(value_attr) || mlirTypeIsNull(vtensor_type))
return failure();
MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.vtensor.literal", loc, vtensor_type,
toMlirNamedAttribute("value", value_attr));
MlirValue result = mlirOperationGetResult(op, 0);
auto inserted = nv_map_.insert(std::make_pair(name, result));
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
msg.append("', most recent from ");
msg.append(initializer.DebugString());
return SetError(std::move(msg));
}
return success();
}
Status NodeImporter::ImportNode(const onnx::NodeProto &node) {
std::string_view op_type = node.op_type();
// Handle special-form op types that do not go down the generic path.
if (op_type == "ConstantOfShape") {
return ImportConstantOfShapeNode(node);
}
return ImportGeneralNode(node);
}
Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) {
MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(node.name()),
/*childLoc=*/{nullptr});
// Map inputs to values.
std::vector<MlirValue> input_values;
for (auto &input_name : node.input()) {
auto found_it = nv_map_.find(input_name);
if (found_it == nv_map_.end()) {
std::string msg = "Non topologically produced ONNX node input '";
msg.append(input_name);
msg.append("'");
return SetError(std::move(msg));
}
input_values.push_back(found_it->second);
}
// Map outputs to types.
std::vector<MlirType> output_types;
for (auto &output_name : node.output()) {
const onnx::TypeProto *type_proto =
graph_info_.FindTypeProtoForName(output_name);
if (!type_proto)
return failure();
MlirType t = cc_.ConvertTypeProto(*type_proto);
if (mlirTypeIsNull(t))
return failure();
output_types.push_back(t);
}
// Derive the op name.
std::string op_name = "onnx.";
op_name.append(node.op_type());
MlirAttribute op_name_attr =
mlirStringAttrGet(context_, toMlirStringRef(op_name));
// General attributes.
std::vector<std::pair<std::string, MlirAttribute>> general_attributes;
for (auto &onnx_attr : node.attribute()) {
MlirAttribute attr = ImportGeneralAttribute(onnx_attr);
if (mlirAttributeIsNull(attr))
return failure();
std::string full_name = "torch.onnx.";
full_name.append(onnx_attr.name());
general_attributes.push_back(std::make_pair(full_name, attr));
}
// Create op.
MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.operator", loc, output_types, input_values,
toMlirNamedAttribute("name", op_name_attr), general_attributes);
// Record the result values.
for (int i = 0, e = output_types.size(); i < e; ++i) {
MlirValue result = mlirOperationGetResult(op, i);
std::string_view name = node.output(i);
auto inserted = nv_map_.insert(std::make_pair(name, result));
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
msg.append("', most recent from ");
msg.append(node.DebugString());
return SetError(std::move(msg));
}
}
return success();
}
MlirAttribute
NodeImporter::ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr) {
switch (onnx_attr.type()) {
case onnx::AttributeProto::UNDEFINED:
SetError("'UNDEFINED' attribute type not supported");
return {nullptr};
case onnx::AttributeProto::FLOAT:
return mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_),
onnx_attr.f());
case onnx::AttributeProto::INT:
return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64),
onnx_attr.i());
case onnx::AttributeProto::STRING:
return mlirStringAttrGet(context_, toMlirStringRef(onnx_attr.s()));
case onnx::AttributeProto::TENSOR:
return cc_.ConvertTensorProtoToAttr(onnx_attr.t());
case onnx::AttributeProto::GRAPH:
SetError("'GRAPH' attribute type not supported on this node");
return {nullptr};
case onnx::AttributeProto::SPARSE_TENSOR:
SetError("'SPARSE_TENSOR' attribute type not supported on this node");
return {nullptr};
case onnx::AttributeProto::TYPE_PROTO:
SetError("'TYPE_PROTO' attribute type not supported on this node");
return {nullptr};
case onnx::AttributeProto::FLOATS: {
std::vector<MlirAttribute> attrs;
for (auto f : onnx_attr.floats())
attrs.push_back(
mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_), f));
return mlirArrayAttrGet(context_, attrs.size(), attrs.data());
}
case onnx::AttributeProto::INTS: {
std::vector<MlirAttribute> attrs;
for (auto i : onnx_attr.ints())
attrs.push_back(
mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64), i));
return mlirArrayAttrGet(context_, attrs.size(), attrs.data());
}
case onnx::AttributeProto::STRINGS: {
std::vector<MlirAttribute> attrs;
for (auto s : onnx_attr.strings())
attrs.push_back(mlirStringAttrGet(context_, toMlirStringRef(s)));
return mlirArrayAttrGet(context_, attrs.size(), attrs.data());
}
case onnx::AttributeProto::TENSORS: {
std::vector<MlirAttribute> attrs;
for (auto &t : onnx_attr.tensors()) {
MlirAttribute attr = cc_.ConvertTensorProtoToAttr(t);
if (mlirAttributeIsNull(attr))
return {nullptr};
attrs.push_back(attr);
}
return mlirArrayAttrGet(context_, attrs.size(), attrs.data());
}
case onnx::AttributeProto::GRAPHS:
SetError("'GRAPHS' attribute type not supported on this node");
return {nullptr};
case onnx::AttributeProto::SPARSE_TENSORS:
SetError("'SPARSE_TENSORS' attribute type not supported on this node");
return {nullptr};
case onnx::AttributeProto::TYPE_PROTOS:
SetError("'TYPE_PROTOS' attribute type not supported on this node");
return {nullptr};
}
std::string msg = "Unhandled ONNX attribute type code ";
msg.append(std::to_string(onnx_attr.type()));
msg.append(": ");
msg.append(onnx_attr.DebugString());
SetError(std::move(msg));
return {nullptr};
}
Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) {
std::string_view name = node.name();
MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name),
/*childLoc=*/{nullptr});
// This op is special: It has an input of the shape, and in full generality
// could involve eager production of constants of variable size. In
// practice, the DNN profile for ONNX makes this very difficult to do
// and we hard-assert that the input can be resolved to an immediate
// value.
if (node.input_size() != 1 || node.output_size() != 1) {
return SetError("ConstantOfShape node must have one input and output");
}
// Shape.
std::vector<int64_t> shape;
if (failed(GetImmediateShapeTensor(node.input(0), shape)))
return failure();
// Value.
const onnx::AttributeProto *value_proto = nullptr;
for (auto &attr : node.attribute()) {
if (attr.name() == "value") {
value_proto = &attr;
break;
}
}
if (!value_proto) {
return SetError("ConstantOfShape node must have a 'value' attribute");
}
if (value_proto->type() != onnx::AttributeProto_AttributeType_TENSOR) {
return SetError("ConstantOfShape node must have a tensor value attribute");
}
// Create the splat.
const onnx::TensorProto &tensor_proto = value_proto->t();
if (tensor_proto.dims_size() != 1 || tensor_proto.dims(0) != 1) {
return SetError("ConstantOfShape node expected a scalar tensor value");
}
auto tensorTypeFor = [&](MlirType element_type) {
return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type,
/*encoding*/ {nullptr});
};
MlirAttribute splat_attr = {nullptr};
switch (tensor_proto.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT32:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)),
tensor_proto.int32_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_INT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)),
tensor_proto.int64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64:
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)),
tensor_proto.uint64_data(0));
break;
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32:
// Special case: inline data is stored in uint64.
splat_attr = mlirDenseElementsAttrFloatSplatGet(
tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)),
tensor_proto.uint64_data(0));
break;
}
if (mlirAttributeIsNull(splat_attr)) {
std::string message =
"ConstantOfShape node has an unsupported splat data type: ";
message.append(tensor_proto.DebugString());
return SetError(std::move(message));
}
// Create the vtensor type for the result.
MlirType splat_type = mlirAttributeGetType(splat_attr);
MlirType element_type = mlirShapedTypeGetElementType(splat_type);
MlirType vtensor_type = cc_.GetVtensorType(shape, element_type);
if (mlirTypeIsNull(vtensor_type))
return failure();
MlirOperation op = createMlirOperationAtEnd(
body_block_, "torch.vtensor.literal", loc, vtensor_type,
toMlirNamedAttribute("value", splat_attr));
MlirValue result = mlirOperationGetResult(op, 0);
// Export to the nv_map.
auto inserted = nv_map_.insert(std::make_pair(name, result));
if (!inserted.second) {
std::string msg = "Multiple nodes produced a value for '";
msg.append(name);
msg.append("', most recent from ");
msg.append(node.DebugString());
return SetError(std::move(msg));
}
return success();
}
Status NodeImporter::GetImmediateShapeTensor(const std::string &name,
std::vector<int64_t> &shape) {
auto found_it = graph_info_.initializer_map().find(name);
if (found_it == graph_info_.initializer_map().end()) {
std::string message = "An immediate shape value for '";
message.append(name);
message.append("' was required but it is dynamically produced");
return SetError(std::move(message));
}
const onnx::TensorProto &tp = found_it->second;
shape.clear();
// Since this is being interpreted as a shape, we only support some limited
// types.
size_t raw_data_size;
switch (tp.data_type()) {
case onnx::TensorProto::DataType::TensorProto_DataType_INT32: {
auto *raw_data = graph_info_.GetOptionalRawData<int32_t>(tp, raw_data_size);
if (raw_data) {
std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape));
} else {
for (auto v : tp.int32_data())
shape.push_back(v);
}
return success();
}
case onnx::TensorProto::DataType::TensorProto_DataType_INT64: {
auto *raw_data = graph_info_.GetOptionalRawData<int64_t>(tp, raw_data_size);
if (raw_data) {
std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape));
} else {
for (auto v : tp.int64_data())
shape.push_back(v);
}
return success();
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: {
auto *raw_data =
graph_info_.GetOptionalRawData<uint32_t>(tp, raw_data_size);
if (raw_data) {
std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape));
} else {
// Stupid special case: stored in uint64.
for (auto v : tp.uint64_data())
shape.push_back(v);
}
return success();
}
case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: {
auto *raw_data =
graph_info_.GetOptionalRawData<uint64_t>(tp, raw_data_size);
if (raw_data) {
std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape));
} else {
for (auto v : tp.uint64_data())
shape.push_back(v);
}
return success();
}
}
{
std::string message =
"An immediate shape value could not be converted from TensorProto: ";
message.append(tp.DebugString());
return SetError(std::move(message));
}
}
void NodeImporter::DebugDumpModule() {
auto callback = +[](MlirStringRef sr, void *) {
fwrite(sr.data, sizeof(char), sr.length, stderr);
};
mlirOperationPrint(module_op_, callback, nullptr);
}