Add numpy.ufunc_call op.

pull/1/head
Stella Laurenzo 2020-04-29 17:49:56 -07:00
parent c4a192d5c9
commit b4425fe1d2
3 changed files with 51 additions and 0 deletions

View File

@ -11,6 +11,10 @@
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// Dialect definition
//===----------------------------------------------------------------------===//
def Numpy_Dialect : Dialect {
let name = "numpy";
let summary = "Core numpy dialect";
@ -20,10 +24,20 @@ def Numpy_Dialect : Dialect {
let cppNamespace = "numpy";
}
//===----------------------------------------------------------------------===//
// Op templates
//===----------------------------------------------------------------------===//
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Numpy_Dialect, mnemonic, traits> {
let parser = [{ return parse$cppClass(parser, &result); }];
let printer = [{ return print$cppClass(p, *this); }];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
def Numpy_AnyArray : TensorOf<[AnyType]>;
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT

View File

@ -49,4 +49,24 @@ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
let summary = "Default operation on a func";
let description = [{
Invokes a ufunc with the given arguments. This variant models the __call__
behavior of a python ufunc except that it does not model the `out`
parameter, which indicates an in-place update.
}];
let arguments = (ins
FlatSymbolRefAttr:$ufunc_ref,
Variadic<Numpy_AnyArray>:$operands
);
let results = (outs
Numpy_AnyArray:$result
);
let assemblyFormat = [{
$ufunc_ref `(` operands `)` attr-dict `:` functional-type(operands, results)
}];
}
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS

View File

@ -27,3 +27,20 @@ module @example_generic_ufunc {
}
)
}
// -----
// CHECK-LABEL: @ufunc_apply_ops
module @ufunc_apply_ops {
numpy.generic_ufunc @numpy.add (
overload(%arg0: i32, %arg1: i32) -> i32 {
%0 = addi %arg0, %arg1 : i32
numpy.ufunc_return %0 : i32
}
)
func @example(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>)
-> tensor<*xi32>
return %0 : tensor<*xi32>
}
}