mirror of https://github.com/llvm/torch-mlir
Factor name resolution and constant creation to a new environment facility.
parent
2242f48228
commit
f791909a25
|
@ -0,0 +1,23 @@
|
||||||
|
# XFAIL: *
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
from npcomp.compiler.frontend import *
|
||||||
|
|
||||||
|
|
||||||
|
def import_global(f):
|
||||||
|
fe = ImportFrontend()
|
||||||
|
fe.import_global_function(f)
|
||||||
|
print("// -----")
|
||||||
|
print(fe.ir_module.to_asm())
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
OUTER_ONE = 1
|
||||||
|
OUTER_STRING = "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @outer_one
|
||||||
|
@import_global
|
||||||
|
def outer_one():
|
||||||
|
return OUTER_ONE
|
||||||
|
|
|
@ -0,0 +1,230 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from _npcomp.mlir import ir
|
||||||
|
|
||||||
|
from . import logging
|
||||||
|
from .target import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BuiltinsValueCoder",
|
||||||
|
"Environment",
|
||||||
|
"NameReference",
|
||||||
|
"NameResolver",
|
||||||
|
"ValueCoder",
|
||||||
|
"ValueCoderChain",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ValueCoder:
|
||||||
|
"""Encodes values in various ways.
|
||||||
|
|
||||||
|
Instances are designed to be daisy-chained and should ignore types that they
|
||||||
|
don't understand. Functions return NotImplemented if they cannot handle a
|
||||||
|
case locally.
|
||||||
|
"""
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
def create_const(self, env: "Environment", py_value):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class ValueCoderChain(ValueCoder):
|
||||||
|
"""Codes values by delegating to sub-coders in order."""
|
||||||
|
__slots__ = ["_sub_coders"]
|
||||||
|
|
||||||
|
def __init__(self, sub_coders):
|
||||||
|
self._sub_coders = sub_coders
|
||||||
|
|
||||||
|
def create_const(self, env: "Environment", py_value):
|
||||||
|
for sc in self._sub_coders:
|
||||||
|
result = sc.create_const(env, py_value)
|
||||||
|
if result is not NotImplemented:
|
||||||
|
return result
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class NameReference:
|
||||||
|
"""Abstract base class for performing operations on a name."""
|
||||||
|
__slots__ = [
|
||||||
|
"name",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def load(self, env: "Environment",
|
||||||
|
ir_h: ir.DialectHelper) -> Optional[ir.Value]:
|
||||||
|
"""Loads the IR Value associated with the name.
|
||||||
|
|
||||||
|
The load may either be direct, returning an existing value or
|
||||||
|
side-effecting, causing a read from an external context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ir_h: The dialect helper used to emit code.
|
||||||
|
Returns:
|
||||||
|
An SSA value containing the resolved value (or None if not bound).
|
||||||
|
Raises:
|
||||||
|
NotImplementedError if load is not supported for this name.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def store(self, env: "Environment", value: ir.Value, ir_h: ir.DialectHelper):
|
||||||
|
"""Stores a new value into the name.
|
||||||
|
|
||||||
|
A subsequent call to 'load' should yield the same value, subject to
|
||||||
|
typing constraints on value equality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The new value to store into the name.
|
||||||
|
ir_h: The dialect helper used to emit code.
|
||||||
|
Raises:
|
||||||
|
NotImplementedError if store is not supported for this name.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class NameResolver:
|
||||||
|
"""Abstract base class that can resolve a name.
|
||||||
|
|
||||||
|
Name resolvers are typically stacked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def checked_lookup(self, name):
|
||||||
|
ref = self.lookup(name)
|
||||||
|
assert ref is not None, "Lookup of name {} is required".format(name)
|
||||||
|
return ref
|
||||||
|
|
||||||
|
def lookup(self, name) -> Optional[NameReference]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Environment(NameResolver):
|
||||||
|
"""Manages access to the environment of a code region.
|
||||||
|
|
||||||
|
This encapsulates name lookup, access to the containing module, etc.
|
||||||
|
"""
|
||||||
|
__slots__ = [
|
||||||
|
"ir_h",
|
||||||
|
"_name_resolvers",
|
||||||
|
"target",
|
||||||
|
"value_coder",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ir_h: ir.DialectHelper,
|
||||||
|
*,
|
||||||
|
target: Target,
|
||||||
|
name_resolvers=(),
|
||||||
|
value_coder):
|
||||||
|
super().__init__()
|
||||||
|
self.ir_h = ir_h
|
||||||
|
self.target = target
|
||||||
|
self._name_resolvers = name_resolvers
|
||||||
|
self.value_coder = value_coder
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def for_const_global_function(cls, ir_h: ir.DialectHelper, f, *,
|
||||||
|
parameter_bindings, **kwargs):
|
||||||
|
"""Helper to generate an environment for a global function.
|
||||||
|
|
||||||
|
This is a helper for the very common case and will be wholly insufficient
|
||||||
|
for advanced cases, including mutable global state, closures, etc.
|
||||||
|
Globals from the module are considered immutable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
code = f.__code__
|
||||||
|
except AttributeError:
|
||||||
|
assert False, "Function {} does not have a __code__ attribute".format(f)
|
||||||
|
|
||||||
|
# Locals resolver.
|
||||||
|
# Note that co_varnames should include both parameter and local names.
|
||||||
|
locals_resolver = LocalNameResolver(code.co_varnames)
|
||||||
|
resolvers = (locals_resolver,)
|
||||||
|
env = cls(ir_h, name_resolvers=resolvers, **kwargs)
|
||||||
|
|
||||||
|
# Bind parameters.
|
||||||
|
for name, value in parameter_bindings:
|
||||||
|
logging.debug("STORE PARAM: {} <- {}", name, value)
|
||||||
|
locals_resolver.checked_lookup(name).store(env, value)
|
||||||
|
return env
|
||||||
|
|
||||||
|
def lookup(self, name) -> Optional[NameReference]:
|
||||||
|
for resolver in self._name_resolvers:
|
||||||
|
ref = resolver.lookup(name)
|
||||||
|
if ref is not None:
|
||||||
|
return ref
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LocalNameReference(NameReference):
|
||||||
|
"""Holds an association between a name and SSA value."""
|
||||||
|
__slots__ = [
|
||||||
|
"_current_value",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, name, initial_value=None):
|
||||||
|
super().__init__(name)
|
||||||
|
self._current_value = initial_value
|
||||||
|
|
||||||
|
def load(self, env: "Environment") -> Optional[ir.Value]:
|
||||||
|
return self._current_value
|
||||||
|
|
||||||
|
def store(self, env: "Environment", value: ir.Value):
|
||||||
|
self._current_value = value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<LocalNameReference({})>".format(self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalNameResolver(NameResolver):
|
||||||
|
"""Resolves names in a local cache of SSA values.
|
||||||
|
|
||||||
|
This is used to manage locals and arguments (that are not referenced through
|
||||||
|
a closure).
|
||||||
|
"""
|
||||||
|
__slots__ = [
|
||||||
|
"_name_refs",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, names):
|
||||||
|
super().__init__()
|
||||||
|
self._name_refs = {name: LocalNameReference(name) for name in names}
|
||||||
|
|
||||||
|
def lookup(self, name) -> Optional[NameReference]:
|
||||||
|
return self._name_refs.get(name)
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinsValueCoder:
|
||||||
|
"""Value coder for builtin python types."""
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
def create_const(self, env: "Environment", py_value):
|
||||||
|
ir_h = env.ir_h
|
||||||
|
ir_c = ir_h.context
|
||||||
|
if py_value is True:
|
||||||
|
return ir_h.basicpy_bool_constant_op(True).result
|
||||||
|
elif py_value is False:
|
||||||
|
return ir_h.basicpy_bool_constant_op(False).result
|
||||||
|
elif py_value is None:
|
||||||
|
return ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
||||||
|
elif isinstance(py_value, int):
|
||||||
|
ir_type = env.target.impl_int_type
|
||||||
|
ir_attr = ir_c.integer_attr(ir_type, py_value)
|
||||||
|
return ir_h.constant_op(ir_type, ir_attr).result
|
||||||
|
elif isinstance(py_value, float):
|
||||||
|
ir_type = env.target.impl_float_type
|
||||||
|
ir_attr = ir_c.float_attr(ir_type, py_value)
|
||||||
|
return ir_h.constant_op(ir_type, ir_attr).result
|
||||||
|
elif isinstance(py_value, str):
|
||||||
|
return ir_h.basicpy_str_constant_op(py_value).result
|
||||||
|
elif isinstance(py_value, bytes):
|
||||||
|
return ir_h.basicpy_bytes_constant_op(py_value).result
|
||||||
|
elif isinstance(py_value, type(...)):
|
||||||
|
return ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
||||||
|
return NotImplemented
|
|
@ -7,12 +7,14 @@ Frontend to the compiler, allowing various ways to import code.
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from _npcomp.mlir import ir
|
from _npcomp.mlir import ir
|
||||||
from _npcomp.mlir.dialect import ScfDialectHelper
|
from _npcomp.mlir.dialect import ScfDialectHelper
|
||||||
from npcomp.dialect import Numpy
|
from npcomp.dialect import Numpy
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
from .environment import *
|
||||||
from .importer import *
|
from .importer import *
|
||||||
from .target import *
|
from .target import *
|
||||||
|
|
||||||
|
@ -38,16 +40,20 @@ class ImportFrontend:
|
||||||
"_ir_module",
|
"_ir_module",
|
||||||
"_helper",
|
"_helper",
|
||||||
"_target_factory",
|
"_target_factory",
|
||||||
|
"_value_coder",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ir_context: ir.MLIRContext = None,
|
ir_context: ir.MLIRContext = None,
|
||||||
target_factory: TargetFactory = GenericTarget64):
|
*,
|
||||||
|
target_factory: TargetFactory = GenericTarget64,
|
||||||
|
value_coder: Optional[ValueCoder] = None):
|
||||||
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
||||||
self._ir_module = self._ir_context.new_module()
|
self._ir_module = self._ir_context.new_module()
|
||||||
self._helper = AllDialectHelper(self._ir_context,
|
self._helper = AllDialectHelper(self._ir_context,
|
||||||
ir.OpBuilder(self._ir_context))
|
ir.OpBuilder(self._ir_context))
|
||||||
self._target_factory = target_factory
|
self._target_factory = target_factory
|
||||||
|
self._value_coder = value_coder if value_coder else BuiltinsValueCoder()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ir_context(self):
|
def ir_context(self):
|
||||||
|
@ -111,13 +117,19 @@ class ImportFrontend:
|
||||||
ir_f_type,
|
ir_f_type,
|
||||||
create_entry_block=True,
|
create_entry_block=True,
|
||||||
attrs=attrs)
|
attrs=attrs)
|
||||||
|
env = Environment.for_const_global_function(h,
|
||||||
|
f,
|
||||||
|
parameter_bindings=zip(
|
||||||
|
f_params.keys(),
|
||||||
|
ir_f.first_block.args),
|
||||||
|
value_coder=self._value_coder,
|
||||||
|
target=target)
|
||||||
fctx = FunctionContext(ir_c=ir_c,
|
fctx = FunctionContext(ir_c=ir_c,
|
||||||
ir_f=ir_f,
|
ir_f=ir_f,
|
||||||
ir_h=h,
|
ir_h=h,
|
||||||
filename_ident=filename_ident,
|
filename_ident=filename_ident,
|
||||||
target=target)
|
target=target,
|
||||||
for f_arg, ir_arg in zip(f_params.keys(), ir_f.first_block.args):
|
environment=env)
|
||||||
fctx.map_local_name(f_arg, ir_arg)
|
|
||||||
|
|
||||||
fdimport = FunctionDefImporter(fctx, ast_fd)
|
fdimport = FunctionDefImporter(fctx, ast_fd)
|
||||||
fdimport.import_body()
|
fdimport.import_body()
|
||||||
|
|
|
@ -10,6 +10,7 @@ import sys
|
||||||
from _npcomp.mlir import ir
|
from _npcomp.mlir import ir
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
from .environment import *
|
||||||
from .target import *
|
from .target import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -27,16 +28,16 @@ class FunctionContext:
|
||||||
"ir_h",
|
"ir_h",
|
||||||
"target",
|
"target",
|
||||||
"filename_ident",
|
"filename_ident",
|
||||||
"local_name_value_map",
|
"environment",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident):
|
def __init__(self, ir_c, ir_f, ir_h, target, filename_ident, environment):
|
||||||
self.ir_c = ir_c
|
self.ir_c = ir_c
|
||||||
self.ir_f = ir_f
|
self.ir_f = ir_f
|
||||||
self.ir_h = ir_h
|
self.ir_h = ir_h
|
||||||
self.target = target
|
self.target = target
|
||||||
self.filename_ident = filename_ident
|
self.filename_ident = filename_ident
|
||||||
self.local_name_value_map = dict()
|
self.environment = environment
|
||||||
|
|
||||||
def abort(self, message):
|
def abort(self, message):
|
||||||
"""Emits an error diagnostic and raises an exception to abort."""
|
"""Emits an error diagnostic and raises an exception to abort."""
|
||||||
|
@ -52,9 +53,12 @@ class FunctionContext:
|
||||||
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
|
self.ir_h.builder.set_file_line_col(self.filename_ident, ast_node.lineno,
|
||||||
ast_node.col_offset)
|
ast_node.col_offset)
|
||||||
|
|
||||||
def map_local_name(self, name, value):
|
def lookup_name(self, name) -> NameReference:
|
||||||
self.local_name_value_map[name] = value
|
ref = self.environment.lookup(name)
|
||||||
logging.debug("Map name({}) -> value({})", name, value)
|
if ref is None:
|
||||||
|
self.abort("Could not resolve referenced name '{}'".format(name))
|
||||||
|
logging.debug("Map name({}) -> {}", name, ref)
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
class BaseNodeVisitor(ast.NodeVisitor):
|
class BaseNodeVisitor(ast.NodeVisitor):
|
||||||
|
@ -109,7 +113,13 @@ class FunctionDefImporter(BaseNodeVisitor):
|
||||||
# TODO: Del, AugStore, etc
|
# TODO: Del, AugStore, etc
|
||||||
self.fctx.abort("Unsupported assignment context type %s" %
|
self.fctx.abort("Unsupported assignment context type %s" %
|
||||||
target.ctx.__class__.__name__)
|
target.ctx.__class__.__name__)
|
||||||
self.fctx.map_local_name(target.id, expr.value)
|
name_ref = self.fctx.lookup_name(target.id)
|
||||||
|
try:
|
||||||
|
name_ref.store(self.fctx.environment, expr.value)
|
||||||
|
logging.debug("STORE: {} <- {}", name_ref, expr.value)
|
||||||
|
except NotImplementedError:
|
||||||
|
self.fctx.abort(
|
||||||
|
"Cannot assign to '{}': Store not supported".format(name_ref))
|
||||||
|
|
||||||
def visit_Expr(self, ast_node):
|
def visit_Expr(self, ast_node):
|
||||||
ir_h = self.fctx.ir_h
|
ir_h = self.fctx.ir_h
|
||||||
|
@ -155,30 +165,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
return sub_importer.value
|
return sub_importer.value
|
||||||
|
|
||||||
def emit_constant(self, value):
|
def emit_constant(self, value):
|
||||||
ir_c = self.fctx.ir_c
|
env = self.fctx.environment
|
||||||
ir_h = self.fctx.ir_h
|
ir_const_value = env.value_coder.create_const(env, value)
|
||||||
if value is True:
|
if ir_const_value is NotImplemented:
|
||||||
self.value = ir_h.basicpy_bool_constant_op(True).result
|
|
||||||
elif value is False:
|
|
||||||
self.value = ir_h.basicpy_bool_constant_op(False).result
|
|
||||||
elif value is None:
|
|
||||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_NoneType).result
|
|
||||||
elif isinstance(value, int):
|
|
||||||
ir_type = self._int_type
|
|
||||||
ir_attr = ir_c.integer_attr(ir_type, value)
|
|
||||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
|
||||||
elif isinstance(value, float):
|
|
||||||
ir_type = self._float_type
|
|
||||||
ir_attr = ir_c.float_attr(ir_type, value)
|
|
||||||
self.value = ir_h.constant_op(ir_type, ir_attr).result
|
|
||||||
elif isinstance(value, str):
|
|
||||||
self.value = ir_h.basicpy_str_constant_op(value).result
|
|
||||||
elif isinstance(value, bytes):
|
|
||||||
self.value = ir_h.basicpy_bytes_constant_op(value).result
|
|
||||||
elif isinstance(value, type(...)):
|
|
||||||
self.value = ir_h.basicpy_singleton_op(ir_h.basicpy_EllipsisType).result
|
|
||||||
else:
|
|
||||||
self.fctx.abort("unknown constant type '%r'" % (value,))
|
self.fctx.abort("unknown constant type '%r'" % (value,))
|
||||||
|
self.value = ir_const_value
|
||||||
|
|
||||||
def visit_BinOp(self, ast_node):
|
def visit_BinOp(self, ast_node):
|
||||||
ir_h = self.fctx.ir_h
|
ir_h = self.fctx.ir_h
|
||||||
|
@ -290,10 +281,11 @@ class ExpressionImporter(BaseNodeVisitor):
|
||||||
if not isinstance(ast_node.ctx, ast.Load):
|
if not isinstance(ast_node.ctx, ast.Load):
|
||||||
self.fctx.abort("Unsupported expression name context type %s" %
|
self.fctx.abort("Unsupported expression name context type %s" %
|
||||||
ast_node.ctx.__class__.__name__)
|
ast_node.ctx.__class__.__name__)
|
||||||
# TODO: Need to apply scope rules: local, global, ...
|
name_ref = self.fctx.lookup_name(ast_node.id)
|
||||||
value = self.fctx.local_name_value_map.get(ast_node.id)
|
value = name_ref.load(self.fctx.environment)
|
||||||
|
logging.debug("LOAD {} -> {}", name_ref, value)
|
||||||
if value is None:
|
if value is None:
|
||||||
self.fctx.abort("Local variable '%s' has not been assigned" % ast_node.id)
|
self.fctx.abort("Name reference '{}' cannot be loaded".format(name_ref))
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def visit_UnaryOp(self, ast_node):
|
def visit_UnaryOp(self, ast_node):
|
||||||
|
|
Loading…
Reference in New Issue