torch-mlir/python/npcomp/compiler/importer.py

486 lines
18 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""
Importers for populating MLIR from AST.
"""
import ast
import sys
import traceback
from _npcomp.mlir import ir
from . import logging
from .interfaces import *
__all__ = [
"FunctionContext",
"FunctionDefImporter",
"ExpressionImporter",
]
class FunctionContext:
"""Accounting information for importing a function."""
__slots__ = [
"ir_c",
"ir_f",
"ir_h",
"filename_ident",
"environment",
]
def __init__(self, ir_c, ir_f, ir_h, filename_ident, environment):
self.ir_c = ir_c
self.ir_f = ir_f
self.ir_h = ir_h
self.filename_ident = filename_ident
self.environment = environment
def abort(self, message):
"""Emits an error diagnostic and raises an exception to abort."""
loc = self.current_loc
ir.emit_error(loc, message)
raise EmittedError(loc, message)
def check_partial_evaluated(self, result: PartialEvalResult):
"""Checks that a PartialEvalResult has evaluated without error."""
if result.type == PartialEvalType.ERROR:
exc_type, exc_value, tb = result.yields
loc = self.current_loc
if issubclass(exc_type, UserReportableError):
message = exc_value.message
else:
message = ("Error while evaluating value from environment:\n" +
"".join(traceback.format_exception(exc_type, exc_value, tb)))
ir.emit_error(loc, message)
raise EmittedError(loc, message)
if result.type == PartialEvalType.NOT_EVALUATED:
self.abort("Unable to evaluate expression")
@property
def current_loc(self):
return self.ir_h.builder.current_loc
def update_loc(self, ast_node):
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
ast_node.col_offset)
def lookup_name(self, name) -> NameReference:
"""Lookup a name in the environment, requiring it to have evaluated."""
ref = self.environment.resolve_name(name)
if ref is None:
self.abort("Could not resolve referenced name '{}'".format(name))
logging.debug("Map name({}) -> {}", name, ref)
return ref
def emit_const_value(self, py_value) -> ir.Value:
"""Codes a value as a constant, returning an ir Value."""
env = self.environment
result = env.code_py_value_as_const(py_value)
if result is NotImplemented:
self.abort("Cannot code python value as constant: {}".format(py_value))
return result
def emit_partial_eval_result(self,
partial_result: PartialEvalResult) -> ir.Value:
"""Emits a partial eval result either as a direct IR value or a constant."""
self.check_partial_evaluated(partial_result)
if partial_result.type == PartialEvalType.YIELDS_IR_VALUE:
# Return directly.
return partial_result.yields
elif partial_result.type == PartialEvalType.YIELDS_LIVE_VALUE:
# Import constant.
return self.emit_const_value(partial_result.yields.live_value)
else:
self.abort("Unhandled partial eval result type {}".format(partial_result))
class BaseNodeVisitor(ast.NodeVisitor):
"""Base class of a node visitor that aborts on unhandled nodes."""
IMPORTER_TYPE = "<unknown>"
__slots__ = [
"fctx",
]
def __init__(self, fctx):
super().__init__()
self.fctx = fctx
def visit(self, node):
self.fctx.update_loc(node)
return super().visit(node)
def generic_visit(self, ast_node):
logging.debug("UNHANDLED NODE: {}", ast.dump(ast_node))
self.fctx.abort("unhandled python %s AST node '%s'" %
(self.IMPORTER_TYPE, ast_node.__class__.__name__))
class FunctionDefImporter(BaseNodeVisitor):
"""AST visitor for importing a function's statements.
Handles nodes that are direct children of a FunctionDef.
"""
IMPORTER_TYPE = "statement"
__slots__ = [
"ast_fd",
"_last_was_return",
]
def __init__(self, fctx, ast_fd):
super().__init__(fctx)
self.ast_fd = ast_fd
self._last_was_return = False
def import_body(self):
ir_h = self.fctx.ir_h
for ast_stmt in self.ast_fd.body:
self._last_was_return = False
logging.debug("STMT: {}", ast.dump(ast_stmt, include_attributes=True))
self.visit(ast_stmt)
if not self._last_was_return:
# Add a default terminator.
none_value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
none_cast = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
none_value).result
ir_h.return_op([none_cast])
def visit_Assign(self, ast_node):
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
for target in ast_node.targets:
self.fctx.update_loc(target)
if not isinstance(target.ctx, ast.Store):
# TODO: Del, AugStore, etc
self.fctx.abort("Unsupported assignment context type %s" %
target.ctx.__class__.__name__)
name_ref = self.fctx.lookup_name(target.id)
try:
name_ref.store(self.fctx.environment, expr.value)
logging.debug("STORE: {} <- {}", name_ref, expr.value)
except NotImplementedError:
self.fctx.abort(
"Cannot assign to '{}': Store not supported".format(name_ref))
def visit_Expr(self, ast_node):
ir_h = self.fctx.ir_h
_, ip = ir_h.basicpy_exec_op()
# Evaluate the expression in the exec body.
orig_ip = ir_h.builder.insertion_point
ir_h.builder.insertion_point = ip
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
ir_h.basicpy_exec_discard_op([expr.value])
ir_h.builder.insertion_point = orig_ip
def visit_Pass(self, ast_node):
pass
def visit_Return(self, ast_node):
ir_h = self.fctx.ir_h
expr = ExpressionImporter(self.fctx)
expr.visit(ast_node.value)
casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
expr.value).result
ir_h.return_op([casted])
self._last_was_return = True
class ExpressionImporter(BaseNodeVisitor):
"""Imports expression nodes.
Visitor methods should either raise an exception or set self.value to the
IR value that the expression lowers to.
"""
IMPORTER_TYPE = "expression"
__slots__ = [
"value",
]
def __init__(self, fctx):
super().__init__(fctx)
self.value = None
def visit(self, node):
super().visit(node)
assert self.value, ("ExpressionImporter did not assign a value (%r)" %
(ast.dump(node),))
def sub_evaluate(self, sub_node):
sub_importer = ExpressionImporter(self.fctx)
sub_importer.visit(sub_node)
return sub_importer.value
def emit_constant(self, value):
env = self.fctx.environment
ir_const_value = env.code_py_value_as_const(value)
if ir_const_value is NotImplemented:
self.fctx.abort("unknown constant type '%r'" % (value,))
self.value = ir_const_value
def visit_Attribute(self, ast_node):
# Import the attribute's value recursively as a partial eval if possible.
pe_importer = PartialEvalImporter(self.fctx)
pe_importer.visit(ast_node)
if pe_importer.partial_eval_result:
self.fctx.check_partial_evaluated(pe_importer.partial_eval_result)
self.value = self.fctx.emit_partial_eval_result(
pe_importer.partial_eval_result)
return
self.fctx.abort("unhandled attribute access mode: {}".format(
ast.dump(ast_node)))
def visit_BinOp(self, ast_node):
ir_h = self.fctx.ir_h
left = self.sub_evaluate(ast_node.left)
right = self.sub_evaluate(ast_node.right)
self.value = ir_h.basicpy_binary_expr_op(
ir_h.basicpy_UnknownType, left, right,
ast_node.op.__class__.__name__).result
def visit_BoolOp(self, ast_node):
ir_h = self.fctx.ir_h
if isinstance(ast_node.op, ast.And):
return_first_true = False
elif isinstance(ast_node.op, ast.Or):
return_first_true = True
else:
self.fctx.abort("unknown bool op %r" % (ast.dump(ast_node.op)))
def emit_next(next_nodes):
next_node = next_nodes[0]
next_nodes = next_nodes[1:]
next_value = self.sub_evaluate(next_node)
if not next_nodes:
return next_value
condition_value = ir_h.basicpy_to_boolean_op(next_value).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
condition_value, True)
orig_ip = ir_h.builder.insertion_point
# Short-circuit return case.
ir_h.builder.insertion_point = then_ip if return_first_true else else_ip
next_value_casted = ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
next_value).result
ir_h.scf_yield_op([next_value_casted])
# Nested evaluate next case.
ir_h.builder.insertion_point = else_ip if return_first_true else then_ip
nested_value = emit_next(next_nodes)
nested_value_casted = next_value_casted = ir_h.basicpy_unknown_cast_op(
ir_h.basicpy_UnknownType, nested_value).result
ir_h.scf_yield_op([nested_value_casted])
ir_h.builder.insertion_point = orig_ip
return if_op.result
self.value = emit_next(ast_node.values)
def visit_Call(self, ast_node):
# Evaluate positional args.
evaluated_args = []
for raw_arg in ast_node.args:
evaluated_args.append(self.sub_evaluate(raw_arg))
# Evaluate keyword args.
keyword_args = []
for raw_kw_arg in ast_node.keywords:
keyword_args.append((raw_kw_arg.arg, self.sub_evaluate(raw_kw_arg.value)))
# Perform partial evaluation of the callee.
callee_importer = PartialEvalImporter(self.fctx)
callee_importer.visit(ast_node.func)
callee_result = callee_importer.partial_eval_result
if (callee_result and
callee_result.type == PartialEvalType.YIELDS_LIVE_VALUE):
# This is a function known to the compiler. Perform a template call.
call_result = callee_result.yields.resolve_call(self.fctx.environment,
evaluated_args,
keyword_args)
if call_result.type != PartialEvalType.NOT_EVALUATED:
# Partial evaluation success.
self.fctx.check_partial_evaluated(call_result)
self.value = self.fctx.emit_partial_eval_result(call_result)
return
# The function is not known to the compiler.
self.fctx.check_partial_evaluated(callee_result)
# TODO: Implement first class functions.
self.fctx.abort("unhandled (potentially first-class function): {}".format(
ast.dump(ast_node)))
def visit_Compare(self, ast_node):
# Short-circuit comparison (degenerates to binary comparison when just
# two operands).
ir_h = self.fctx.ir_h
false_value = ir_h.basicpy_bool_constant_op(False).result
def emit_next(left_value, comparisons):
operation, right_node = comparisons[0]
comparisons = comparisons[1:]
right_value = self.sub_evaluate(right_node)
compare_result = ir_h.basicpy_binary_compare_op(
left_value, right_value, operation.__class__.__name__).result
# Terminate by yielding the final compare result.
if not comparisons:
return compare_result
# Emit 'if' op and recurse. The if op takes an i1 (core dialect
# requirement) and returns a basicpy.BoolType. Since this is an 'and',
# all else clauses yield a false value.
compare_result_i1 = ir_h.basicpy_bool_cast_op(ir_h.i1_type,
compare_result).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_BoolType],
compare_result_i1, True)
orig_ip = ir_h.builder.insertion_point
# Build the else clause.
ir_h.builder.insertion_point = else_ip
ir_h.scf_yield_op([false_value])
# Build the then clause.
ir_h.builder.insertion_point = then_ip
nested_result = emit_next(right_value, comparisons)
ir_h.scf_yield_op([nested_result])
ir_h.builder.insertion_point = orig_ip
return if_op.result
self.value = emit_next(self.sub_evaluate(ast_node.left),
list(zip(ast_node.ops, ast_node.comparators)))
def visit_IfExp(self, ast_node):
ir_h = self.fctx.ir_h
test_result = ir_h.basicpy_to_boolean_op(self.sub_evaluate(
ast_node.test)).result
if_op, then_ip, else_ip = ir_h.scf_if_op([ir_h.basicpy_UnknownType],
test_result, True)
orig_ip = ir_h.builder.insertion_point
# Build the then clause
ir_h.builder.insertion_point = then_ip
then_result = self.sub_evaluate(ast_node.body)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
then_result).result
])
# Build the then clause.
ir_h.builder.insertion_point = else_ip
orelse_result = self.sub_evaluate(ast_node.orelse)
ir_h.scf_yield_op([
ir_h.basicpy_unknown_cast_op(ir_h.basicpy_UnknownType,
orelse_result).result
])
ir_h.builder.insertion_point = orig_ip
self.value = if_op.result
def visit_Name(self, ast_node):
if not isinstance(ast_node.ctx, ast.Load):
self.fctx.abort("Unsupported expression name context type %s" %
ast_node.ctx.__class__.__name__)
name_ref = self.fctx.lookup_name(ast_node.id)
pe_result = name_ref.load(self.fctx.environment)
logging.debug("LOAD {} -> {}", name_ref, pe_result)
self.value = self.fctx.emit_partial_eval_result(pe_result)
def visit_UnaryOp(self, ast_node):
ir_h = self.fctx.ir_h
op = ast_node.op
operand_value = self.sub_evaluate(ast_node.operand)
if isinstance(op, ast.Not):
# Special handling for logical-not.
condition_value = ir_h.basicpy_to_boolean_op(operand_value).result
true_value = ir_h.basicpy_bool_constant_op(True).result
false_value = ir_h.basicpy_bool_constant_op(False).result
self.value = ir_h.select_op(condition_value, false_value,
true_value).result
else:
self.fctx.abort("Unknown unary op %r", (ast.dump(op)))
if sys.version_info < (3, 8, 0):
# <3.8 breaks these out into separate AST classes.
def visit_Num(self, ast_node):
self.emit_constant(ast_node.n)
def visit_Str(self, ast_node):
self.emit_constant(ast_node.s)
def visit_Bytes(self, ast_node):
self.emit_constant(ast_node.s)
def visit_NameConstant(self, ast_node):
self.emit_constant(ast_node.value)
def visit_Ellipsis(self, ast_node):
self.emit_constant(...)
else:
def visit_Constant(self, ast_node):
self.emit_constant(ast_node.value)
class PartialEvalImporter(BaseNodeVisitor):
"""Importer for performing greedy partial evaluation.
Concretely this is used for Attribute.value and Call resolution.
Attribute resolution is not just treated as a normal expression because it
is first subject to "partial evaluation", allowing the environment's partial
eval hook to operate on live python values from the containing
environment versus naively emitting code for attribute resolution for
entities that can/should be considered constants from the hosting context.
This is used, for example, to resolve attributes from modules without
immediately dereferencing/transforming the intervening chain of attributes.
"""
IMPORTER_TYPE = "partial_eval"
__slots__ = [
"partial_eval_result",
]
def __init__(self, fctx):
super().__init__(fctx)
self.partial_eval_result = None
def visit_Attribute(self, ast_node):
# Sub-evaluate the 'value'.
sub_eval = PartialEvalImporter(self.fctx)
sub_eval.visit(ast_node.value)
if sub_eval.partial_eval_result:
# Partial sub-evaluation successful.
sub_result = sub_eval.partial_eval_result
else:
# Need to evaluate it as an expression.
sub_expr = ExpressionImporter(self.fctx)
sub_expr.visit(ast_node.value)
assert sub_expr.value, (
"Evaluated sub expression did not return a value: %r" %
(ast_node.value))
sub_result = PartialEvalResult.yields_ir_value(sub_expr.value)
# Attempt to perform a static getattr as a partial eval if still operating
# on a live value.
self.fctx.check_partial_evaluated(sub_result)
if sub_result.type == PartialEvalType.YIELDS_LIVE_VALUE:
logging.debug("STATIC getattr '{}' on {}", ast_node.attr, sub_result)
getattr_result = sub_result.yields.resolve_getattr(
self.fctx.environment, ast_node.attr)
if getattr_result.type != PartialEvalType.NOT_EVALUATED:
self.fctx.check_partial_evaluated(getattr_result)
self.partial_eval_result = getattr_result
return
# If a non-statically evaluable live value, then convert to a constant
# and dynamic dispatch.
ir_value = self.fctx.emit_const_value(sub_result.yields.live_value)
else:
ir_value = sub_result.yields
# Yielding an IR value from a recursive partial evaluation means that the
# entire chain needs to be hoisted to IR.
# TODO: Implement.
self.fctx.abort("dynamic-emitted getattr not yet supported: %r" %
(ir_value,))
def visit_Name(self, ast_node):
name_ref = self.fctx.lookup_name(ast_node.id)
partial_eval_result = name_ref.load(self.fctx.environment)
logging.debug("PARTIAL EVAL {} -> {}", name_ref, partial_eval_result)
self.partial_eval_result = partial_eval_result