Add "template function" ops and importer code.

* This starts to lay down the infra for reasoning about calls
* Adds the importer code to generate IR for function calls of compiler recognized static functions.
pull/1/head
Stella Laurenzo 2020-06-26 18:36:36 -07:00
parent 26fd2a576e
commit 7bd5733d38
11 changed files with 415 additions and 7 deletions

View File

@ -143,3 +143,18 @@ This is accomplished with the following `PartialEvalHook` setup:
```
It is expected that this facility will evolve substantially, as it is the primary intended mechanism for remapping significant parts of the python namespace to builtin constructs (i.e. it will be the primary way to map `numpy` functions and values).
## Calls
This is very much a WIP. Relevant ops:
* `func_template`: Aggregates a list of function overloads to choose for a symbolic name.
* `func_template_call`: Performs a symbolic call with python source conventions.
The idea is that a library modules of `func_template` definitions is assembled with all concrete implementations that have compiler support. The python compiler will iterate over all such templates and bind partial evaluation rules in the environment to detect the calls. Then, when importing, `func_template_call` ops make the call.
See the `basicpy.func_template` op for more detailed notes. The intention is that compiler-supported functions, methods, attribute getter/setter, and dunder functions can all exist in the library, with concrete resolution carried out by type constraints and corresponding type inference passes. Upon matches, concrete functions are pulled into the module being compiled and possibly inlined. With enough type constraints and some iteration, this should converge reasonably to a statically typed program.
See the tests:
* [template_call.py](../pytest/Compiler/template_call.py)

View File

@ -16,6 +16,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -10,6 +10,7 @@
#define NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
include "BasicpyDialect.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
@ -201,6 +202,127 @@ def Basicpy_ExecDiscardOp : Basicpy_Op<"exec_discard", [
let assemblyFormat = "operands attr-dict `:` type(operands)";
}
def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> {
let summary = "Calls a function template";
let description = [{
Most function calls start with this generic calling op, which binds
symbolically to a func_template. At this level, there are very few
semantics associated with the call, since, often, both types and the
specific concrete callee cannot be determined.
Per python calling conventions, all functions return one result, even if
None or a tuple (which may be syntactically unpacked to multiple results).
If specified, the `argNames` operand is right aligned to the list of
positional `args`, representing arguments that are special or have been
passed with a keyword. The following arg names are special:
'*': Indicates that the argument is a positional argument pack (must be
the first arg name, if present).
'**': Indicates that the argument is a keyword argument pack (must be
the last arg name, if present).
}];
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$args,
StrArrayAttr:$arg_names);
let results = (outs AnyType:$result);
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
let verifier = [{ return verifyBasicpyOp(*this); }];
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result">,
];
}
def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"FuncTemplateTerminatorOp">,
NativeOpTrait<"SymbolTable">,
Symbol]> {
let summary = "Group of multiple overload-resolved concrete functions";
let description = [{
The outer func_template op acts as a module that can contain named concrete
functions that are interpreted as overloads. If the function signature is
sufficient to disambiguate (i.e. with nothing more than arity and MLIR
argument types), then this is all that is needed. However, in many cases,
additional attributes will need to be specified to further constrain types.
The first matching function signature is selected to satisfy a
`func_template_call` op.
TODO: Define this extended constraint matching.
Instantiation
-------------
Once a concrete function is selected as being applicable to a given call,
it will typically be instantiated as a standalone, unspecialized function
in the calling module (as a peer to the func_template). This function
will be uniquely identified by concating the outer func_template's symbol
name, '$', and the concrete instance's symbol name.
Note that the function may still be unspecialized (in that it contains
UnknownType arguments/results), and type inference is expected to further
specialize/inline/constrain it.
Naming
------
By convention, func_templates are named to avoid collision for various
uses:
- Global function templates: "__global$python.qualified.name"
- Method names: "__method$method_name"
- Attribute getter: "__getattr$attr_name"
- Attribute setter: "__setattr$attr_name"
As in user-level python, for functions that bind to an instance, the first
argument must be a concrete type for the bound instance type. In this way,
there is one `func_template` for every unique member name and the normal
type constraints system is used to select the overload, just as if it was
a normal function call. It is left to utility routines to merge libraries
in a way that preserves this invariant.
TODO: This needs to be fleshed out more as some additional rules about
ordering and conflict resolution are likely needed to make this correct.
Correlation with python runtime
-------------------------------
When extracting a program, it is typically necessary to create weak
references to specific python functions and correlate them back to a named
template defined here. Often times this can just be done lexically, but
to avoid fragility, any func_template that correlates to a python
runtime function will have an additional attribute `py_bind` that is an
array of StringAttr qualified names to resolve and bind to in the python
runtime. In cases of divergence, the symbol name of the template should
be chosen just for uniqueness (not significance).
The qualified name format for `py_bind` attribute is:
package.name#local.qualified.name
}];
let arguments = (ins);
let regions = (region SizedRegion<1>:$body);
let verifier = [{ return verifyBasicpyOp(*this); }];
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result">,
];
let extraClassDeclaration = [{
OpBuilder getBodyBuilder() {
Block* body = getBody(0);
return OpBuilder::atBlockEnd(body);
}
}];
}
def Basicpy_FuncTemplateTerminatorOp : Basicpy_Op<"func_template_terminator", [
HasParent<"Basicpy::FuncTemplateOp">,
Terminator]> {
let summary = "Terminator pseudo-op for the FuncTemplateOp";
let parser = ?;
let printer = ?;
}
def Basicpy_SlotObjectMakeOp : Basicpy_Op<"slot_object_make", [
NoSideEffect]> {
let summary = "Creates an instance of a SlotObject type";

View File

@ -8,6 +8,7 @@
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -47,17 +48,85 @@ void ExecOp::build(OpBuilder &builder, OperationState &result) {
static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
Region *bodyRegion = result->addRegion();
if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}) ||
parser.parseOptionalAttrDict(result->attributes))
if (parser.parseOptionalAttrDictWithKeyword(result->attributes) ||
parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
return success();
}
static void printExecOp(OpAsmPrinter &p, ExecOp op) {
p << op.getOperationName();
p.printOptionalAttrDictWithKeyword(op.getAttrs());
p.printRegion(op.body());
}
//===----------------------------------------------------------------------===//
// FuncTemplateCallOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) {
auto argNames = op.arg_names();
if (argNames.size() > op.args().size()) {
return op.emitOpError() << "expected <= kw arg names vs args";
}
for (auto it : llvm::enumerate(argNames)) {
auto argName = it.value().cast<StringAttr>().getValue();
if (argName == "*" && it.index() != 0) {
return op.emitOpError() << "positional arg pack must be the first kw arg";
}
if (argName == "**" && it.index() != argNames.size() - 1) {
return op.emitOpError() << "kw arg pack must be the last kw arg";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// FuncTemplateOp
//===----------------------------------------------------------------------===//
void FuncTemplateOp::build(OpBuilder &builder, OperationState &result) {
OpBuilder::InsertionGuard guard(builder);
ensureTerminator(*result.addRegion(), builder, result.location);
}
static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
OperationState *result) {
Region *bodyRegion = result->addRegion();
StringAttr symbolName;
if (parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
result->attributes) ||
parser.parseOptionalAttrDictWithKeyword(result->attributes) ||
parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
FuncTemplateOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
result->location);
return success();
}
static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
p << op.getOperationName() << " ";
p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
{SymbolTable::getSymbolAttrName()});
p.printRegion(op.body());
}
static LogicalResult verifyBasicpyOp(FuncTemplateOp op) {
Block *body = op.getBody();
for (auto &childOp : body->getOperations()) {
if (!llvm::isa<FuncOp>(childOp) &&
!llvm::isa<FuncTemplateTerminatorOp>(childOp)) {
return childOp.emitOpError() << "illegal operation in func_template";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SlotObjectMakeOp
//===----------------------------------------------------------------------===//

View File

@ -504,6 +504,14 @@ void PyContext::bind(py::module m) {
}
return DictionaryAttr::get(attrs, &self.context);
})
.def("array_attr",
[](PyContext &self, py::list l) -> PyAttribute {
SmallVector<Attribute, 4> attrs;
for (auto &it : l) {
attrs.push_back(it.cast<PyAttribute>().attr);
}
return ArrayAttr::get(attrs, &self.context);
})
.def("dense_elements_attr",
[](PyContext &self, py::buffer array) -> PyAttribute {
// Request a contiguous view.

View File

@ -0,0 +1,28 @@
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
import math
from npcomp.compiler.frontend import *
def import_global(f):
fe = ImportFrontend()
fe.import_global_function(f)
print("// -----")
print(fe.ir_module.to_asm())
return f
# CHECK-LABEL: func @call_ceil_positional
@import_global
def call_ceil_positional(n):
# CHECK: basicpy.func_template_call @__global$math.ceil(%arg0) kw [] : (!basicpy.UnknownType) -> !basicpy.UnknownType
return math.ceil(n)
# CHECK-LABEL: func @call_isclose_kw
@import_global
def call_isclose_kw(n):
# CHECK-DAG: %[[RTOL:.*]] = constant 2.000000e-06
# CHECK-DAG: %[[ABSTOL:.*]] = constant 2.000000e-01
# CHECK: basicpy.func_template_call @__global$math.isclose(%arg0, %[[RTOL]], %[[ABSTOL]]) kw ["rtol", "abs_tol"] : (!basicpy.UnknownType, f64, f64) -> !basicpy.UnknownType
return math.isclose(n, rtol=2e-6, abs_tol=0.2)

View File

@ -176,6 +176,11 @@ class LiveValueRef:
"""Gets a named attribute from the live value."""
return PartialEvalResult.not_evaluated()
def resolve_call(self, env: "Environment", args,
keywords) -> PartialEvalResult:
"""Resolves a function call given 'args' and 'keywords'."""
return PartialEvalResult.not_evaluated()
def __repr__(self):
return "MacroValueRef({}, {})".format(self.__class__.__name__,
self.live_value)
@ -195,6 +200,31 @@ class ResolveAttrLiveValueRef(LiveValueRef):
return env.partial_eval_hook.resolve(attr_py_value)
class TemplateCallLiveValueRef(LiveValueRef):
"""Custom LiveValueRef that resolves calls to a func_template_call op."""
__slots__ = ["callee_name"]
def __init__(self, callee_name, live_value):
super().__init__(live_value)
self.callee_name = callee_name
def resolve_call(self, env: "Environment", args,
keywords) -> PartialEvalResult:
linear_args = list(args)
kw_arg_names = []
for kw_name, kw_value in keywords:
kw_arg_names.append(kw_name)
linear_args.append(kw_value)
ir_h = env.ir_h
result_ir_value = ir_h.basicpy_func_template_call_op(
result_type=ir_h.basicpy_UnknownType,
callee_symbol=self.callee_name,
args=linear_args,
arg_names=kw_arg_names).result
return PartialEvalResult.yields_ir_value(result_ir_value)
class PartialEvalHook:
"""Owned by an environment to customize partial evaluation."""
__slots__ = [
@ -248,6 +278,12 @@ class PartialEvalHook:
lambda pv: PartialEvalResult.yields_live_value(
ResolveAttrLiveValueRef(pv)), **kwargs)
def enable_template_call(self, callee_name, **kwargs):
""""Enables a global template call."""
self._bind(
lambda pv: PartialEvalResult.yields_live_value(
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)
################################################################################
# Environment

View File

@ -167,12 +167,17 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
def build_default_partial_eval_hook() -> PartialEvalHook:
mr = PartialEvalHook()
pe = PartialEvalHook()
### Modules
mr.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
### Tuples
# Enable attribute resolution on tuple, which includes namedtuple (which is
# really what we want).
mr.enable_getattr(for_type=tuple)
return mr
pe.enable_getattr(for_type=tuple)
### Temp: resolve a function to a template call for testing
import math
pe.enable_template_call("__global$math.ceil", for_ref=math.ceil)
pe.enable_template_call("__global$math.isclose", for_ref=math.isclose)
return pe

View File

@ -275,6 +275,39 @@ class ExpressionImporter(BaseNodeVisitor):
self.value = emit_next(ast_node.values)
def visit_Call(self, ast_node):
# Evaluate positional args.
evaluated_args = []
for raw_arg in ast_node.args:
evaluated_args.append(self.sub_evaluate(raw_arg))
# Evaluate keyword args.
keyword_args = []
for raw_kw_arg in ast_node.keywords:
keyword_args.append((raw_kw_arg.arg, self.sub_evaluate(raw_kw_arg.value)))
# Perform partial evaluation of the callee.
callee_importer = PartialEvalImporter(self.fctx)
callee_importer.visit(ast_node.func)
callee_result = callee_importer.partial_eval_result
if (callee_result and
callee_result.type == PartialEvalType.YIELDS_LIVE_VALUE):
# This is a function known to the compiler. Perform a template call.
call_result = callee_result.yields.resolve_call(self.fctx.environment,
evaluated_args,
keyword_args)
if call_result.type != PartialEvalType.NOT_EVALUATED:
# Partial evaluation success.
self.fctx.check_partial_evaluated(call_result)
self.value = self.fctx.emit_partial_eval_result(call_result)
return
# The function is not known to the compiler.
self.fctx.check_partial_evaluated(callee_result)
# TODO: Implement first class functions.
self.fctx.abort("unhandled (potentially first-class function): {}".format(
ast.dump(ast_node)))
def visit_Compare(self, ast_node):
# Short-circuit comparison (degenerates to binary comparison when just
# two operands).

View File

@ -15,7 +15,7 @@ class DialectHelper(_BaseDialectHelper):
>>> c = ir.MLIRContext()
>>> h = DialectHelper(c, ir.OpBuilder(c))
Dialect Types:
>>> h.basicpy_NoneType
!basicpy.NoneType
@ -96,6 +96,16 @@ class DialectHelper(_BaseDialectHelper):
def basicpy_unknown_cast_op(self, result_type, operand):
return self.op("basicpy.unknown_cast", [result_type], [operand])
def basicpy_func_template_call_op(self, result_type, callee_symbol, args,
arg_names):
"""Creates a basicpy.func_template_call op."""
c = self.context
attrs = c.dictionary_attr({
"callee": c.flat_symbol_ref_attr(callee_symbol),
"arg_names": c.array_attr([c.string_attr(n) for n in arg_names]),
})
return self.op("basicpy.func_template_call", [result_type], args, attrs)
if __name__ == "__main__":
import doctest

View File

@ -0,0 +1,81 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s | npcomp-opt -canonicalize | FileCheck --dump-input=fail %s
//===----------------------------------------------------------------------===//
// func_template_call
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @positional
func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw []
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw [] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
// CHECK-LABEL: func @kwValid
func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
// CHECK-LABEL: func @posArgPack
func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
// CHECK-LABEL: func @kwArgPack
func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{expected <= kw arg names vs args}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second", "third", "fourth"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{positional arg pack must be the first kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*", "*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
// -----
func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{kw arg pack must be the last kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**", "next"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}
//===----------------------------------------------------------------------===//
// func_template
//===----------------------------------------------------------------------===//
// -----
// CHECK-LABEL: module @valid_template
module @valid_template {
// CHECK: basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
// CHECK: func @forInts(%arg0: i32) -> i32
func @forInts(%arg0 : i32) -> i32 {
return %arg0 : i32
}
}
}
// -----
module @invalid_template {
basicpy.func_template @__global$pkg.foobar {
// expected-error @+1 {{illegal operation in func_template}}
module {}
}
}