mirror of https://github.com/llvm/torch-mlir
49 lines
1.8 KiB
Python
49 lines
1.8 KiB
Python
# -*- Python -*-
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
#
|
|
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
|
|
|
|
from typing import Mapping
|
|
import inspect
|
|
import ast
|
|
import torch.fx
|
|
|
|
class Annotation:
|
|
def __init__(self, name: str, row: int, col: int):
|
|
self.name = name
|
|
self.row = row
|
|
self.col = col
|
|
|
|
Annotations = Mapping[torch.fx.Node, Annotation]
|
|
|
|
class LocInspector:
|
|
#TODO: type of module?
|
|
def __init__(self, graph: torch.fx.Graph, module: torch.nn.Module):
|
|
self.annotations = {}
|
|
self.graph = graph
|
|
self.module = module
|
|
module_lines, self.module_start_lineno = \
|
|
inspect.getsourcelines(type(module))
|
|
module_src = "".join(module_lines)
|
|
self.src_file = inspect.getsourcefile(type(module))
|
|
self.module_ast = ast.parse(module_src)
|
|
|
|
def __str__(self):
|
|
newline = "\n\n"
|
|
values = ["Annotations: ", str(self.annotations), newline,
|
|
"Src File: ", self.src_file, newline,
|
|
"Module AST: ", ast.dump(self.module_ast)]
|
|
return "".join(values)
|
|
|
|
def annotate_defs(self) -> None:
|
|
for node in ast.walk(self.module_ast):
|
|
if isinstance(node, (ast.ClassDef,
|
|
ast.FunctionDef)):
|
|
# subtract 1 because lineno's begin on 1
|
|
lineno = node.lineno + self.module_start_lineno - 1
|
|
self.annotations[node.name] = (self.src_file, lineno,
|
|
node.col_offset)
|