torch-mlir/examples/torchfx/loc.py

49 lines
1.8 KiB
Python
Raw Normal View History

# -*- Python -*-
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
#
# pylint: disable=no-member, no-name-in-module, invalid-name, missing-function-docstring, fixme
from typing import Mapping
import inspect
import ast
import torch.fx
class Annotation:
def __init__(self, name: str, row: int, col: int):
self.name = name
self.row = row
self.col = col
Annotations = Mapping[torch.fx.Node, Annotation]
class LocInspector:
#TODO: type of module?
def __init__(self, graph: torch.fx.Graph, module: torch.nn.Module):
self.annotations = {}
self.graph = graph
self.module = module
module_lines, self.module_start_lineno = \
inspect.getsourcelines(type(module))
module_src = "".join(module_lines)
self.src_file = inspect.getsourcefile(type(module))
self.module_ast = ast.parse(module_src)
def __str__(self):
newline = "\n\n"
values = ["Annotations: ", str(self.annotations), newline,
"Src File: ", self.src_file, newline,
"Module AST: ", ast.dump(self.module_ast)]
return "".join(values)
def annotate_defs(self) -> None:
for node in ast.walk(self.module_ast):
if isinstance(node, (ast.ClassDef,
ast.FunctionDef)):
# subtract 1 because lineno's begin on 1
lineno = node.lineno + self.module_start_lineno - 1
self.annotations[node.name] = (self.src_file, lineno,
node.col_offset)