mirror of https://github.com/llvm/torch-mlir
Add builtin_ufunc and generic_ufunc ops.
parent
25e22aa4a5
commit
e845db8a20
|
@ -21,6 +21,9 @@ def Numpy_Dialect : Dialect {
|
|||
}
|
||||
|
||||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Numpy_Dialect, mnemonic, traits>;
|
||||
Op<Numpy_Dialect, mnemonic, traits> {
|
||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||
|
|
|
@ -9,8 +9,12 @@
|
|||
#ifndef NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
||||
#define NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/FunctionSupport.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/SideEffects.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -11,20 +11,78 @@
|
|||
|
||||
include "NumpyDialect.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
def Numpy_FooOp : Numpy_Op<"foo", [NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Temp op";
|
||||
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
|
||||
let summary = "References a built-in universal function";
|
||||
let description = [{
|
||||
Temp op
|
||||
}];
|
||||
|
||||
let arguments = (ins I32:$input);
|
||||
let results = (outs I32:$res);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
This module-level op binds by name to a fully-qualified numpy built-in
|
||||
ufunc (i.e. "numpy.add") and carries metadata associated with it.
|
||||
}];
|
||||
}
|
||||
|
||||
def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
||||
IsolatedFromAbove,
|
||||
Symbol]> {
|
||||
let summary = "Defines a ufunc in terms of overloaded element-wise functions";
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TypeArrayAttr:$overload_types);
|
||||
|
||||
let regions = (region
|
||||
VariadicRegion<AnyRegion>:$overloads);
|
||||
}
|
||||
|
||||
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
||||
Terminator,
|
||||
HasParent<"NUMPY::GenericUfuncOp">]> {
|
||||
let summary = "Return a value from a generic_ufunc";
|
||||
let description = [{
|
||||
Must terminate the basic block of a generic_ufunc overload.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$operands
|
||||
);
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
// def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
||||
// IsolatedFromAbove,
|
||||
// Symbol,
|
||||
// NativeOpTrait<"FunctionLike">]> {
|
||||
// let summary = "Defines a ufunc in terms of elementwise operations";
|
||||
// let description = [{
|
||||
// Defines a universal-function operator in terms of a region, containing
|
||||
// a basic block of element-wise ops.
|
||||
// }];
|
||||
|
||||
// let regions = (region AnyRegion:$body);
|
||||
|
||||
// let extraClassDeclaration = [{
|
||||
// /// Returns the type of this function.
|
||||
// FunctionType getType() {
|
||||
// return getAttrOfType<TypeAttr>(getTypeAttrName())
|
||||
// .getValue()
|
||||
// .cast<FunctionType>();
|
||||
// }
|
||||
|
||||
// /// Hook for OpTrait::FunctionLike, returns the number of function
|
||||
// /// arguments. Depends on the type attribute being correct as checked by
|
||||
// /// verifyType.
|
||||
// unsigned getNumFuncArguments() { return getType().getInputs().size(); }
|
||||
|
||||
// /// Hook for OpTrait::FunctionLike, returns the number of function results.
|
||||
// /// Depends on the type attribute being correct as checked by verifyType.
|
||||
// unsigned getNumFuncResults() { return getType().getResults().size(); }
|
||||
|
||||
// /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
|
||||
// /// attribute is present. This can check for preconditions of the
|
||||
// /// getNumArguments hook not failing.
|
||||
// LogicalResult verifyType();
|
||||
// }];
|
||||
// }
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||
|
|
|
@ -7,12 +7,144 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace npcomp {
|
||||
namespace NUMPY {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuildinUfuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseBuiltinUfuncOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
// Parse the name as a symbol.
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
result->attributes)) {
|
||||
return failure();
|
||||
}
|
||||
if (failed(parser.parseOptionalAttrDict(result->attributes))) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printBuiltinUfuncOp(OpAsmPrinter &p, BuiltinUfuncOp op) {
|
||||
p << op.getOperationName() << " ";
|
||||
p.printSymbolName(op.getName());
|
||||
p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName()});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericUfuncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseGenericUfuncOp(OpAsmParser &parser,
|
||||
OperationState *result) {
|
||||
Builder b(result->getContext());
|
||||
|
||||
// Parse the name as a symbol.
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
result->attributes))
|
||||
return failure();
|
||||
|
||||
// Parse the body of overloads.
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute, 4> overloadTypes;
|
||||
for (bool first = true;; first = false) {
|
||||
if (first) {
|
||||
if (parser.parseOptionalKeyword("overload"))
|
||||
break;
|
||||
}
|
||||
if (!first) {
|
||||
if (parser.parseOptionalComma())
|
||||
break;
|
||||
if (parser.parseKeyword("overload"))
|
||||
return failure();
|
||||
}
|
||||
SmallVector<OpAsmParser::OperandType, 2> argNames;
|
||||
SmallVector<Type, 2> argTypes;
|
||||
SmallVector<Type, 1> resultTypes;
|
||||
SmallVector<SmallVector<NamedAttribute, 2>, 1> unusedAttrs;
|
||||
bool isVariadic = false;
|
||||
if (::mlir::impl::parseFunctionSignature(parser, false, argNames, argTypes,
|
||||
unusedAttrs, isVariadic,
|
||||
resultTypes, unusedAttrs))
|
||||
return failure();
|
||||
overloadTypes.push_back(TypeAttr::get(
|
||||
FunctionType::get(argTypes, resultTypes, result->getContext())));
|
||||
auto *region = result->addRegion();
|
||||
if (parser.parseRegion(*region, argNames, argTypes))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
// Parse 'attributes {...}'
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result->attributes))
|
||||
return failure();
|
||||
result->addAttribute(b.getIdentifier("overload_types"),
|
||||
b.getArrayAttr(overloadTypes));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
|
||||
p << op.getOperationName() << " @" << op.getName() << "(";
|
||||
bool first = true;
|
||||
for (auto it : llvm::enumerate(op.getRegions())) {
|
||||
auto *region = it.value();
|
||||
if (first)
|
||||
first = false;
|
||||
else
|
||||
p << ", ";
|
||||
if (region->empty()) {
|
||||
p << "<<OVERLOAD_ERROR>>";
|
||||
continue;
|
||||
}
|
||||
|
||||
Block &entryBlock = region->front();
|
||||
p << "overload(";
|
||||
if (it.index() >= op.overload_types().size()) {
|
||||
p << "<<MISSING OVERLOAD TYPE>>";
|
||||
continue;
|
||||
}
|
||||
TypeAttr tattr = op.overload_types()[it.index()].cast<TypeAttr>();
|
||||
FunctionType overloadType = tattr.getValue().dyn_cast<FunctionType>();
|
||||
if (!overloadType) {
|
||||
p << "<<ILLEGAL OVERLOAD TYPE>>";
|
||||
continue;
|
||||
}
|
||||
if (overloadType.getNumInputs() != entryBlock.getNumArguments()) {
|
||||
p << "<<OVERLOAD ARG MISMATCH>>";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto argTypes = entryBlock.getArgumentTypes();
|
||||
for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
|
||||
auto arg = entryBlock.getArgument(i);
|
||||
if (i > 0)
|
||||
p << ", ";
|
||||
p.printOperand(arg);
|
||||
p << ": ";
|
||||
p.printType(overloadType.getInputs()[i]);
|
||||
}
|
||||
p << ")";
|
||||
p.printArrowTypeList(overloadType.getResults());
|
||||
p.printRegion(*region, false, true);
|
||||
}
|
||||
p << ")";
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||
} // namespace NUMPY
|
||||
|
|
|
@ -1,9 +1,29 @@
|
|||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @foo()
|
||||
func @foo() -> i32 {
|
||||
%0 = constant 1 : i32
|
||||
// CHECK: %{{.*}} = numpy.foo %{{.*}} : i32
|
||||
%res = numpy.foo %0 : i32
|
||||
return %res : i32
|
||||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
// -----
|
||||
// CHECK-LABEL: @builtin_ufunc
|
||||
module @builtin_ufunc {
|
||||
// CHECK: numpy.builtin_ufunc @numpy.add
|
||||
numpy.builtin_ufunc @numpy.add
|
||||
// CHECK: numpy.builtin_ufunc @numpy.custom_sub {some_attr = "foobar"}
|
||||
numpy.builtin_ufunc @numpy.custom_sub { some_attr = "foobar" }
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @example_generic_ufunc
|
||||
module @example_generic_ufunc {
|
||||
// CHECK: numpy.generic_ufunc @numpy.add(
|
||||
numpy.generic_ufunc @numpy.add (
|
||||
// CHECK-SAME: overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||
overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||
// CHECK: addi
|
||||
%0 = addi %arg0, %arg1 : i32
|
||||
numpy.ufunc_return %0 : i32
|
||||
},
|
||||
// CHECK: overload(%arg0: f32, %arg1: f32) -> f32 {
|
||||
overload(%arg0: f32, %arg1: f32) -> f32 {
|
||||
// CHECK: addf
|
||||
%0 = addf %arg0, %arg1 : f32
|
||||
numpy.ufunc_return %0 : f32
|
||||
}
|
||||
)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ td="$(realpath $(dirname $0)/..)"
|
|||
build_dir="$td/build"
|
||||
install_mlir="$td/install-mlir"
|
||||
build_mlir="$td/build-mlir"
|
||||
declare -a extra_opts
|
||||
|
||||
if ! [ -d "$install_mlir/include/mlir" ]; then
|
||||
echo "MLIR install path does not appear valid: $install_mlir"
|
||||
|
@ -21,6 +22,16 @@ if [ -z "$python_exe" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
# Detect linker.
|
||||
# TODO: Generalize this.
|
||||
for probe_linker in /usr/bin/ld.lld-10; do
|
||||
if which ld.lld-10; then
|
||||
echo "Using linker $probe_linker"
|
||||
extra_opts+=("-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=$probe_linker")
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
set -x
|
||||
cmake -GNinja \
|
||||
"-H$td" \
|
||||
|
@ -28,4 +39,5 @@ cmake -GNinja \
|
|||
"-DPYTHON_EXECUTABLE=$python_exe" \
|
||||
"-DMLIR_DIR=$install_mlir/lib/cmake/mlir" \
|
||||
"-DLLVM_EXTERNAL_LIT=$build_mlir/bin/llvm-lit" \
|
||||
"${extra_opts[@]}" \
|
||||
"$@"
|
||||
|
|
|
@ -32,4 +32,6 @@ for i in "$@"; do
|
|||
done
|
||||
|
||||
set -x
|
||||
cd $build_dir/test && python3 "$lit_exe" ${lit_args[@]}
|
||||
cd $build_dir
|
||||
ninja npcomp-opt
|
||||
cd test && python3 "$lit_exe" ${lit_args[@]}
|
||||
|
|
Loading…
Reference in New Issue