torch-mlir/python/npcomp/compiler/frontend.py

258 lines
7.9 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Frontend to the compiler, allowing various ways to import code.
"""
import ast
import inspect
import sys
from _npcomp.mlir import ir
from npcomp.dialect import Numpy
from . import logging
__all__ = [
"ImportFrontend",
]
class ImportFrontend:
"""Frontend for importing various entities into a Module."""
def __init__(self, ir_context: ir.MLIRContext = None):
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
self._ir_module = self._ir_context.new_module()
self._helper = Numpy.DialectHelper(self._ir_context)
@property
def ir_context(self):
return self._ir_context
@property
def ir_module(self):
return self._ir_module
@property
def ir_h(self):
return self._helper
def import_global_function(self, f):
"""Imports a global function.
This facility is not general and does not allow customization of the
containing environment, method import, etc.
Most errors are emitted via the MLIR context's diagnostic infrastructure,
but errors related to extracting source, etc are raised directly.
Args:
f: The python callable.
"""
h = self.ir_h
ir_c = self.ir_context
ir_m = self.ir_module
filename = inspect.getsourcefile(f)
source_lines, start_lineno = inspect.getsourcelines(f)
source = "".join(source_lines)
ast_root = ast.parse(source, filename=filename)
ast.increment_lineno(ast_root, start_lineno - 1)
ast_fd = ast_root.body[0]
filename_ident = ir_c.identifier(filename)
# Define the function.
# TODO: Much more needs to be done here (arg/result mapping, etc)
logging.debug("Importing global function {}:\n{}", ast_fd.name,
ast.dump(ast_fd, include_attributes=True))
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
ast_fd.col_offset)
h.builder.insert_before_terminator(ir_m.first_block)
ir_f_type = h.function_type([], [h.basicpy_UnknownType])
ir_f = h.func_op(ast_fd.name, ir_f_type, create_entry_block=True)
fctx = FunctionContext(ir_c=ir_c,
ir_f=ir_f,
ir_h=h,
filename_ident=filename_ident)
fdimport = FunctionDefImporter(fctx, ast_fd)
fdimport.import_body()
return ir_f
class FunctionContext:
"""Accounting information for importing a function."""
__slots__ = [
"ir_c",
"ir_f",
"ir_h",
"filename_ident",
"local_name_value_map",
]
def __init__(self, ir_c, ir_f, ir_h, filename_ident):
self.ir_c = ir_c
self.ir_f = ir_f
self.ir_h = ir_h
self.filename_ident = filename_ident
self.local_name_value_map = dict()
def abort(self, message):
"""Emits an error diagnostic and raises an exception to abort."""
loc = self.current_loc
ir.emit_error(loc, message)
raise EmittedError(loc, message)
@property
def current_loc(self):
return self.ir_h.builder.current_loc
def update_loc(self, ast_node):
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
ast_node.col_offset)
def map_local_name(self, name, value):
self.local_name_value_map[name] = value
logging.debug("Map name({}) -> value({})", name, value)
class BaseNodeVisitor(ast.NodeVisitor):
"""Base class of a node visitor that aborts on unhandled nodes."""
IMPORTER_TYPE = "<unknown>"
def __init__(self, fctx):
super().__init__()
self.fctx = fctx
def visit(self, node):
self.fctx.update_loc(node)
return super().visit(node)
def generic_visit(self, ast_node):
logging.debug("UNHANDLED NODE: {}", ast.dump(ast_node))
self.fctx.abort("unhandled python %s AST node '%s'" %
(self.IMPORTER_TYPE, ast_node.__class__.__name__))
class FunctionDefImporter(BaseNodeVisitor):
"""AST visitor for importing a function's statements.
Handles nodes that are direct children of a FunctionDef.
"""
IMPORTER_TYPE = "statement"
def __init__(self, fctx, ast_fd):
super().__init__(fctx)
self.ast_fd = ast_fd
def import_body(self):
for ast_stmt in self.ast_fd.body:
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
self.visit(ast_stmt)
def visit_Assign(self, ast_node):
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
for target in ast_node.targets:
self.fctx.update_loc(target)
if not isinstance(target.ctx, ast.Store):
# TODO: Del, AugStore, etc
self.fctx.abort("Unsupported assignment context type %s" %
target.ctx.__class__.__name__)
self.fctx.map_local_name(target.id, expr.value)
def visit_Return(self, ast_node):
ir_h = self.fctx.ir_h
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
ir_h.return_op([expr.value])
class ExpressionImporter(BaseNodeVisitor):
IMPORTER_TYPE = "expression"
def __init__(self, fctx):
super().__init__(fctx)
self.value = None
def visit(self, node):
super().visit(node)
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
(node,))
def visit_Constant(self, ast_node):
ir_c = self.fctx.ir_c
ir_h = self.fctx.ir_h
if isinstance(ast_node, ast.Num):
# Handle numeric constants.
nval = ast_node.n
if isinstance(nval, int):
# TODO: Configurable type mapping
ir_type = ir_h.i64_type
ir_attr = ir_c.integer_attr(ir_type, nval)
elif isinstance(nval, float):
# TODO: Configurable type mapping
ir_type = ir_h.f64_type
ir_attr = ir_c.float_attr(ir_type, nval)
else:
self.fctx.abort("unsupported numeric constant type: %r" % (nval,))
self.value = ir_h.constant_op(ir_type, ir_attr).result
else:
self.fctx.abort("unknown constant type %s" %
(ast_node.__class__.__name__))
def visit_BinOp(self, ast_node):
ir_c = self.fctx.ir_c
ir_h = self.fctx.ir_h
left = ExpressionImporter(self.fctx)
left.visit(ast_node.left)
right = ExpressionImporter(self.fctx)
right.visit(ast_node.right)
ir_attrs = ir_c.dictionary_attr(
{"operation": ir_c.string_attr(ast_node.op.__class__.__name__)})
self.fctx.update_loc(ast_node)
# TODO: Change to a registered op.
self.value = ir_h.op("basicpy.binary_expr", [ir_h.basicpy_UnknownType],
[left.value, right.value], ir_attrs).result
def visit_Name(self, ast_node):
if not isinstance(ast_node.ctx, ast.Load):
self.fctx.abort("Unsupported expression name context type %s" %
ast_node.ctx.__class__.__name__)
# TODO: Need to apply scope rules: local, global, ...
value = self.fctx.local_name_value_map.get(ast_node.id)
if value is None:
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
self.value = value
if sys.version_info < (3, 8, 0):
visit_Num = visit_Constant
visit_Str = visit_Constant
visit_Bytes = visit_Constant
visit_NameConstant = visit_Constant
visit_Ellipsis = visit_Constant
else:
# For >= 3.8.0, these are deprecated but still may be called for
# compatibility (in addition to visit_Constant). Just make them no-op.
def ignore(self, ast_node):
pass
visit_Num = ignore
visit_Str = ignore
visit_Bytes = ignore
visit_NameConstant = ignore
visit_Ellipsis = ignore
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__("%s (at %r)" % (message, loc))
self.loc = loc