mirror of https://github.com/llvm/torch-mlir
[onnx] Fix importer variable names to make `mlir` legal (#2690)
Some names for `onnx` identifiers are not legal in `mlir-ir`. Sanitize so that the generated `ir` is legal.pull/2691/head
parent
ccd469ca0d
commit
85b86b36a2
|
@ -38,6 +38,7 @@ from typing import Optional
|
|||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
from ..ir import (
|
||||
ArrayAttr,
|
||||
|
@ -464,13 +465,18 @@ class ContextCache:
|
|||
# See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type.
|
||||
raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}")
|
||||
|
||||
def _sanitize_name(self, name):
|
||||
if not name.isidentifier():
|
||||
name = "_" + name
|
||||
return re.sub("[:/]", "_", name)
|
||||
|
||||
def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
|
||||
tensor_type = self.tensor_proto_to_builtin_type(tp)
|
||||
if tp.HasField("raw_data"):
|
||||
# Conveniently, DenseResourceElementsAttr shares the raw data
|
||||
# format. We just give it maximum numeric alignment.
|
||||
return DenseResourceElementsAttr.get_from_buffer(
|
||||
tp.raw_data, tp.name, tensor_type, alignment=8
|
||||
tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8
|
||||
)
|
||||
else:
|
||||
# We have to do a data type specific instantiation from proto fields.
|
||||
|
|
Loading…
Reference in New Issue