torch-mlir/python/npcomp/compiler/frontend.py

290 lines
9.2 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Frontend to the compiler, allowing various ways to import code.
"""
import ast
import inspect
import sys
from _npcomp.mlir import ir
from _npcomp.mlir.dialect import ScfDialectHelper
from npcomp.dialect import Numpy
from . import logging
__all__ = [
"ImportFrontend",
]
# TODO: Remove this hack in favor of a helper function that combines
# multiple dialect helpers so that we don't need to deal with the sharp
# edge of initializing multiple native base classes.
class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
def __init__(self, *args, **kwargs):
Numpy.DialectHelper.__init__(self, *args, **kwargs)
ScfDialectHelper.__init__(self, *args, **kwargs)
class ImportFrontend:
"""Frontend for importing various entities into a Module."""
def __init__(self, ir_context: ir.MLIRContext = None):
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
self._ir_module = self._ir_context.new_module()
self._helper = AllDialectHelper(self._ir_context,
ir.OpBuilder(self._ir_context))
@property
def ir_context(self):
return self._ir_context
@property
def ir_module(self):
return self._ir_module
@property
def ir_h(self):
return self._helper
def import_global_function(self, f):
"""Imports a global function.
This facility is not general and does not allow customization of the
containing environment, method import, etc.
Most errors are emitted via the MLIR context's diagnostic infrastructure,
but errors related to extracting source, etc are raised directly.
Args:
f: The python callable.
"""
h = self.ir_h
ir_c = self.ir_context
ir_m = self.ir_module
filename = inspect.getsourcefile(f)
source_lines, start_lineno = inspect.getsourcelines(f)
source = "".join(source_lines)
ast_root = ast.parse(source, filename=filename)
ast.increment_lineno(ast_root, start_lineno - 1)
ast_fd = ast_root.body[0]
filename_ident = ir_c.identifier(filename)
# Define the function.
# TODO: Much more needs to be done here (arg/result mapping, etc)
2020-06-08 06:15:19 +08:00
logging.debug(":::::::")
logging.debug("::: Importing global function {}:\n{}", ast_fd.name,
ast.dump(ast_fd, include_attributes=True))
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
ast_fd.col_offset)
h.builder.insert_before_terminator(ir_m.first_block)
ir_f_type = h.function_type([], [h.basicpy_UnknownType])
ir_f = h.func_op(ast_fd.name, ir_f_type, create_entry_block=True)
fctx = FunctionContext(ir_c=ir_c,
ir_f=ir_f,
ir_h=h,
filename_ident=filename_ident)
fdimport = FunctionDefImporter(fctx, ast_fd)
fdimport.import_body()
return ir_f
class FunctionContext:
"""Accounting information for importing a function."""
__slots__ = [
"ir_c",
"ir_f",
"ir_h",
"filename_ident",
"local_name_value_map",
]
def __init__(self, ir_c, ir_f, ir_h, filename_ident):
self.ir_c = ir_c
self.ir_f = ir_f
self.ir_h = ir_h
self.filename_ident = filename_ident
self.local_name_value_map = dict()
def abort(self, message):
"""Emits an error diagnostic and raises an exception to abort."""
loc = self.current_loc
ir.emit_error(loc, message)
raise EmittedError(loc, message)
@property
def current_loc(self):
return self.ir_h.builder.current_loc
def update_loc(self, ast_node):
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
ast_node.col_offset)
def map_local_name(self, name, value):
self.local_name_value_map[name] = value
logging.debug("Map name({}) -> value({})", name, value)
class BaseNodeVisitor(ast.NodeVisitor):
"""Base class of a node visitor that aborts on unhandled nodes."""
IMPORTER_TYPE = "<unknown>"
def __init__(self, fctx):
super().__init__()
self.fctx = fctx
def visit(self, node):
self.fctx.update_loc(node)
return super().visit(node)
def generic_visit(self, ast_node):
logging.debug("UNHANDLED NODE: {}", ast.dump(ast_node))
self.fctx.abort("unhandled python %s AST node '%s'" %
(self.IMPORTER_TYPE, ast_node.__class__.__name__))
class FunctionDefImporter(BaseNodeVisitor):
"""AST visitor for importing a function's statements.
Handles nodes that are direct children of a FunctionDef.
"""
IMPORTER_TYPE = "statement"
def __init__(self, fctx, ast_fd):
super().__init__(fctx)
self.ast_fd = ast_fd
def import_body(self):
for ast_stmt in self.ast_fd.body:
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
self.visit(ast_stmt)
def visit_Assign(self, ast_node):
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
for target in ast_node.targets:
self.fctx.update_loc(target)
if not isinstance(target.ctx, ast.Store):
# TODO: Del, AugStore, etc
self.fctx.abort("Unsupported assignment context type %s" %
target.ctx.__class__.__name__)
self.fctx.map_local_name(target.id, expr.value)
def visit_Return(self, ast_node):
ir_h = self.fctx.ir_h
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
2020-06-08 06:46:28 +08:00
casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
expr.value).result
ir_h.return_op([casted])
class ExpressionImporter(BaseNodeVisitor):
IMPORTER_TYPE = "expression"
def __init__(self, fctx):
super().__init__(fctx)
self.value = None
def visit(self, node):
super().visit(node)
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
2020-06-08 06:15:19 +08:00
(ast.dump(node),))
def emit_constant(self, value):
ir_c = self.fctx.ir_c
ir_h = self.fctx.ir_h
if value is True:
self.value = ir_h.basicpy_bool_constant_op(True).result
elif value is False:
self.value = ir_h.basicpy_bool_constant_op(False).result
2020-06-08 07:21:00 +08:00
elif value is None:
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
elif isinstance(value, int):
# TODO: Configurable type mapping
ir_type = ir_h.i64_type
ir_attr = ir_c.integer_attr(ir_type, value)
self.value = ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(value, float):
# TODO: Configurable type mapping
ir_type = ir_h.f64_type
ir_attr = ir_c.float_attr(ir_type, value)
self.value = ir_h.constant_op(ir_type, ir_attr).result
elif isinstance(value, str):
self.value = ir_h.basicpy_str_constant_op(value).result
elif isinstance(value, bytes):
self.value = ir_h.basicpy_bytes_constant_op(value).result
elif isinstance(value, type(...)):
2020-06-08 06:49:39 +08:00
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
else:
self.fctx.abort("unknown constant type '%r'" % (value,))
def visit_BinOp(self, ast_node):
ir_h = self.fctx.ir_h
left = ExpressionImporter(self.fctx)
left.visit(ast_node.left)
right = ExpressionImporter(self.fctx)
right.visit(ast_node.right)
self.value = ir_h.basicpy_binary_expr_op(
ir_h.basicpy_UnknownType, left.value, right.value,
ast_node.op.__class__.__name__).result
def visit_Compare(self, ast_node):
ir_h = self.fctx.ir_h
if len(ast_node.ops) != 1:
self.fctx.abort("unsupported short-circuit comparison")
left = ExpressionImporter(self.fctx)
left.visit(ast_node.left)
right = ExpressionImporter(self.fctx)
right.visit(ast_node.comparators[0])
self.value = ir_h.basicpy_binary_compare_op(
left.value, right.value, ast_node.ops[0].__class__.__name__).result
def visit_Name(self, ast_node):
if not isinstance(ast_node.ctx, ast.Load):
self.fctx.abort("Unsupported expression name context type %s" %
ast_node.ctx.__class__.__name__)
# TODO: Need to apply scope rules: local, global, ...
value = self.fctx.local_name_value_map.get(ast_node.id)
if value is None:
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
self.value = value
if sys.version_info < (3, 8, 0):
# <3.8 breaks these out into separate AST classes.
def visit_Num(self, ast_node):
self.emit_constant(ast_node.n)
def visit_Str(self, ast_node):
self.emit_constant(ast_node.s)
def visit_Bytes(self, ast_node):
self.emit_constant(ast_node.s)
def visit_NameConstant(self, ast_node):
self.emit_constant(ast_node.value)
def visit_Ellipsis(self, ast_node):
self.emit_constant(...)
else:
def visit_Constant(self, ast_node):
self.emit_constant(ast_node.value)
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__("%s (at %r)" % (message, loc))
self.loc = loc