mirror of https://github.com/llvm/torch-mlir
Add numpy.ufunc_call op.
parent
c4a192d5c9
commit
b4425fe1d2
|
@ -11,6 +11,10 @@
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect definition
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Numpy_Dialect : Dialect {
|
def Numpy_Dialect : Dialect {
|
||||||
let name = "numpy";
|
let name = "numpy";
|
||||||
let summary = "Core numpy dialect";
|
let summary = "Core numpy dialect";
|
||||||
|
@ -20,10 +24,20 @@ def Numpy_Dialect : Dialect {
|
||||||
let cppNamespace = "numpy";
|
let cppNamespace = "numpy";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Op templates
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<Numpy_Dialect, mnemonic, traits> {
|
Op<Numpy_Dialect, mnemonic, traits> {
|
||||||
let parser = [{ return parse$cppClass(parser, &result); }];
|
let parser = [{ return parse$cppClass(parser, &result); }];
|
||||||
let printer = [{ return print$cppClass(p, *this); }];
|
let printer = [{ return print$cppClass(p, *this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type predicates
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Numpy_AnyArray : TensorOf<[AnyType]>;
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT
|
||||||
|
|
|
@ -49,4 +49,24 @@ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
||||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> {
|
||||||
|
let summary = "Default operation on a func";
|
||||||
|
let description = [{
|
||||||
|
Invokes a ufunc with the given arguments. This variant models the __call__
|
||||||
|
behavior of a python ufunc except that it does not model the `out`
|
||||||
|
parameter, which indicates an in-place update.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$ufunc_ref,
|
||||||
|
Variadic<Numpy_AnyArray>:$operands
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Numpy_AnyArray:$result
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$ufunc_ref `(` operands `)` attr-dict `:` functional-type(operands, results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
#endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS
|
||||||
|
|
|
@ -27,3 +27,20 @@ module @example_generic_ufunc {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @ufunc_apply_ops
|
||||||
|
module @ufunc_apply_ops {
|
||||||
|
numpy.generic_ufunc @numpy.add (
|
||||||
|
overload(%arg0: i32, %arg1: i32) -> i32 {
|
||||||
|
%0 = addi %arg0, %arg1 : i32
|
||||||
|
numpy.ufunc_return %0 : i32
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func @example(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
|
%0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>)
|
||||||
|
-> tensor<*xi32>
|
||||||
|
return %0 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue