Add builtin_ufunc and generic_ufunc ops.

pull/1/head
Stella Laurenzo 2020-04-28 20:32:49 -07:00
parent 25e22aa4a5
commit e845db8a20
7 changed files with 252 additions and 21 deletions

View File

@ -21,6 +21,9 @@ def Numpy_Dialect : Dialect {
}
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Numpy_Dialect, mnemonic, traits>;
Op<Numpy_Dialect, mnemonic, traits> {
let parser = [{ return parse$cppClass(parser, &result); }];
let printer = [{ return print$cppClass(p, *this); }];
}
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT

View File

@ -9,8 +9,12 @@
#ifndef NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
#define NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffects.h"
namespace mlir {

View File

@ -11,20 +11,78 @@
include "NumpyDialect.td"
include "mlir/Interfaces/SideEffects.td"
include "mlir/IR/SymbolInterfaces.td"
def Numpy_FooOp : Numpy_Op<"foo", [NoSideEffect,
SameOperandsAndResultType]> {
let summary = "Temp op";
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
let summary = "References a built-in universal function";
let description = [{
Temp op
}];
let arguments = (ins I32:$input);
let results = (outs I32:$res);
let assemblyFormat = [{
$input attr-dict `:` type($input)
This module-level op binds by name to a fully-qualified numpy built-in
ufunc (i.e. "numpy.add") and carries metadata associated with it.
}];
}
def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
IsolatedFromAbove,
Symbol]> {
let summary = "Defines a ufunc in terms of overloaded element-wise functions";
let description = [{
}];
let arguments = (ins
TypeArrayAttr:$overload_types);
let regions = (region
VariadicRegion<AnyRegion>:$overloads);
}
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
Terminator,
HasParent<"NUMPY::GenericUfuncOp">]> {
let summary = "Return a value from a generic_ufunc";
let description = [{
Must terminate the basic block of a generic_ufunc overload.
}];
let arguments = (ins
Variadic<AnyType>:$operands
);
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
// def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
// IsolatedFromAbove,
// Symbol,
// NativeOpTrait<"FunctionLike">]> {
// let summary = "Defines a ufunc in terms of elementwise operations";
// let description = [{
// Defines a universal-function operator in terms of a region, containing
// a basic block of element-wise ops.
// }];
// let regions = (region AnyRegion:$body);
// let extraClassDeclaration = [{
// /// Returns the type of this function.
// FunctionType getType() {
// return getAttrOfType<TypeAttr>(getTypeAttrName())
// .getValue()
// .cast<FunctionType>();
// }
// /// Hook for OpTrait::FunctionLike, returns the number of function
// /// arguments. Depends on the type attribute being correct as checked by
// /// verifyType.
// unsigned getNumFuncArguments() { return getType().getInputs().size(); }
// /// Hook for OpTrait::FunctionLike, returns the number of function results.
// /// Depends on the type attribute being correct as checked by verifyType.
// unsigned getNumFuncResults() { return getType().getResults().size(); }
// /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
// /// attribute is present. This can check for preconditions of the
// /// getNumArguments hook not failing.
// LogicalResult verifyType();
// }];
// }
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS

View File

@ -7,12 +7,144 @@
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Numpy/NumpyOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
namespace mlir {
namespace npcomp {
namespace NUMPY {
//===----------------------------------------------------------------------===//
// BuildinUfuncOp
//===----------------------------------------------------------------------===//
static ParseResult parseBuiltinUfuncOp(OpAsmParser &parser,
OperationState *result) {
// Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result->attributes)) {
return failure();
}
if (failed(parser.parseOptionalAttrDict(result->attributes))) {
return failure();
}
return success();
}
static void printBuiltinUfuncOp(OpAsmPrinter &p, BuiltinUfuncOp op) {
p << op.getOperationName() << " ";
p.printSymbolName(op.getName());
p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName()});
}
//===----------------------------------------------------------------------===//
// GenericUfuncOp
//===----------------------------------------------------------------------===//
static ParseResult parseGenericUfuncOp(OpAsmParser &parser,
OperationState *result) {
Builder b(result->getContext());
// Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result->attributes))
return failure();
// Parse the body of overloads.
if (parser.parseLParen())
return failure();
SmallVector<Attribute, 4> overloadTypes;
for (bool first = true;; first = false) {
if (first) {
if (parser.parseOptionalKeyword("overload"))
break;
}
if (!first) {
if (parser.parseOptionalComma())
break;
if (parser.parseKeyword("overload"))
return failure();
}
SmallVector<OpAsmParser::OperandType, 2> argNames;
SmallVector<Type, 2> argTypes;
SmallVector<Type, 1> resultTypes;
SmallVector<SmallVector<NamedAttribute, 2>, 1> unusedAttrs;
bool isVariadic = false;
if (::mlir::impl::parseFunctionSignature(parser, false, argNames, argTypes,
unusedAttrs, isVariadic,
resultTypes, unusedAttrs))
return failure();
overloadTypes.push_back(TypeAttr::get(
FunctionType::get(argTypes, resultTypes, result->getContext())));
auto *region = result->addRegion();
if (parser.parseRegion(*region, argNames, argTypes))
return failure();
}
if (parser.parseRParen())
return failure();
// Parse 'attributes {...}'
if (parser.parseOptionalAttrDictWithKeyword(result->attributes))
return failure();
result->addAttribute(b.getIdentifier("overload_types"),
b.getArrayAttr(overloadTypes));
return success();
}
static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
p << op.getOperationName() << " @" << op.getName() << "(";
bool first = true;
for (auto it : llvm::enumerate(op.getRegions())) {
auto *region = it.value();
if (first)
first = false;
else
p << ", ";
if (region->empty()) {
p << "<<OVERLOAD_ERROR>>";
continue;
}
Block &entryBlock = region->front();
p << "overload(";
if (it.index() >= op.overload_types().size()) {
p << "<<MISSING OVERLOAD TYPE>>";
continue;
}
TypeAttr tattr = op.overload_types()[it.index()].cast<TypeAttr>();
FunctionType overloadType = tattr.getValue().dyn_cast<FunctionType>();
if (!overloadType) {
p << "<<ILLEGAL OVERLOAD TYPE>>";
continue;
}
if (overloadType.getNumInputs() != entryBlock.getNumArguments()) {
p << "<<OVERLOAD ARG MISMATCH>>";
continue;
}
auto argTypes = entryBlock.getArgumentTypes();
for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
auto arg = entryBlock.getArgument(i);
if (i > 0)
p << ", ";
p.printOperand(arg);
p << ": ";
p.printType(overloadType.getInputs()[i]);
}
p << ")";
p.printArrowTypeList(overloadType.getResults());
p.printRegion(*region, false, true);
}
p << ")";
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
} // namespace NUMPY

View File

@ -1,9 +1,29 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
// CHECK-LABEL: func @foo()
func @foo() -> i32 {
%0 = constant 1 : i32
// CHECK: %{{.*}} = numpy.foo %{{.*}} : i32
%res = numpy.foo %0 : i32
return %res : i32
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
// -----
// CHECK-LABEL: @builtin_ufunc
module @builtin_ufunc {
// CHECK: numpy.builtin_ufunc @numpy.add
numpy.builtin_ufunc @numpy.add
// CHECK: numpy.builtin_ufunc @numpy.custom_sub {some_attr = "foobar"}
numpy.builtin_ufunc @numpy.custom_sub { some_attr = "foobar" }
}
// -----
// CHECK-LABEL: @example_generic_ufunc
module @example_generic_ufunc {
// CHECK: numpy.generic_ufunc @numpy.add(
numpy.generic_ufunc @numpy.add (
// CHECK-SAME: overload(%arg0: i32, %arg1: i32) -> i32 {
overload(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: addi
%0 = addi %arg0, %arg1 : i32
numpy.ufunc_return %0 : i32
},
// CHECK: overload(%arg0: f32, %arg1: f32) -> f32 {
overload(%arg0: f32, %arg1: f32) -> f32 {
// CHECK: addf
%0 = addf %arg0, %arg1 : f32
numpy.ufunc_return %0 : f32
}
)
}

View File

@ -6,6 +6,7 @@ td="$(realpath $(dirname $0)/..)"
build_dir="$td/build"
install_mlir="$td/install-mlir"
build_mlir="$td/build-mlir"
declare -a extra_opts
if ! [ -d "$install_mlir/include/mlir" ]; then
echo "MLIR install path does not appear valid: $install_mlir"
@ -21,6 +22,16 @@ if [ -z "$python_exe" ]; then
exit 1
fi
# Detect linker.
# TODO: Generalize this.
for probe_linker in /usr/bin/ld.lld-10; do
if which ld.lld-10; then
echo "Using linker $probe_linker"
extra_opts+=("-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=$probe_linker")
break
fi
done
set -x
cmake -GNinja \
"-H$td" \
@ -28,4 +39,5 @@ cmake -GNinja \
"-DPYTHON_EXECUTABLE=$python_exe" \
"-DMLIR_DIR=$install_mlir/lib/cmake/mlir" \
"-DLLVM_EXTERNAL_LIT=$build_mlir/bin/llvm-lit" \
"${extra_opts[@]}" \
"$@"

View File

@ -32,4 +32,6 @@ for i in "$@"; do
done
set -x
cd $build_dir/test && python3 "$lit_exe" ${lit_args[@]}
cd $build_dir
ninja npcomp-opt
cd test && python3 "$lit_exe" ${lit_args[@]}