From e845db8a20b3a7d131d7c5ace0f2cf5fcfacfa45 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 28 Apr 2020 20:32:49 -0700 Subject: [PATCH] Add builtin_ufunc and generic_ufunc ops. --- include/npcomp/Dialect/Numpy/NumpyDialect.td | 5 +- include/npcomp/Dialect/Numpy/NumpyOps.h | 4 + include/npcomp/Dialect/Numpy/NumpyOps.td | 80 +++++++++-- lib/Dialect/Numpy/NumpyOps.cpp | 132 +++++++++++++++++++ test/Dialect/Numpy/ops.mlir | 36 +++-- tools/cmake_configure.sh | 12 ++ tools/run_lit.sh | 4 +- 7 files changed, 252 insertions(+), 21 deletions(-) diff --git a/include/npcomp/Dialect/Numpy/NumpyDialect.td b/include/npcomp/Dialect/Numpy/NumpyDialect.td index afcb96e58..883f886b1 100644 --- a/include/npcomp/Dialect/Numpy/NumpyDialect.td +++ b/include/npcomp/Dialect/Numpy/NumpyDialect.td @@ -21,6 +21,9 @@ def Numpy_Dialect : Dialect { } class Numpy_Op traits = []> : - Op; + Op { + let parser = [{ return parse$cppClass(parser, &result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} #endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.h b/include/npcomp/Dialect/Numpy/NumpyOps.h index c6f066189..3824cda92 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.h +++ b/include/npcomp/Dialect/Numpy/NumpyOps.h @@ -9,8 +9,12 @@ #ifndef NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H #define NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H +#include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/SideEffects.h" namespace mlir { diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.td b/include/npcomp/Dialect/Numpy/NumpyOps.td index f75f7d1e4..7873793f5 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/NumpyOps.td @@ -11,20 +11,78 @@ include "NumpyDialect.td" include "mlir/Interfaces/SideEffects.td" +include "mlir/IR/SymbolInterfaces.td" -def Numpy_FooOp : Numpy_Op<"foo", [NoSideEffect, - SameOperandsAndResultType]> { - let summary = "Temp op"; +def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> { + let summary = "References a built-in universal function"; let description = [{ - Temp op - }]; - - let arguments = (ins I32:$input); - let results = (outs I32:$res); - - let assemblyFormat = [{ - $input attr-dict `:` type($input) + This module-level op binds by name to a fully-qualified numpy built-in + ufunc (i.e. "numpy.add") and carries metadata associated with it. }]; } +def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [ + IsolatedFromAbove, + Symbol]> { + let summary = "Defines a ufunc in terms of overloaded element-wise functions"; + let description = [{ + }]; + + let arguments = (ins + TypeArrayAttr:$overload_types); + + let regions = (region + VariadicRegion:$overloads); +} + +def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [ + Terminator, + HasParent<"NUMPY::GenericUfuncOp">]> { + let summary = "Return a value from a generic_ufunc"; + let description = [{ + Must terminate the basic block of a generic_ufunc overload. + }]; + let arguments = (ins + Variadic:$operands + ); + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + +// def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [ +// IsolatedFromAbove, +// Symbol, +// NativeOpTrait<"FunctionLike">]> { +// let summary = "Defines a ufunc in terms of elementwise operations"; +// let description = [{ +// Defines a universal-function operator in terms of a region, containing +// a basic block of element-wise ops. +// }]; + +// let regions = (region AnyRegion:$body); + +// let extraClassDeclaration = [{ +// /// Returns the type of this function. +// FunctionType getType() { +// return getAttrOfType(getTypeAttrName()) +// .getValue() +// .cast(); +// } + +// /// Hook for OpTrait::FunctionLike, returns the number of function +// /// arguments. Depends on the type attribute being correct as checked by +// /// verifyType. +// unsigned getNumFuncArguments() { return getType().getInputs().size(); } + +// /// Hook for OpTrait::FunctionLike, returns the number of function results. +// /// Depends on the type attribute being correct as checked by verifyType. +// unsigned getNumFuncResults() { return getType().getResults().size(); } + +// /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' +// /// attribute is present. This can check for preconditions of the +// /// getNumArguments hook not failing. +// LogicalResult verifyType(); +// }]; +// } + #endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS diff --git a/lib/Dialect/Numpy/NumpyOps.cpp b/lib/Dialect/Numpy/NumpyOps.cpp index 32330b2c8..89be312b5 100644 --- a/lib/Dialect/Numpy/NumpyOps.cpp +++ b/lib/Dialect/Numpy/NumpyOps.cpp @@ -7,12 +7,144 @@ //===----------------------------------------------------------------------===// #include "npcomp/Dialect/Numpy/NumpyOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" #include "npcomp/Dialect/Numpy/NumpyDialect.h" namespace mlir { namespace npcomp { namespace NUMPY { + +//===----------------------------------------------------------------------===// +// BuildinUfuncOp +//===----------------------------------------------------------------------===// + +static ParseResult parseBuiltinUfuncOp(OpAsmParser &parser, + OperationState *result) { + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result->attributes)) { + return failure(); + } + if (failed(parser.parseOptionalAttrDict(result->attributes))) { + return failure(); + } + return success(); +} + +static void printBuiltinUfuncOp(OpAsmPrinter &p, BuiltinUfuncOp op) { + p << op.getOperationName() << " "; + p.printSymbolName(op.getName()); + p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName()}); +} + +//===----------------------------------------------------------------------===// +// GenericUfuncOp +//===----------------------------------------------------------------------===// + +static ParseResult parseGenericUfuncOp(OpAsmParser &parser, + OperationState *result) { + Builder b(result->getContext()); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result->attributes)) + return failure(); + + // Parse the body of overloads. + if (parser.parseLParen()) + return failure(); + + SmallVector overloadTypes; + for (bool first = true;; first = false) { + if (first) { + if (parser.parseOptionalKeyword("overload")) + break; + } + if (!first) { + if (parser.parseOptionalComma()) + break; + if (parser.parseKeyword("overload")) + return failure(); + } + SmallVector argNames; + SmallVector argTypes; + SmallVector resultTypes; + SmallVector, 1> unusedAttrs; + bool isVariadic = false; + if (::mlir::impl::parseFunctionSignature(parser, false, argNames, argTypes, + unusedAttrs, isVariadic, + resultTypes, unusedAttrs)) + return failure(); + overloadTypes.push_back(TypeAttr::get( + FunctionType::get(argTypes, resultTypes, result->getContext()))); + auto *region = result->addRegion(); + if (parser.parseRegion(*region, argNames, argTypes)) + return failure(); + } + + if (parser.parseRParen()) + return failure(); + + // Parse 'attributes {...}' + if (parser.parseOptionalAttrDictWithKeyword(result->attributes)) + return failure(); + result->addAttribute(b.getIdentifier("overload_types"), + b.getArrayAttr(overloadTypes)); + + return success(); +} + +static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) { + p << op.getOperationName() << " @" << op.getName() << "("; + bool first = true; + for (auto it : llvm::enumerate(op.getRegions())) { + auto *region = it.value(); + if (first) + first = false; + else + p << ", "; + if (region->empty()) { + p << "<>"; + continue; + } + + Block &entryBlock = region->front(); + p << "overload("; + if (it.index() >= op.overload_types().size()) { + p << "<>"; + continue; + } + TypeAttr tattr = op.overload_types()[it.index()].cast(); + FunctionType overloadType = tattr.getValue().dyn_cast(); + if (!overloadType) { + p << "<>"; + continue; + } + if (overloadType.getNumInputs() != entryBlock.getNumArguments()) { + p << "<>"; + continue; + } + + auto argTypes = entryBlock.getArgumentTypes(); + for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) { + auto arg = entryBlock.getArgument(i); + if (i > 0) + p << ", "; + p.printOperand(arg); + p << ": "; + p.printType(overloadType.getInputs()[i]); + } + p << ")"; + p.printArrowTypeList(overloadType.getResults()); + p.printRegion(*region, false, true); + } + p << ")"; +} + #define GET_OP_CLASSES #include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc" } // namespace NUMPY diff --git a/test/Dialect/Numpy/ops.mlir b/test/Dialect/Numpy/ops.mlir index e26ce5cd5..e7fdf14ed 100644 --- a/test/Dialect/Numpy/ops.mlir +++ b/test/Dialect/Numpy/ops.mlir @@ -1,9 +1,29 @@ -// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s - -// CHECK-LABEL: func @foo() -func @foo() -> i32 { - %0 = constant 1 : i32 - // CHECK: %{{.*}} = numpy.foo %{{.*}} : i32 - %res = numpy.foo %0 : i32 - return %res : i32 +// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s +// ----- +// CHECK-LABEL: @builtin_ufunc +module @builtin_ufunc { + // CHECK: numpy.builtin_ufunc @numpy.add + numpy.builtin_ufunc @numpy.add + // CHECK: numpy.builtin_ufunc @numpy.custom_sub {some_attr = "foobar"} + numpy.builtin_ufunc @numpy.custom_sub { some_attr = "foobar" } +} + +// ----- +// CHECK-LABEL: @example_generic_ufunc +module @example_generic_ufunc { + // CHECK: numpy.generic_ufunc @numpy.add( + numpy.generic_ufunc @numpy.add ( + // CHECK-SAME: overload(%arg0: i32, %arg1: i32) -> i32 { + overload(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: addi + %0 = addi %arg0, %arg1 : i32 + numpy.ufunc_return %0 : i32 + }, + // CHECK: overload(%arg0: f32, %arg1: f32) -> f32 { + overload(%arg0: f32, %arg1: f32) -> f32 { + // CHECK: addf + %0 = addf %arg0, %arg1 : f32 + numpy.ufunc_return %0 : f32 + } + ) } diff --git a/tools/cmake_configure.sh b/tools/cmake_configure.sh index b8da60d43..045bc64d6 100755 --- a/tools/cmake_configure.sh +++ b/tools/cmake_configure.sh @@ -6,6 +6,7 @@ td="$(realpath $(dirname $0)/..)" build_dir="$td/build" install_mlir="$td/install-mlir" build_mlir="$td/build-mlir" +declare -a extra_opts if ! [ -d "$install_mlir/include/mlir" ]; then echo "MLIR install path does not appear valid: $install_mlir" @@ -21,6 +22,16 @@ if [ -z "$python_exe" ]; then exit 1 fi +# Detect linker. +# TODO: Generalize this. +for probe_linker in /usr/bin/ld.lld-10; do + if which ld.lld-10; then + echo "Using linker $probe_linker" + extra_opts+=("-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=$probe_linker") + break + fi +done + set -x cmake -GNinja \ "-H$td" \ @@ -28,4 +39,5 @@ cmake -GNinja \ "-DPYTHON_EXECUTABLE=$python_exe" \ "-DMLIR_DIR=$install_mlir/lib/cmake/mlir" \ "-DLLVM_EXTERNAL_LIT=$build_mlir/bin/llvm-lit" \ + "${extra_opts[@]}" \ "$@" diff --git a/tools/run_lit.sh b/tools/run_lit.sh index b8006ffa3..70671d21e 100755 --- a/tools/run_lit.sh +++ b/tools/run_lit.sh @@ -32,4 +32,6 @@ for i in "$@"; do done set -x -cd $build_dir/test && python3 "$lit_exe" ${lit_args[@]} +cd $build_dir +ninja npcomp-opt +cd test && python3 "$lit_exe" ${lit_args[@]}