mirror of https://github.com/llvm/torch-mlir
189 lines
6.2 KiB
Python
189 lines
6.2 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
"""
|
|
Frontend to the compiler, allowing various ways to import code.
|
|
"""
|
|
|
|
import ast
|
|
import inspect
|
|
from typing import Optional
|
|
|
|
from _npcomp.mlir import ir
|
|
from _npcomp.mlir.dialect import ScfDialectHelper
|
|
from npcomp.dialect import Numpy
|
|
|
|
from . import logging
|
|
from .importer import *
|
|
from .interfaces import *
|
|
from .name_resolver_base import *
|
|
from .value_coder_base import *
|
|
from .target import *
|
|
|
|
__all__ = [
|
|
"ImportFrontend",
|
|
]
|
|
|
|
|
|
class ImportFrontend:
|
|
"""Frontend for importing various entities into a Module."""
|
|
__slots__ = [
|
|
"_ir_context",
|
|
"_ir_module",
|
|
"_ir_h",
|
|
"_config",
|
|
]
|
|
|
|
def __init__(self,
|
|
*,
|
|
config: Configuration,
|
|
ir_context: ir.MLIRContext = None):
|
|
super().__init__()
|
|
self._ir_context = ir.MLIRContext() if not ir_context else ir_context
|
|
self._ir_module = self._ir_context.new_module()
|
|
self._ir_h = AllDialectHelper(self._ir_context,
|
|
ir.OpBuilder(self._ir_context))
|
|
self._config = config
|
|
|
|
@property
|
|
def ir_context(self):
|
|
return self._ir_context
|
|
|
|
@property
|
|
def ir_module(self):
|
|
return self._ir_module
|
|
|
|
@property
|
|
def ir_h(self):
|
|
return self._ir_h
|
|
|
|
def import_global_function(self, f):
|
|
"""Imports a global function.
|
|
|
|
This facility is not general and does not allow customization of the
|
|
containing environment, method import, etc.
|
|
|
|
Most errors are emitted via the MLIR context's diagnostic infrastructure,
|
|
but errors related to extracting source, etc are raised directly.
|
|
|
|
Args:
|
|
f: The python callable.
|
|
"""
|
|
h = self.ir_h
|
|
ir_c = self.ir_context
|
|
ir_m = self.ir_module
|
|
target = self._config.target_factory(h)
|
|
filename = inspect.getsourcefile(f)
|
|
source_lines, start_lineno = inspect.getsourcelines(f)
|
|
source = "".join(source_lines)
|
|
ast_root = ast.parse(source, filename=filename)
|
|
ast.increment_lineno(ast_root, start_lineno - 1)
|
|
ast_fd = ast_root.body[0]
|
|
filename_ident = ir_c.identifier(filename)
|
|
|
|
# Define the function.
|
|
# TODO: Much more needs to be done here (arg/result mapping, etc)
|
|
logging.debug(":::::::")
|
|
logging.debug("::: Importing global function {}:\n{}", ast_fd.name,
|
|
ast.dump(ast_fd, include_attributes=True))
|
|
|
|
# TODO: VERY BAD: Assumes all positional params.
|
|
f_signature = inspect.signature(f)
|
|
f_params = f_signature.parameters
|
|
f_input_types = [
|
|
self._resolve_signature_annotation(target, p.annotation)
|
|
for p in f_params.values()
|
|
]
|
|
f_return_type = self._resolve_signature_annotation(
|
|
target, f_signature.return_annotation)
|
|
ir_f_type = h.function_type(f_input_types, [f_return_type])
|
|
|
|
h.builder.set_file_line_col(filename_ident, ast_fd.lineno,
|
|
ast_fd.col_offset)
|
|
h.builder.insert_before_terminator(ir_m.first_block)
|
|
# TODO: Do not hardcode this IREE attribute.
|
|
attrs = ir_c.dictionary_attr({"iree.module.export": ir_c.unit_attr})
|
|
ir_f = h.func_op(ast_fd.name,
|
|
ir_f_type,
|
|
create_entry_block=True,
|
|
attrs=attrs)
|
|
env = self._create_const_global_env(f,
|
|
parameter_bindings=zip(
|
|
f_params.keys(),
|
|
ir_f.first_block.args),
|
|
target=target)
|
|
fctx = FunctionContext(ir_c=ir_c,
|
|
ir_f=ir_f,
|
|
ir_h=h,
|
|
filename_ident=filename_ident,
|
|
environment=env)
|
|
|
|
fdimport = FunctionDefImporter(fctx, ast_fd)
|
|
fdimport.import_body()
|
|
return ir_f
|
|
|
|
def _create_const_global_env(self, f, parameter_bindings, target):
|
|
"""Helper to generate an environment for a global function.
|
|
|
|
This is a helper for the very common case and will be wholly insufficient
|
|
for advanced cases, including mutable global state, closures, etc.
|
|
Globals from the module are considered immutable.
|
|
"""
|
|
ir_h = self._ir_h
|
|
try:
|
|
code = f.__code__
|
|
globals_dict = f.__globals__
|
|
builtins_module = globals_dict["__builtins__"]
|
|
except AttributeError:
|
|
assert False, (
|
|
"Function {} does not have required user-defined function attributes".
|
|
format(f))
|
|
|
|
# Locals resolver.
|
|
# Note that co_varnames should include both parameter and local names.
|
|
locals_resolver = LocalNameResolver(code.co_varnames)
|
|
resolvers = (
|
|
locals_resolver,
|
|
ConstModuleNameResolver(globals_dict, as_dict=True),
|
|
ConstModuleNameResolver(builtins_module),
|
|
)
|
|
env = Environment(config=self._config, ir_h=ir_h, name_resolvers=resolvers)
|
|
|
|
# Bind parameters.
|
|
for name, value in parameter_bindings:
|
|
logging.debug("STORE PARAM: {} <- {}", name, value)
|
|
locals_resolver.checked_resolve_name(name).store(env, value)
|
|
return env
|
|
|
|
def _resolve_signature_annotation(self, target: Target, annot):
|
|
ir_h = self._ir_h
|
|
if annot is inspect.Signature.empty:
|
|
return ir_h.basicpy_UnknownType
|
|
|
|
# TODO: Do something real here once we need more than the primitive types.
|
|
if annot is int:
|
|
return target.impl_int_type
|
|
elif annot is float:
|
|
return target.impl_float_type
|
|
elif annot is bool:
|
|
return ir_h.basicpy_BoolType
|
|
elif annot is str:
|
|
return ir_h.basicpy_StrType
|
|
else:
|
|
return ir_h.basicpy_UnknownType
|
|
|
|
|
|
################################################################################
|
|
# Support
|
|
################################################################################
|
|
|
|
|
|
# TODO: Remove this hack in favor of a helper function that combines
|
|
# multiple dialect helpers so that we don't need to deal with the sharp
|
|
# edge of initializing multiple native base classes.
|
|
class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
Numpy.DialectHelper.__init__(self, *args, **kwargs)
|
|
ScfDialectHelper.__init__(self, *args, **kwargs)
|