torch-mlir/python/npcomp/compiler/importer.py

450 lines
16 KiB
Python
Raw Normal View History

2020-06-10 08:16:36 +08:00
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Importers for populating MLIR from AST.
"""
import ast
import sys
import traceback
2020-06-10 08:16:36 +08:00
from _npcomp.mlir import ir
from . import logging
from .environment import *
from .target import *
2020-06-10 08:16:36 +08:00
__all__ = [
"FunctionContext",
"FunctionDefImporter",
"ExpressionImporter",
]
class FunctionContext:
"""Accounting information for importing a function."""
__slots__ = [
"ir_c",
"ir_f",
"ir_h",
"target",
2020-06-10 08:16:36 +08:00
"filename_ident",
"environment",
2020-06-10 08:16:36 +08:00
]
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident, environment):
2020-06-10 08:16:36 +08:00
self.ir_c = ir_c
self.ir_f = ir_f
self.ir_h = ir_h
self.target = target
2020-06-10 08:16:36 +08:00
self.filename_ident = filename_ident
self.environment = environment
2020-06-10 08:16:36 +08:00
def abort(self, message):
"""Emits an error diagnostic and raises an exception to abort."""
loc = self.current_loc
ir.emit_error(loc, message)
raise EmittedError(loc, message)
def check_macro_evaluated(self, result: MacroEvalResult):
"""Checks that a macro has evaluated without error."""
if result.type == MacroEvalType.ERROR:
exc_info = result.yields
loc = self.current_loc
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(*exc_info)))
ir.emit_error(loc, message)
raise EmittedError(loc, message)
if result.type == MacroEvalType.NOT_EVALUATED:
self.abort("Unable to evaluate expression")
2020-06-10 08:16:36 +08:00
@property
def current_loc(self):
return self.ir_h.builder.current_loc
def update_loc(self, ast_node):
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
ast_node.col_offset)
def lookup_name(self, name) -> NameReference:
"""Lookup a name in the environment, requiring it to have evaluated."""
ref = self.environment.lookup(name)
if ref is None:
self.abort("Could not resolve referenced name '{}'".format(name))
logging.debug("Map name({}) -> {}", name, ref)
return ref
2020-06-10 08:16:36 +08:00
def emit_const_value(self, py_value) -> ir.Value:
"""Codes a value as a constant, returning an ir Value."""
env = self.environment
result = env.value_coder.create_const(env, py_value)
if result is NotImplemented:
self.abort("Cannot code python value as constant: {}".format(py_value))
return result
def emit_macro_result(self, macro_result: MacroEvalResult) -> ir.Value:
"""Emits a macro result either as a direct IR value or a constant."""
self.check_macro_evaluated(macro_result)
if macro_result.type == MacroEvalType.YIELDS_IR_VALUE:
# Return directly.
return macro_result.yields
elif macro_result.type == MacroEvalType.YIELDS_LIVE_VALUE:
# Import constant.
return self.emit_const_value(macro_result.yields.live_value)
else:
self.abort("Unhandled macro result type {}".format(macro_result))
2020-06-10 08:16:36 +08:00
class BaseNodeVisitor(ast.NodeVisitor):
"""Base class of a node visitor that aborts on unhandled nodes."""
IMPORTER_TYPE = "<unknown>"
def __init__(self, fctx):
super().__init__()
self.fctx = fctx
def visit(self, node):
self.fctx.update_loc(node)
return super().visit(node)
def generic_visit(self, ast_node):
logging.debug("UNHANDLED NODE: {}", ast.dump(ast_node))
self.fctx.abort("unhandled python %s AST node '%s'" %
(self.IMPORTER_TYPE, ast_node.__class__.__name__))
class FunctionDefImporter(BaseNodeVisitor):
"""AST visitor for importing a function's statements.
Handles nodes that are direct children of a FunctionDef.
"""
IMPORTER_TYPE = "statement"
def __init__(self, fctx, ast_fd):
super().__init__(fctx)
self.ast_fd = ast_fd
self._last_was_return = False
2020-06-10 08:16:36 +08:00
def import_body(self):
ir_h = self.fctx.ir_h
2020-06-10 08:16:36 +08:00
for ast_stmt in self.ast_fd.body:
self._last_was_return = False
2020-06-10 08:16:36 +08:00
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
self.visit(ast_stmt)
if not self._last_was_return:
# Add a default terminator.
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
none_value).result
ir_h.return_op([none_cast])
2020-06-10 08:16:36 +08:00
def visit_Assign(self, ast_node):
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
for target in ast_node.targets:
self.fctx.update_loc(target)
if not isinstance(target.ctx, ast.Store):
# TODO: Del, AugStore, etc
self.fctx.abort("Unsupported assignment context type %s" %
target.ctx.__class__.__name__)
name_ref = self.fctx.lookup_name(target.id)
try:
name_ref.store(self.fctx.environment, expr.value)
logging.debug("STORE: {} <- {}", name_ref, expr.value)
except NotImplementedError:
self.fctx.abort(
"Cannot assign to '{}': Store not supported".format(name_ref))
2020-06-10 08:16:36 +08:00
def visit_Expr(self, ast_node):
ir_h = self.fctx.ir_h
2020-06-10 09:35:21 +08:00
_, ip = ir_h.basicpy_exec_op()
# Evaluate the expression in the exec body.
orig_ip = ir_h.builder.insertion_point
ir_h.builder.insertion_point = ip
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
ir_h.basicpy_exec_discard_op([expr.value])
ir_h.builder.insertion_point = orig_ip
def visit_Pass(self, ast_node):
pass
2020-06-10 08:16:36 +08:00
def visit_Return(self, ast_node):
ir_h = self.fctx.ir_h
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
expr.value).result
ir_h.return_op([casted])
self._last_was_return = True
2020-06-10 08:16:36 +08:00
class ExpressionImporter(BaseNodeVisitor):
"""Imports expression nodes.
Visitor methods should either raise an exception or set self.value to the
IR value that the expression lowers to.
"""
2020-06-10 08:16:36 +08:00
IMPORTER_TYPE = "expression"
def __init__(self, fctx):
super().__init__(fctx)
self.value = None
def visit(self, node):
super().visit(node)
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
(ast.dump(node),))
def sub_evaluate(self, sub_node):
sub_importer = ExpressionImporter(self.fctx)
sub_importer.visit(sub_node)
return sub_importer.value
def emit_constant(self, value):
env = self.fctx.environment
ir_const_value = env.value_coder.create_const(env, value)
if ir_const_value is NotImplemented:
2020-06-10 08:16:36 +08:00
self.fctx.abort("unknown constant type '%r'" % (value,))
self.value = ir_const_value
2020-06-10 08:16:36 +08:00
def visit_Attribute(self, ast_node):
# Import the attribute's value recursively as a macro if possible.
macro_importer = MacroImporter(self.fctx)
macro_importer.visit(ast_node)
if macro_importer.macro_result:
self.fctx.check_macro_evaluated(macro_importer.macro_result)
self.value = self.fctx.emit_macro_result(macro_importer.macro_result)
return
self.fctx.abort("unhandled attribute access mode: {}".format(
ast.dump(ast_node)))
2020-06-10 08:16:36 +08:00
def visit_BinOp(self, ast_node):
ir_h = self.fctx.ir_h
left = self.sub_evaluate(ast_node.left)
right = self.sub_evaluate(ast_node.right)
self.value = ir_h.basicpy_binary_expr_op(
ir_h.basicpy_UnknownType, left, right,
ast_node.op.__class__.__name__).result
def visit_BoolOp(self, ast_node):
ir_h = self.fctx.ir_h
if isinstance(ast_node.op, ast.And):
return_first_true = False
elif isinstance(ast_node.op, ast.Or):
return_first_true = True
else:
self.fctx.abort("unknown bool op %r" % (ast.dump(ast_node.op)))
def emit_next(next_nodes):
next_node = next_nodes[0]
next_nodes = next_nodes[1:]
next_value = self.sub_evaluate(next_node)
if not next_nodes:
return next_value
condition_value = ir_h.basicpy_to_boolean_op(next_value).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
condition_value, True)
orig_ip = ir_h.builder.insertion_point
# Short-circuit return case.
ir_h.builder.insertion_point = then_ip if return_first_true else else_ip
next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
next_value).result
ir_h.scf_yield_op([next_value_casted])
# Nested evaluate next case.
ir_h.builder.insertion_point = else_ip if return_first_true else then_ip
nested_value = emit_next(next_nodes)
nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op(
ir_h.basicpy_UnknownType, nested_value).result
ir_h.scf_yield_op([nested_value_casted])
ir_h.builder.insertion_point = orig_ip
return if_op.result
self.value = emit_next(ast_node.values)
def visit_Compare(self, ast_node):
# Short-circuit comparison (degenerates to binary comparison when just
# two operands).
ir_h = self.fctx.ir_h
false_value = ir_h.basicpy_bool_constant_op(False).result
def emit_next(left_value, comparisons):
operation, right_node = comparisons[0]
comparisons = comparisons[1:]
right_value = self.sub_evaluate(right_node)
compare_result = ir_h.basicpy_binary_compare_op(
left_value, right_value, operation.__class__.__name__).result
# Terminate by yielding the final compare result.
if not comparisons:
return compare_result
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
# all else clauses yield a false value.
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
compare_result).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
compare_result_i1, True)
orig_ip = ir_h.builder.insertion_point
# Build the else clause.
ir_h.builder.insertion_point = else_ip
ir_h.scf_yield_op([false_value])
# Build the then clause.
ir_h.builder.insertion_point = then_ip
nested_result = emit_next(right_value, comparisons)
ir_h.scf_yield_op([nested_result])
ir_h.builder.insertion_point = orig_ip
return if_op.result
self.value = emit_next(self.sub_evaluate(ast_node.left),
list(zip(ast_node.ops, ast_node.comparators)))
def visit_IfExp(self, ast_node):
ir_h = self.fctx.ir_h
test_result = ir_h.basicpy_to_boolean_op(self.sub_evaluate(
ast_node.test)).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
test_result, True)
orig_ip = ir_h.builder.insertion_point
# Build the then clause
ir_h.builder.insertion_point = then_ip
then_result = self.sub_evaluate(ast_node.body)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
then_result).result
])
# Build the then clause.
ir_h.builder.insertion_point = else_ip
orelse_result = self.sub_evaluate(ast_node.orelse)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
orelse_result).result
])
ir_h.builder.insertion_point = orig_ip
self.value = if_op.result
def visit_Name(self, ast_node):
if not isinstance(ast_node.ctx, ast.Load):
self.fctx.abort("Unsupported expression name context type %s" %
ast_node.ctx.__class__.__name__)
name_ref = self.fctx.lookup_name(ast_node.id)
macro_result = name_ref.load(self.fctx.environment)
logging.debug("LOAD {} -> {}", name_ref, macro_result)
self.value = self.fctx.emit_macro_result(macro_result)
2020-06-10 08:16:36 +08:00
def visit_UnaryOp(self, ast_node):
ir_h = self.fctx.ir_h
op = ast_node.op
operand_value = self.sub_evaluate(ast_node.operand)
if isinstance(op, ast.Not):
# Special handling for logical-not.
condition_value = ir_h.basicpy_to_boolean_op(operand_value).result
true_value = ir_h.basicpy_bool_constant_op(True).result
false_value = ir_h.basicpy_bool_constant_op(False).result
self.value = ir_h.select_op(condition_value, false_value,
true_value).result
else:
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
if sys.version_info < (3, 8, 0):
# <3.8 breaks these out into separate AST classes.
def visit_Num(self, ast_node):
self.emit_constant(ast_node.n)
def visit_Str(self, ast_node):
self.emit_constant(ast_node.s)
def visit_Bytes(self, ast_node):
self.emit_constant(ast_node.s)
def visit_NameConstant(self, ast_node):
self.emit_constant(ast_node.value)
def visit_Ellipsis(self, ast_node):
self.emit_constant(...)
else:
def visit_Constant(self, ast_node):
self.emit_constant(ast_node.value)
class MacroImporter(BaseNodeVisitor):
"""Importer for expressions that can resolve through the environment's macro
system.
Concretely this is used for Attribute.value and Call resolution.
Attribute resolution is not just treated as a normal expression because it
is first subject to "macro expansion", allowing the environment's macro
resolution facility to operate on live python values from the containing
environment versus naively emitting code for attribute resolution from
entities that can/should be considered constants from the hosting context.
This is used, for example, to resolve attributes from modules without
by immediately dereferencing/transforming the intervening chain of attributes.
"""
IMPORTER_TYPE = "macro"
def __init__(self, fctx):
super().__init__(fctx)
self.macro_result = None
def visit_Attribute(self, ast_node):
# Sub-evaluate the 'value'.
sub_macro = MacroImporter(self.fctx)
sub_macro.visit(ast_node.value)
if sub_macro.macro_result:
# Macro sub-evaluation successful.
sub_result = sub_macro.macro_result
else:
# Need to evaluate it as an expression.
sub_expr = ExpressionImporter(self.fctx)
sub_expr.visit(ast_node.value)
assert sub_expr.value, (
"Macro sub expression did not return a value: %r" % (ast_node.value))
sub_result = MacroEvalResult.yields_ir_value(sub_expr.value)
# Attempt to perform a static getattr as a macro if still operating on a
# live value.
self.fctx.check_macro_evaluated(sub_result)
if sub_result.type == MacroEvalType.YIELDS_LIVE_VALUE:
logging.debug("STATIC getattr '{}' on {}", ast_node.attr, sub_result)
getattr_result = sub_result.yields.resolve_getattr(
self.fctx.environment, ast_node.attr)
if getattr_result.type != MacroEvalType.NOT_EVALUATED:
self.fctx.check_macro_evaluated(getattr_result)
self.macro_result = getattr_result
return
# If a non-statically evaluable live value, then convert to a constant
# and dynamic dispatch.
ir_value = self.fctx.emit_const_value(sub_result.yields.live_value)
else:
ir_value = sub_result.yields
# Yielding an IR value from a recursive macro evaluation means that the
# entire chain needs to be hoisted to IR.
# TODO: Implement.
self.fctx.abort("dynamic-emitted getattr not yet supported: %r" %
(ir_value,))
def visit_Name(self, ast_node):
name_ref = self.fctx.lookup_name(ast_node.id)
macro_result = name_ref.load(self.fctx.environment)
logging.debug("LOAD MACRO {} -> {}", name_ref, macro_result)
self.macro_result = macro_result
2020-06-10 08:16:36 +08:00
class EmittedError(Exception):
"""Exception subclass that indicates an error diagnostic has been emitted.
By throwing, this lets us abort and handle at a higher level so as not
to duplicate diagnostics.
"""
def __init__(self, loc, message):
super().__init__("%s (at %r)" % (message, loc))
self.loc = loc