mirror of https://github.com/llvm/torch-mlir
Add numpy.ufunc_call op.
parent
c4a192d5c9
commit
b4425fe1d2
|
@ -11,6 +11,10 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Numpy_Dialect : Dialect {
|
||||
let name = "numpy";
|
||||
let summary = "Core numpy dialect";
|
||||
|
@ -20,10 +24,20 @@ def Numpy_Dialect : Dialect {
|
|||
let cppNamespace = "numpy";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op templates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Numpy_Dialect, mnemonic, traits> {
|
||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Numpy_AnyArray : TensorOf<[AnyType]>;
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||
|
|
|
@ -49,4 +49,24 @@ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
|||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
|
||||
let summary = "Default operation on a func";
|
||||
let description = [{
|
||||
Invokes a ufunc with the given arguments. This variant models the __call__
|
||||
behavior of a python ufunc except that it does not model the `out`
|
||||
parameter, which indicates an in-place update.
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$ufunc_ref,
|
||||
Variadic<Numpy_AnyArray>:$operands
|
||||
);
|
||||
let results = (outs
|
||||
Numpy_AnyArray:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$ufunc_ref `(` operands `)` attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||
|
|
|
@ -27,3 +27,20 @@ module @example_generic_ufunc {
|
|||
}
|
||||
)
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @ufunc_apply_ops
|
||||
module @ufunc_apply_ops {
|
||||
numpy.generic_ufunc @numpy.add (
|
||||
overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||
%0 = addi %arg0, %arg1 : i32
|
||||
numpy.ufunc_return %0 : i32
|
||||
}
|
||||
)
|
||||
|
||||
func @example(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>)
|
||||
-> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue