mirror of https://github.com/llvm/torch-mlir
Merge pull request #4 from google/stella_dev
Add "template function" ops and importer code.pull/1/head
commit
d6b428fb60
|
@ -143,3 +143,18 @@ This is accomplished with the following `PartialEvalHook` setup:
|
|||
```
|
||||
|
||||
It is expected that this facility will evolve substantially, as it is the primary intended mechanism for remapping significant parts of the python namespace to builtin constructs (i.e. it will be the primary way to map `numpy` functions and values).
|
||||
|
||||
## Calls
|
||||
|
||||
This is very much a WIP. Relevant ops:
|
||||
|
||||
* `func_template`: Aggregates a list of function overloads to choose for a symbolic name.
|
||||
* `func_template_call`: Performs a symbolic call with python source conventions.
|
||||
|
||||
The idea is that a library modules of `func_template` definitions is assembled with all concrete implementations that have compiler support. The python compiler will iterate over all such templates and bind partial evaluation rules in the environment to detect the calls. Then, when importing, `func_template_call` ops make the call.
|
||||
|
||||
See the `basicpy.func_template` op for more detailed notes. The intention is that compiler-supported functions, methods, attribute getter/setter, and dunder functions can all exist in the library, with concrete resolution carried out by type constraints and corresponding type inference passes. Upon matches, concrete functions are pulled into the module being compiled and possibly inlined. With enough type constraints and some iteration, this should converge reasonably to a statically typed program.
|
||||
|
||||
See the tests:
|
||||
|
||||
* [template_call.py](../pytest/Compiler/template_call.py)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define NPCOMP_DIALECT_BASICPY_IR_BASICPY_OPS
|
||||
|
||||
include "BasicpyDialect.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
@ -201,6 +202,127 @@ def Basicpy_ExecDiscardOp : Basicpy_Op<"exec_discard", [
|
|||
let assemblyFormat = "operands attr-dict `:` type(operands)";
|
||||
}
|
||||
|
||||
def Basicpy_FuncTemplateCallOp : Basicpy_Op<"func_template_call", []> {
|
||||
let summary = "Calls a function template";
|
||||
let description = [{
|
||||
Most function calls start with this generic calling op, which binds
|
||||
symbolically to a func_template. At this level, there are very few
|
||||
semantics associated with the call, since, often, both types and the
|
||||
specific concrete callee cannot be determined.
|
||||
|
||||
Per python calling conventions, all functions return one result, even if
|
||||
None or a tuple (which may be syntactically unpacked to multiple results).
|
||||
|
||||
If specified, the `argNames` operand is right aligned to the list of
|
||||
positional `args`, representing arguments that are special or have been
|
||||
passed with a keyword. The following arg names are special:
|
||||
'*': Indicates that the argument is a positional argument pack (must be
|
||||
the first arg name, if present).
|
||||
'**': Indicates that the argument is a keyword argument pack (must be
|
||||
the last arg name, if present).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$callee,
|
||||
Variadic<AnyType>:$args,
|
||||
StrArrayAttr:$arg_names);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "$callee `(` $args `)` `kw` $arg_names attr-dict `:` functional-type($args, results)";
|
||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result">,
|
||||
];
|
||||
}
|
||||
|
||||
def Basicpy_FuncTemplateOp : Basicpy_Op<"func_template", [
|
||||
IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"FuncTemplateTerminatorOp">,
|
||||
NativeOpTrait<"SymbolTable">,
|
||||
Symbol]> {
|
||||
let summary = "Group of multiple overload-resolved concrete functions";
|
||||
let description = [{
|
||||
The outer func_template op acts as a module that can contain named concrete
|
||||
functions that are interpreted as overloads. If the function signature is
|
||||
sufficient to disambiguate (i.e. with nothing more than arity and MLIR
|
||||
argument types), then this is all that is needed. However, in many cases,
|
||||
additional attributes will need to be specified to further constrain types.
|
||||
The first matching function signature is selected to satisfy a
|
||||
`func_template_call` op.
|
||||
|
||||
TODO: Define this extended constraint matching.
|
||||
|
||||
Instantiation
|
||||
-------------
|
||||
Once a concrete function is selected as being applicable to a given call,
|
||||
it will typically be instantiated as a standalone, unspecialized function
|
||||
in the calling module (as a peer to the func_template). This function
|
||||
will be uniquely identified by concating the outer func_template's symbol
|
||||
name, '$', and the concrete instance's symbol name.
|
||||
|
||||
Note that the function may still be unspecialized (in that it contains
|
||||
UnknownType arguments/results), and type inference is expected to further
|
||||
specialize/inline/constrain it.
|
||||
|
||||
Naming
|
||||
------
|
||||
By convention, func_templates are named to avoid collision for various
|
||||
uses:
|
||||
- Global function templates: "__global$python.qualified.name"
|
||||
- Method names: "__method$method_name"
|
||||
- Attribute getter: "__getattr$attr_name"
|
||||
- Attribute setter: "__setattr$attr_name"
|
||||
|
||||
As in user-level python, for functions that bind to an instance, the first
|
||||
argument must be a concrete type for the bound instance type. In this way,
|
||||
there is one `func_template` for every unique member name and the normal
|
||||
type constraints system is used to select the overload, just as if it was
|
||||
a normal function call. It is left to utility routines to merge libraries
|
||||
in a way that preserves this invariant.
|
||||
|
||||
TODO: This needs to be fleshed out more as some additional rules about
|
||||
ordering and conflict resolution are likely needed to make this correct.
|
||||
|
||||
Correlation with python runtime
|
||||
-------------------------------
|
||||
When extracting a program, it is typically necessary to create weak
|
||||
references to specific python functions and correlate them back to a named
|
||||
template defined here. Often times this can just be done lexically, but
|
||||
to avoid fragility, any func_template that correlates to a python
|
||||
runtime function will have an additional attribute `py_bind` that is an
|
||||
array of StringAttr qualified names to resolve and bind to in the python
|
||||
runtime. In cases of divergence, the symbol name of the template should
|
||||
be chosen just for uniqueness (not significance).
|
||||
|
||||
The qualified name format for `py_bind` attribute is:
|
||||
package.name#local.qualified.name
|
||||
}];
|
||||
let arguments = (ins);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let verifier = [{ return verifyBasicpyOp(*this); }];
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result">,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
OpBuilder getBodyBuilder() {
|
||||
Block* body = getBody(0);
|
||||
return OpBuilder::atBlockEnd(body);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Basicpy_FuncTemplateTerminatorOp : Basicpy_Op<"func_template_terminator", [
|
||||
HasParent<"Basicpy::FuncTemplateOp">,
|
||||
Terminator]> {
|
||||
let summary = "Terminator pseudo-op for the FuncTemplateOp";
|
||||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
||||
def Basicpy_SlotObjectMakeOp : Basicpy_Op<"slot_object_make", [
|
||||
NoSideEffect]> {
|
||||
let summary = "Creates an instance of a SlotObject type";
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -47,17 +48,85 @@ void ExecOp::build(OpBuilder &builder, OperationState &result) {
|
|||
|
||||
static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
|
||||
Region *bodyRegion = result->addRegion();
|
||||
if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}) ||
|
||||
parser.parseOptionalAttrDict(result->attributes))
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result->attributes) ||
|
||||
parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printExecOp(OpAsmPrinter &p, ExecOp op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||
p.printRegion(op.body());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncTemplateCallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyBasicpyOp(FuncTemplateCallOp op) {
|
||||
auto argNames = op.arg_names();
|
||||
if (argNames.size() > op.args().size()) {
|
||||
return op.emitOpError() << "expected <= kw arg names vs args";
|
||||
}
|
||||
|
||||
for (auto it : llvm::enumerate(argNames)) {
|
||||
auto argName = it.value().cast<StringAttr>().getValue();
|
||||
if (argName == "*" && it.index() != 0) {
|
||||
return op.emitOpError() << "positional arg pack must be the first kw arg";
|
||||
}
|
||||
if (argName == "**" && it.index() != argNames.size() - 1) {
|
||||
return op.emitOpError() << "kw arg pack must be the last kw arg";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncTemplateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void FuncTemplateOp::build(OpBuilder &builder, OperationState &result) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
ensureTerminator(*result.addRegion(), builder, result.location);
|
||||
}
|
||||
|
||||
static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
Region *bodyRegion = result->addRegion();
|
||||
StringAttr symbolName;
|
||||
|
||||
if (parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
|
||||
result->attributes) ||
|
||||
parser.parseOptionalAttrDictWithKeyword(result->attributes) ||
|
||||
parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
|
||||
FuncTemplateOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
|
||||
result->location);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printFuncTemplateOp(OpAsmPrinter &p, FuncTemplateOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
p.printSymbolName(op.getName());
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
{SymbolTable::getSymbolAttrName()});
|
||||
p.printRegion(op.body());
|
||||
}
|
||||
|
||||
static LogicalResult verifyBasicpyOp(FuncTemplateOp op) {
|
||||
Block *body = op.getBody();
|
||||
for (auto &childOp : body->getOperations()) {
|
||||
if (!llvm::isa<FuncOp>(childOp) &&
|
||||
!llvm::isa<FuncTemplateTerminatorOp>(childOp)) {
|
||||
return childOp.emitOpError() << "illegal operation in func_template";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlotObjectMakeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -504,6 +504,14 @@ void PyContext::bind(py::module m) {
|
|||
}
|
||||
return DictionaryAttr::get(attrs, &self.context);
|
||||
})
|
||||
.def("array_attr",
|
||||
[](PyContext &self, py::list l) -> PyAttribute {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
for (auto &it : l) {
|
||||
attrs.push_back(it.cast<PyAttribute>().attr);
|
||||
}
|
||||
return ArrayAttr::get(attrs, &self.context);
|
||||
})
|
||||
.def("dense_elements_attr",
|
||||
[](PyContext &self, py::buffer array) -> PyAttribute {
|
||||
// Request a contiguous view.
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file | FileCheck %s --dump-input=fail
|
||||
|
||||
import math
|
||||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print("// -----")
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: func @call_ceil_positional
|
||||
@import_global
|
||||
def call_ceil_positional(n):
|
||||
# CHECK: basicpy.func_template_call @__global$math.ceil(%arg0) kw [] : (!basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return math.ceil(n)
|
||||
|
||||
|
||||
# CHECK-LABEL: func @call_isclose_kw
|
||||
@import_global
|
||||
def call_isclose_kw(n):
|
||||
# CHECK-DAG: %[[RTOL:.*]] = constant 2.000000e-06
|
||||
# CHECK-DAG: %[[ABSTOL:.*]] = constant 2.000000e-01
|
||||
# CHECK: basicpy.func_template_call @__global$math.isclose(%arg0, %[[RTOL]], %[[ABSTOL]]) kw ["rtol", "abs_tol"] : (!basicpy.UnknownType, f64, f64) -> !basicpy.UnknownType
|
||||
return math.isclose(n, rtol=2e-6, abs_tol=0.2)
|
|
@ -176,6 +176,11 @@ class LiveValueRef:
|
|||
"""Gets a named attribute from the live value."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords) -> PartialEvalResult:
|
||||
"""Resolves a function call given 'args' and 'keywords'."""
|
||||
return PartialEvalResult.not_evaluated()
|
||||
|
||||
def __repr__(self):
|
||||
return "MacroValueRef({}, {})".format(self.__class__.__name__,
|
||||
self.live_value)
|
||||
|
@ -195,6 +200,31 @@ class ResolveAttrLiveValueRef(LiveValueRef):
|
|||
return env.partial_eval_hook.resolve(attr_py_value)
|
||||
|
||||
|
||||
class TemplateCallLiveValueRef(LiveValueRef):
|
||||
"""Custom LiveValueRef that resolves calls to a func_template_call op."""
|
||||
__slots__ = ["callee_name"]
|
||||
|
||||
def __init__(self, callee_name, live_value):
|
||||
super().__init__(live_value)
|
||||
self.callee_name = callee_name
|
||||
|
||||
def resolve_call(self, env: "Environment", args,
|
||||
keywords) -> PartialEvalResult:
|
||||
linear_args = list(args)
|
||||
kw_arg_names = []
|
||||
for kw_name, kw_value in keywords:
|
||||
kw_arg_names.append(kw_name)
|
||||
linear_args.append(kw_value)
|
||||
|
||||
ir_h = env.ir_h
|
||||
result_ir_value = ir_h.basicpy_func_template_call_op(
|
||||
result_type=ir_h.basicpy_UnknownType,
|
||||
callee_symbol=self.callee_name,
|
||||
args=linear_args,
|
||||
arg_names=kw_arg_names).result
|
||||
return PartialEvalResult.yields_ir_value(result_ir_value)
|
||||
|
||||
|
||||
class PartialEvalHook:
|
||||
"""Owned by an environment to customize partial evaluation."""
|
||||
__slots__ = [
|
||||
|
@ -248,6 +278,12 @@ class PartialEvalHook:
|
|||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
ResolveAttrLiveValueRef(pv)), **kwargs)
|
||||
|
||||
def enable_template_call(self, callee_name, **kwargs):
|
||||
""""Enables a global template call."""
|
||||
self._bind(
|
||||
lambda pv: PartialEvalResult.yields_live_value(
|
||||
TemplateCallLiveValueRef(callee_name, pv)), **kwargs)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Environment
|
||||
|
|
|
@ -167,12 +167,17 @@ class AllDialectHelper(Numpy.DialectHelper, ScfDialectHelper):
|
|||
|
||||
|
||||
def build_default_partial_eval_hook() -> PartialEvalHook:
|
||||
mr = PartialEvalHook()
|
||||
pe = PartialEvalHook()
|
||||
### Modules
|
||||
mr.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
|
||||
pe.enable_getattr(for_type=ast.__class__) # The module we use is arbitrary.
|
||||
|
||||
### Tuples
|
||||
# Enable attribute resolution on tuple, which includes namedtuple (which is
|
||||
# really what we want).
|
||||
mr.enable_getattr(for_type=tuple)
|
||||
return mr
|
||||
pe.enable_getattr(for_type=tuple)
|
||||
|
||||
### Temp: resolve a function to a template call for testing
|
||||
import math
|
||||
pe.enable_template_call("__global$math.ceil", for_ref=math.ceil)
|
||||
pe.enable_template_call("__global$math.isclose", for_ref=math.isclose)
|
||||
return pe
|
||||
|
|
|
@ -275,6 +275,39 @@ class ExpressionImporter(BaseNodeVisitor):
|
|||
|
||||
self.value = emit_next(ast_node.values)
|
||||
|
||||
def visit_Call(self, ast_node):
|
||||
# Evaluate positional args.
|
||||
evaluated_args = []
|
||||
for raw_arg in ast_node.args:
|
||||
evaluated_args.append(self.sub_evaluate(raw_arg))
|
||||
|
||||
# Evaluate keyword args.
|
||||
keyword_args = []
|
||||
for raw_kw_arg in ast_node.keywords:
|
||||
keyword_args.append((raw_kw_arg.arg, self.sub_evaluate(raw_kw_arg.value)))
|
||||
|
||||
# Perform partial evaluation of the callee.
|
||||
callee_importer = PartialEvalImporter(self.fctx)
|
||||
callee_importer.visit(ast_node.func)
|
||||
callee_result = callee_importer.partial_eval_result
|
||||
if (callee_result and
|
||||
callee_result.type == PartialEvalType.YIELDS_LIVE_VALUE):
|
||||
# This is a function known to the compiler. Perform a template call.
|
||||
call_result = callee_result.yields.resolve_call(self.fctx.environment,
|
||||
evaluated_args,
|
||||
keyword_args)
|
||||
if call_result.type != PartialEvalType.NOT_EVALUATED:
|
||||
# Partial evaluation success.
|
||||
self.fctx.check_partial_evaluated(call_result)
|
||||
self.value = self.fctx.emit_partial_eval_result(call_result)
|
||||
return
|
||||
|
||||
# The function is not known to the compiler.
|
||||
self.fctx.check_partial_evaluated(callee_result)
|
||||
# TODO: Implement first class functions.
|
||||
self.fctx.abort("unhandled (potentially first-class function): {}".format(
|
||||
ast.dump(ast_node)))
|
||||
|
||||
def visit_Compare(self, ast_node):
|
||||
# Short-circuit comparison (degenerates to binary comparison when just
|
||||
# two operands).
|
||||
|
|
|
@ -15,7 +15,7 @@ class DialectHelper(_BaseDialectHelper):
|
|||
|
||||
>>> c = ir.MLIRContext()
|
||||
>>> h = DialectHelper(c, ir.OpBuilder(c))
|
||||
|
||||
|
||||
Dialect Types:
|
||||
>>> h.basicpy_NoneType
|
||||
!basicpy.NoneType
|
||||
|
@ -96,6 +96,16 @@ class DialectHelper(_BaseDialectHelper):
|
|||
def basicpy_unknown_cast_op(self, result_type, operand):
|
||||
return self.op("basicpy.unknown_cast", [result_type], [operand])
|
||||
|
||||
def basicpy_func_template_call_op(self, result_type, callee_symbol, args,
|
||||
arg_names):
|
||||
"""Creates a basicpy.func_template_call op."""
|
||||
c = self.context
|
||||
attrs = c.dictionary_attr({
|
||||
"callee": c.flat_symbol_ref_attr(callee_symbol),
|
||||
"arg_names": c.array_attr([c.string_attr(n) for n in arg_names]),
|
||||
})
|
||||
return self.op("basicpy.func_template_call", [result_type], args, attrs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics %s | npcomp-opt -canonicalize | FileCheck --dump-input=fail %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// func_template_call
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @positional
|
||||
func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw []
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw [] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @kwValid
|
||||
func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"]
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @posArgPack
|
||||
func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"]
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @kwArgPack
|
||||
func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"]
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// expected-error @+1 {{expected <= kw arg names vs args}}
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second", "third", "fourth"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// expected-error @+1 {{positional arg pack must be the first kw arg}}
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*", "*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
// -----
|
||||
func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
|
||||
// expected-error @+1 {{kw arg pack must be the last kw arg}}
|
||||
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**", "next"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
|
||||
return %0 : !basicpy.UnknownType
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// func_template
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: module @valid_template
|
||||
module @valid_template {
|
||||
// CHECK: basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
|
||||
basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
|
||||
// CHECK: func @forInts(%arg0: i32) -> i32
|
||||
func @forInts(%arg0 : i32) -> i32 {
|
||||
return %arg0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
module @invalid_template {
|
||||
basicpy.func_template @__global$pkg.foobar {
|
||||
// expected-error @+1 {{illegal operation in func_template}}
|
||||
module {}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue