mirror of https://github.com/llvm/torch-mlir
Add builtin_ufunc and generic_ufunc ops.
parent
25e22aa4a5
commit
e845db8a20
|
@ -21,6 +21,9 @@ def Numpy_Dialect : Dialect {
|
||||||
}
|
}
|
||||||
|
|
||||||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<Numpy_Dialect, mnemonic, traits>;
|
Op<Numpy_Dialect, mnemonic, traits> {
|
||||||
|
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||||
|
let printer = [{ return print$cppClass(p, *this); }];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||||
|
|
|
@ -9,8 +9,12 @@
|
||||||
#ifndef NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
#ifndef NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
||||||
#define NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
#define NPCOMP_DIALECT_NUMPY_NUMPY_OPS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/FunctionSupport.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/SideEffects.h"
|
#include "mlir/Interfaces/SideEffects.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -11,20 +11,78 @@
|
||||||
|
|
||||||
include "NumpyDialect.td"
|
include "NumpyDialect.td"
|
||||||
include "mlir/Interfaces/SideEffects.td"
|
include "mlir/Interfaces/SideEffects.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
def Numpy_FooOp : Numpy_Op<"foo", [NoSideEffect,
|
def Numpy_BuiltinUfuncOp : Numpy_Op<"builtin_ufunc", [Symbol]> {
|
||||||
SameOperandsAndResultType]> {
|
let summary = "References a built-in universal function";
|
||||||
let summary = "Temp op";
|
|
||||||
let description = [{
|
let description = [{
|
||||||
Temp op
|
This module-level op binds by name to a fully-qualified numpy built-in
|
||||||
}];
|
ufunc (i.e. "numpy.add") and carries metadata associated with it.
|
||||||
|
|
||||||
let arguments = (ins I32:$input);
|
|
||||||
let results = (outs I32:$res);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$input attr-dict `:` type($input)
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
||||||
|
IsolatedFromAbove,
|
||||||
|
Symbol]> {
|
||||||
|
let summary = "Defines a ufunc in terms of overloaded element-wise functions";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TypeArrayAttr:$overload_types);
|
||||||
|
|
||||||
|
let regions = (region
|
||||||
|
VariadicRegion<AnyRegion>:$overloads);
|
||||||
|
}
|
||||||
|
|
||||||
|
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
||||||
|
Terminator,
|
||||||
|
HasParent<"NUMPY::GenericUfuncOp">]> {
|
||||||
|
let summary = "Return a value from a generic_ufunc";
|
||||||
|
let description = [{
|
||||||
|
Must terminate the basic block of a generic_ufunc overload.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<AnyType>:$operands
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||||
|
}
|
||||||
|
|
||||||
|
// def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
||||||
|
// IsolatedFromAbove,
|
||||||
|
// Symbol,
|
||||||
|
// NativeOpTrait<"FunctionLike">]> {
|
||||||
|
// let summary = "Defines a ufunc in terms of elementwise operations";
|
||||||
|
// let description = [{
|
||||||
|
// Defines a universal-function operator in terms of a region, containing
|
||||||
|
// a basic block of element-wise ops.
|
||||||
|
// }];
|
||||||
|
|
||||||
|
// let regions = (region AnyRegion:$body);
|
||||||
|
|
||||||
|
// let extraClassDeclaration = [{
|
||||||
|
// /// Returns the type of this function.
|
||||||
|
// FunctionType getType() {
|
||||||
|
// return getAttrOfType<TypeAttr>(getTypeAttrName())
|
||||||
|
// .getValue()
|
||||||
|
// .cast<FunctionType>();
|
||||||
|
// }
|
||||||
|
|
||||||
|
// /// Hook for OpTrait::FunctionLike, returns the number of function
|
||||||
|
// /// arguments. Depends on the type attribute being correct as checked by
|
||||||
|
// /// verifyType.
|
||||||
|
// unsigned getNumFuncArguments() { return getType().getInputs().size(); }
|
||||||
|
|
||||||
|
// /// Hook for OpTrait::FunctionLike, returns the number of function results.
|
||||||
|
// /// Depends on the type attribute being correct as checked by verifyType.
|
||||||
|
// unsigned getNumFuncResults() { return getType().getResults().size(); }
|
||||||
|
|
||||||
|
// /// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
|
||||||
|
// /// attribute is present. This can check for preconditions of the
|
||||||
|
// /// getNumArguments hook not failing.
|
||||||
|
// LogicalResult verifyType();
|
||||||
|
// }];
|
||||||
|
// }
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||||
|
|
|
@ -7,12 +7,144 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/FunctionImplementation.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace npcomp {
|
namespace npcomp {
|
||||||
namespace NUMPY {
|
namespace NUMPY {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BuildinUfuncOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static ParseResult parseBuiltinUfuncOp(OpAsmParser &parser,
|
||||||
|
OperationState *result) {
|
||||||
|
// Parse the name as a symbol.
|
||||||
|
StringAttr nameAttr;
|
||||||
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||||
|
result->attributes)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (failed(parser.parseOptionalAttrDict(result->attributes))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printBuiltinUfuncOp(OpAsmPrinter &p, BuiltinUfuncOp op) {
|
||||||
|
p << op.getOperationName() << " ";
|
||||||
|
p.printSymbolName(op.getName());
|
||||||
|
p.printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName()});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GenericUfuncOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static ParseResult parseGenericUfuncOp(OpAsmParser &parser,
|
||||||
|
OperationState *result) {
|
||||||
|
Builder b(result->getContext());
|
||||||
|
|
||||||
|
// Parse the name as a symbol.
|
||||||
|
StringAttr nameAttr;
|
||||||
|
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||||
|
result->attributes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Parse the body of overloads.
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Attribute, 4> overloadTypes;
|
||||||
|
for (bool first = true;; first = false) {
|
||||||
|
if (first) {
|
||||||
|
if (parser.parseOptionalKeyword("overload"))
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!first) {
|
||||||
|
if (parser.parseOptionalComma())
|
||||||
|
break;
|
||||||
|
if (parser.parseKeyword("overload"))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
SmallVector<OpAsmParser::OperandType, 2> argNames;
|
||||||
|
SmallVector<Type, 2> argTypes;
|
||||||
|
SmallVector<Type, 1> resultTypes;
|
||||||
|
SmallVector<SmallVector<NamedAttribute, 2>, 1> unusedAttrs;
|
||||||
|
bool isVariadic = false;
|
||||||
|
if (::mlir::impl::parseFunctionSignature(parser, false, argNames, argTypes,
|
||||||
|
unusedAttrs, isVariadic,
|
||||||
|
resultTypes, unusedAttrs))
|
||||||
|
return failure();
|
||||||
|
overloadTypes.push_back(TypeAttr::get(
|
||||||
|
FunctionType::get(argTypes, resultTypes, result->getContext())));
|
||||||
|
auto *region = result->addRegion();
|
||||||
|
if (parser.parseRegion(*region, argNames, argTypes))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Parse 'attributes {...}'
|
||||||
|
if (parser.parseOptionalAttrDictWithKeyword(result->attributes))
|
||||||
|
return failure();
|
||||||
|
result->addAttribute(b.getIdentifier("overload_types"),
|
||||||
|
b.getArrayAttr(overloadTypes));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
|
||||||
|
p << op.getOperationName() << " @" << op.getName() << "(";
|
||||||
|
bool first = true;
|
||||||
|
for (auto it : llvm::enumerate(op.getRegions())) {
|
||||||
|
auto *region = it.value();
|
||||||
|
if (first)
|
||||||
|
first = false;
|
||||||
|
else
|
||||||
|
p << ", ";
|
||||||
|
if (region->empty()) {
|
||||||
|
p << "<<OVERLOAD_ERROR>>";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Block &entryBlock = region->front();
|
||||||
|
p << "overload(";
|
||||||
|
if (it.index() >= op.overload_types().size()) {
|
||||||
|
p << "<<MISSING OVERLOAD TYPE>>";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
TypeAttr tattr = op.overload_types()[it.index()].cast<TypeAttr>();
|
||||||
|
FunctionType overloadType = tattr.getValue().dyn_cast<FunctionType>();
|
||||||
|
if (!overloadType) {
|
||||||
|
p << "<<ILLEGAL OVERLOAD TYPE>>";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (overloadType.getNumInputs() != entryBlock.getNumArguments()) {
|
||||||
|
p << "<<OVERLOAD ARG MISMATCH>>";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto argTypes = entryBlock.getArgumentTypes();
|
||||||
|
for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
|
||||||
|
auto arg = entryBlock.getArgument(i);
|
||||||
|
if (i > 0)
|
||||||
|
p << ", ";
|
||||||
|
p.printOperand(arg);
|
||||||
|
p << ": ";
|
||||||
|
p.printType(overloadType.getInputs()[i]);
|
||||||
|
}
|
||||||
|
p << ")";
|
||||||
|
p.printArrowTypeList(overloadType.getResults());
|
||||||
|
p.printRegion(*region, false, true);
|
||||||
|
}
|
||||||
|
p << ")";
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||||
} // namespace NUMPY
|
} // namespace NUMPY
|
||||||
|
|
|
@ -1,9 +1,29 @@
|
||||||
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
|
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||||
|
// -----
|
||||||
// CHECK-LABEL: func @foo()
|
// CHECK-LABEL: @builtin_ufunc
|
||||||
func @foo() -> i32 {
|
module @builtin_ufunc {
|
||||||
%0 = constant 1 : i32
|
// CHECK: numpy.builtin_ufunc @numpy.add
|
||||||
// CHECK: %{{.*}} = numpy.foo %{{.*}} : i32
|
numpy.builtin_ufunc @numpy.add
|
||||||
%res = numpy.foo %0 : i32
|
// CHECK: numpy.builtin_ufunc @numpy.custom_sub {some_attr = "foobar"}
|
||||||
return %res : i32
|
numpy.builtin_ufunc @numpy.custom_sub { some_attr = "foobar" }
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @example_generic_ufunc
|
||||||
|
module @example_generic_ufunc {
|
||||||
|
// CHECK: numpy.generic_ufunc @numpy.add(
|
||||||
|
numpy.generic_ufunc @numpy.add (
|
||||||
|
// CHECK-SAME: overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
|
overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
|
// CHECK: addi
|
||||||
|
%0 = addi %arg0, %arg1 : i32
|
||||||
|
numpy.ufunc_return %0 : i32
|
||||||
|
},
|
||||||
|
// CHECK: overload(%arg0: f32, %arg1: f32) -> f32 {
|
||||||
|
overload(%arg0: f32, %arg1: f32) -> f32 {
|
||||||
|
// CHECK: addf
|
||||||
|
%0 = addf %arg0, %arg1 : f32
|
||||||
|
numpy.ufunc_return %0 : f32
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ td="$(realpath $(dirname $0)/..)"
|
||||||
build_dir="$td/build"
|
build_dir="$td/build"
|
||||||
install_mlir="$td/install-mlir"
|
install_mlir="$td/install-mlir"
|
||||||
build_mlir="$td/build-mlir"
|
build_mlir="$td/build-mlir"
|
||||||
|
declare -a extra_opts
|
||||||
|
|
||||||
if ! [ -d "$install_mlir/include/mlir" ]; then
|
if ! [ -d "$install_mlir/include/mlir" ]; then
|
||||||
echo "MLIR install path does not appear valid: $install_mlir"
|
echo "MLIR install path does not appear valid: $install_mlir"
|
||||||
|
@ -21,6 +22,16 @@ if [ -z "$python_exe" ]; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Detect linker.
|
||||||
|
# TODO: Generalize this.
|
||||||
|
for probe_linker in /usr/bin/ld.lld-10; do
|
||||||
|
if which ld.lld-10; then
|
||||||
|
echo "Using linker $probe_linker"
|
||||||
|
extra_opts+=("-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=$probe_linker")
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
cmake -GNinja \
|
cmake -GNinja \
|
||||||
"-H$td" \
|
"-H$td" \
|
||||||
|
@ -28,4 +39,5 @@ cmake -GNinja \
|
||||||
"-DPYTHON_EXECUTABLE=$python_exe" \
|
"-DPYTHON_EXECUTABLE=$python_exe" \
|
||||||
"-DMLIR_DIR=$install_mlir/lib/cmake/mlir" \
|
"-DMLIR_DIR=$install_mlir/lib/cmake/mlir" \
|
||||||
"-DLLVM_EXTERNAL_LIT=$build_mlir/bin/llvm-lit" \
|
"-DLLVM_EXTERNAL_LIT=$build_mlir/bin/llvm-lit" \
|
||||||
|
"${extra_opts[@]}" \
|
||||||
"$@"
|
"$@"
|
||||||
|
|
|
@ -32,4 +32,6 @@ for i in "$@"; do
|
||||||
done
|
done
|
||||||
|
|
||||||
set -x
|
set -x
|
||||||
cd $build_dir/test && python3 "$lit_exe" ${lit_args[@]}
|
cd $build_dir
|
||||||
|
ninja npcomp-opt
|
||||||
|
cd test && python3 "$lit_exe" ${lit_args[@]}
|
||||||
|
|
Loading…
Reference in New Issue