mirror of https://github.com/llvm/torch-mlir
Factor name resolution and constant creation to a new environment facility.
parent
2242f48228
commit
f791909a25
|
@ -0,0 +1,23 @@
|
|||
# XFAIL: *
|
||||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
|
||||
OUTER_ONE = 1
|
||||
OUTER_STRING = "Hello"
|
||||
|
||||
|
||||
# CHECK-LABEL: func @outer_one
|
||||
@import_global
|
||||
def outer_one():
|
||||
return OUTER_ONE
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Union
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
|
||||
from . import logging
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
"BuiltinsValueCoder",
|
||||
"Environment",
|
||||
"NameReference",
|
||||
"NameResolver",
|
||||
"ValueCoder",
|
||||
"ValueCoderChain",
|
||||
]
|
||||
|
||||
|
||||
class ValueCoder:
|
||||
"""Encodes values in various ways.
|
||||
|
||||
Instances are designed to be daisy-chained and should ignore types that they
|
||||
don't understand. Functions return NotImplemented if they cannot handle a
|
||||
case locally.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class ValueCoderChain(ValueCoder):
|
||||
"""Codes values by delegating to sub-coders in order."""
|
||||
__slots__ = ["_sub_coders"]
|
||||
|
||||
def __init__(self, sub_coders):
|
||||
self._sub_coders = sub_coders
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
for sc in self._sub_coders:
|
||||
result = sc.create_const(env, py_value)
|
||||
if result is not NotImplemented:
|
||||
return result
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class NameReference:
|
||||
"""Abstract base class for performing operations on a name."""
|
||||
__slots__ = [
|
||||
"name",
|
||||
]
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def load(self, env: "Environment",
|
||||
ir_h: ir.DialectHelper) -> Optional[ir.Value]:
|
||||
"""Loads the IR Value associated with the name.
|
||||
|
||||
The load may either be direct, returning an existing value or
|
||||
side-effecting, causing a read from an external context.
|
||||
|
||||
Args:
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Returns:
|
||||
An SSA value containing the resolved value (or None if not bound).
|
||||
Raises:
|
||||
NotImplementedError if load is not supported for this name.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
|
||||
"""Stores a new value into the name.
|
||||
|
||||
A subsequent call to 'load' should yield the same value, subject to
|
||||
typing constraints on value equality.
|
||||
|
||||
Args:
|
||||
value: The new value to store into the name.
|
||||
ir_h: The dialect helper used to emit code.
|
||||
Raises:
|
||||
NotImplementedError if store is not supported for this name.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NameResolver:
|
||||
"""Abstract base class that can resolve a name.
|
||||
|
||||
Name resolvers are typically stacked.
|
||||
"""
|
||||
|
||||
def checked_lookup(self, name):
|
||||
ref = self.lookup(name)
|
||||
assert ref is not None, "Lookup of name {} is required".format(name)
|
||||
return ref
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
return None
|
||||
|
||||
|
||||
class Environment(NameResolver):
|
||||
"""Manages access to the environment of a code region.
|
||||
|
||||
This encapsulates name lookup, access to the containing module, etc.
|
||||
"""
|
||||
__slots__ = [
|
||||
"ir_h",
|
||||
"_name_resolvers",
|
||||
"target",
|
||||
"value_coder",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
ir_h: ir.DialectHelper,
|
||||
*,
|
||||
target: Target,
|
||||
name_resolvers=(),
|
||||
value_coder):
|
||||
super().__init__()
|
||||
self.ir_h = ir_h
|
||||
self.target = target
|
||||
self._name_resolvers = name_resolvers
|
||||
self.value_coder = value_coder
|
||||
|
||||
@classmethod
|
||||
def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *,
|
||||
parameter_bindings, **kwargs):
|
||||
"""Helper to generate an environment for a global function.
|
||||
|
||||
This is a helper for the very common case and will be wholly insufficient
|
||||
for advanced cases, including mutable global state, closures, etc.
|
||||
Globals from the module are considered immutable.
|
||||
"""
|
||||
try:
|
||||
code = f.__code__
|
||||
except AttributeError:
|
||||
assert False, "Function {} does not have a __code__ attribute".format(f)
|
||||
|
||||
# Locals resolver.
|
||||
# Note that co_varnames should include both parameter and local names.
|
||||
locals_resolver = LocalNameResolver(code.co_varnames)
|
||||
resolvers = (locals_resolver,)
|
||||
env = cls(ir_h, name_resolvers=resolvers, **kwargs)
|
||||
|
||||
# Bind parameters.
|
||||
for name, value in parameter_bindings:
|
||||
logging.debug("STORE PARAM: {} <- {}", name, value)
|
||||
locals_resolver.checked_lookup(name).store(env, value)
|
||||
return env
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
for resolver in self._name_resolvers:
|
||||
ref = resolver.lookup(name)
|
||||
if ref is not None:
|
||||
return ref
|
||||
return None
|
||||
|
||||
|
||||
class LocalNameReference(NameReference):
|
||||
"""Holds an association between a name and SSA value."""
|
||||
__slots__ = [
|
||||
"_current_value",
|
||||
]
|
||||
|
||||
def __init__(self, name, initial_value=None):
|
||||
super().__init__(name)
|
||||
self._current_value = initial_value
|
||||
|
||||
def load(self, env: "Environment") -> Optional[ir.Value]:
|
||||
return self._current_value
|
||||
|
||||
def store(self, env: "Environment", value: ir.Value):
|
||||
self._current_value = value
|
||||
|
||||
def __repr__(self):
|
||||
return "<LocalNameReference({})>".format(self.name)
|
||||
|
||||
|
||||
class LocalNameResolver(NameResolver):
|
||||
"""Resolves names in a local cache of SSA values.
|
||||
|
||||
This is used to manage locals and arguments (that are not referenced through
|
||||
a closure).
|
||||
"""
|
||||
__slots__ = [
|
||||
"_name_refs",
|
||||
]
|
||||
|
||||
def __init__(self, names):
|
||||
super().__init__()
|
||||
self._name_refs = {name: LocalNameReference(name) for name in names}
|
||||
|
||||
def lookup(self, name) -> Optional[NameReference]:
|
||||
return self._name_refs.get(name)
|
||||
|
||||
|
||||
class BuiltinsValueCoder:
|
||||
"""Value coder for builtin python types."""
|
||||
__slots__ = []
|
||||
|
||||
def create_const(self, env: "Environment", py_value):
|
||||
ir_h = env.ir_h
|
||||
ir_c = ir_h.context
|
||||
if py_value is True:
|
||||
return ir_h.basicpy_bool_constant_op(True).result
|
||||
elif py_value is False:
|
||||
return ir_h.basicpy_bool_constant_op(False).result
|
||||
elif py_value is None:
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(py_value, int):
|
||||
ir_type = env.target.impl_int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, float):
|
||||
ir_type = env.target.impl_float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, py_value)
|
||||
return ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(py_value, str):
|
||||
return ir_h.basicpy_str_constant_op(py_value).result
|
||||
elif isinstance(py_value, bytes):
|
||||
return ir_h.basicpy_bytes_constant_op(py_value).result
|
||||
elif isinstance(py_value, type(...)):
|
||||
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||
return NotImplemented
|
|
@ -7,12 +7,14 @@ Frontend to the compiler, allowing various ways to import code.
|
|||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from _npcomp.mlir import ir
|
||||
from _npcomp.mlir.dialect import ScfDialectHelper
|
||||
from npcomp.dialect import Numpy
|
||||
|
||||
from . import logging
|
||||
from .environment import *
|
||||
from .importer import *
|
||||
from .target import *
|
||||
|
||||
|
@ -38,16 +40,20 @@ class ImportFrontend:
|
|||
"_ir_module",
|
||||
"_helper",
|
||||
"_target_factory",
|
||||
"_value_coder",
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
ir_context: ir.MLIRContext = None,
|
||||
target_factory: TargetFactory = GenericTarget64):
|
||||
*,
|
||||
target_factory: TargetFactory = GenericTarget64,
|
||||
value_coder: Optional[ValueCoder] = None):
|
||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||
self._ir_module = self._ir_context.new_module()
|
||||
self._helper = AllDialectHelper(self._ir_context,
|
||||
ir.OpBuilder(self._ir_context))
|
||||
self._target_factory = target_factory
|
||||
self._value_coder = value_coder if value_coder else BuiltinsValueCoder()
|
||||
|
||||
@property
|
||||
def ir_context(self):
|
||||
|
@ -111,13 +117,19 @@ class ImportFrontend:
|
|||
ir_f_type,
|
||||
create_entry_block=True,
|
||||
attrs=attrs)
|
||||
env = Environment.for_const_global_function(h,
|
||||
f,
|
||||
parameter_bindings=zip(
|
||||
f_params.keys(),
|
||||
ir_f.first_block.args),
|
||||
value_coder=self._value_coder,
|
||||
target=target)
|
||||
fctx = FunctionContext(ir_c=ir_c,
|
||||
ir_f=ir_f,
|
||||
ir_h=h,
|
||||
filename_ident=filename_ident,
|
||||
target=target)
|
||||
for f_arg, ir_arg in zip(f_params.keys(), ir_f.first_block.args):
|
||||
fctx.map_local_name(f_arg, ir_arg)
|
||||
target=target,
|
||||
environment=env)
|
||||
|
||||
fdimport = FunctionDefImporter(fctx, ast_fd)
|
||||
fdimport.import_body()
|
||||
|
|
|
@ -10,6 +10,7 @@ import sys
|
|||
from _npcomp.mlir import ir
|
||||
|
||||
from . import logging
|
||||
from .environment import *
|
||||
from .target import *
|
||||
|
||||
__all__ = [
|
||||
|
@ -27,16 +28,16 @@ class FunctionContext:
|
|||
"ir_h",
|
||||
"target",
|
||||
"filename_ident",
|
||||
"local_name_value_map",
|
||||
"environment",
|
||||
]
|
||||
|
||||
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident):
|
||||
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident, environment):
|
||||
self.ir_c = ir_c
|
||||
self.ir_f = ir_f
|
||||
self.ir_h = ir_h
|
||||
self.target = target
|
||||
self.filename_ident = filename_ident
|
||||
self.local_name_value_map = dict()
|
||||
self.environment = environment
|
||||
|
||||
def abort(self, message):
|
||||
"""Emits an error diagnostic and raises an exception to abort."""
|
||||
|
@ -52,9 +53,12 @@ class FunctionContext:
|
|||
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
|
||||
ast_node.col_offset)
|
||||
|
||||
def map_local_name(self, name, value):
|
||||
self.local_name_value_map[name] = value
|
||||
logging.debug("Map name({}) -> value({})", name, value)
|
||||
def lookup_name(self, name) -> NameReference:
|
||||
ref = self.environment.lookup(name)
|
||||
if ref is None:
|
||||
self.abort("Could not resolve referenced name '{}'".format(name))
|
||||
logging.debug("Map name({}) -> {}", name, ref)
|
||||
return ref
|
||||
|
||||
|
||||
class BaseNodeVisitor(ast.NodeVisitor):
|
||||
|
@ -109,7 +113,13 @@ class FunctionDefImporter(BaseNodeVisitor):
|
|||
# TODO: Del, AugStore, etc
|
||||
self.fctx.abort("Unsupported assignment context type %s" %
|
||||
target.ctx.__class__.__name__)
|
||||
self.fctx.map_local_name(target.id, expr.value)
|
||||
name_ref = self.fctx.lookup_name(target.id)
|
||||
try:
|
||||
name_ref.store(self.fctx.environment, expr.value)
|
||||
logging.debug("STORE: {} <- {}", name_ref, expr.value)
|
||||
except NotImplementedError:
|
||||
self.fctx.abort(
|
||||
"Cannot assign to '{}': Store not supported".format(name_ref))
|
||||
|
||||
def visit_Expr(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
|
@ -155,30 +165,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
return sub_importer.value
|
||||
|
||||
def emit_constant(self, value):
|
||||
ir_c = self.fctx.ir_c
|
||||
ir_h = self.fctx.ir_h
|
||||
if value is True:
|
||||
self.value = ir_h.basicpy_bool_constant_op(True).result
|
||||
elif value is False:
|
||||
self.value = ir_h.basicpy_bool_constant_op(False).result
|
||||
elif value is None:
|
||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||
elif isinstance(value, int):
|
||||
ir_type = self._int_type
|
||||
ir_attr = ir_c.integer_attr(ir_type, value)
|
||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(value, float):
|
||||
ir_type = self._float_type
|
||||
ir_attr = ir_c.float_attr(ir_type, value)
|
||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
||||
elif isinstance(value, str):
|
||||
self.value = ir_h.basicpy_str_constant_op(value).result
|
||||
elif isinstance(value, bytes):
|
||||
self.value = ir_h.basicpy_bytes_constant_op(value).result
|
||||
elif isinstance(value, type(...)):
|
||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||
else:
|
||||
env = self.fctx.environment
|
||||
ir_const_value = env.value_coder.create_const(env, value)
|
||||
if ir_const_value is NotImplemented:
|
||||
self.fctx.abort("unknown constant type '%r'" % (value,))
|
||||
self.value = ir_const_value
|
||||
|
||||
def visit_BinOp(self, ast_node):
|
||||
ir_h = self.fctx.ir_h
|
||||
|
@ -290,10 +281,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
if not isinstance(ast_node.ctx, ast.Load):
|
||||
self.fctx.abort("Unsupported expression name context type %s" %
|
||||
ast_node.ctx.__class__.__name__)
|
||||
# TODO: Need to apply scope rules: local, global, ...
|
||||
value = self.fctx.local_name_value_map.get(ast_node.id)
|
||||
name_ref = self.fctx.lookup_name(ast_node.id)
|
||||
value = name_ref.load(self.fctx.environment)
|
||||
logging.debug("LOAD {} -> {}", name_ref, value)
|
||||
if value is None:
|
||||
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
|
||||
self.fctx.abort("Name reference '{}' cannot be loaded".format(name_ref))
|
||||
self.value = value
|
||||
|
||||
def visit_UnaryOp(self, ast_node):
|
||||
|
|
Loading…
Reference in New Issue