diff --git a/include/npcomp/Dialect/Numpy/NumpyDialect.td b/include/npcomp/Dialect/Numpy/NumpyDialect.td index a5df0476f..af430b838 100644 --- a/include/npcomp/Dialect/Numpy/NumpyDialect.td +++ b/include/npcomp/Dialect/Numpy/NumpyDialect.td @@ -11,6 +11,10 @@ include "mlir/IR/OpBase.td" +//===----------------------------------------------------------------------===// +// Dialect definition +//===----------------------------------------------------------------------===// + def Numpy_Dialect : Dialect { let name = "numpy"; let summary = "Core numpy dialect"; @@ -20,10 +24,20 @@ def Numpy_Dialect : Dialect { let cppNamespace = "numpy"; } +//===----------------------------------------------------------------------===// +// Op templates +//===----------------------------------------------------------------------===// + class Numpy_Op traits = []> : Op { let parser = [{ return parse$cppClass(parser, &result); }]; let printer = [{ return print$cppClass(p, *this); }]; } +//===----------------------------------------------------------------------===// +// Type predicates +//===----------------------------------------------------------------------===// + +def Numpy_AnyArray : TensorOf<[AnyType]>; + #endif // NPCOMP_DIALECT_NUMPY_NUMPY_DIALECT diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.td b/include/npcomp/Dialect/Numpy/NumpyOps.td index b736c3f8d..dfd09fb8a 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/NumpyOps.td @@ -49,4 +49,24 @@ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } +def Numpy_UfuncCallOp : Numpy_Op<"ufunc_call", []> { + let summary = "Default operation on a func"; + let description = [{ + Invokes a ufunc with the given arguments. This variant models the __call__ + behavior of a python ufunc except that it does not model the `out` + parameter, which indicates an in-place update. + }]; + let arguments = (ins + FlatSymbolRefAttr:$ufunc_ref, + Variadic:$operands + ); + let results = (outs + Numpy_AnyArray:$result + ); + + let assemblyFormat = [{ + $ufunc_ref `(` operands `)` attr-dict `:` functional-type(operands, results) + }]; +} + #endif // NPCOMP_DIALECT_NUMPY_NUMPY_OPS diff --git a/test/Dialect/Numpy/ops.mlir b/test/Dialect/Numpy/ops.mlir index e7fdf14ed..20d4cb0dc 100644 --- a/test/Dialect/Numpy/ops.mlir +++ b/test/Dialect/Numpy/ops.mlir @@ -27,3 +27,20 @@ module @example_generic_ufunc { } ) } + +// ----- +// CHECK-LABEL: @ufunc_apply_ops +module @ufunc_apply_ops { + numpy.generic_ufunc @numpy.add ( + overload(%arg0: i32, %arg1: i32) -> i32 { + %0 = addi %arg0, %arg1 : i32 + numpy.ufunc_return %0 : i32 + } + ) + + func @example(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + %0 = numpy.ufunc_call @numpy.add(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) + -> tensor<*xi32> + return %0 : tensor<*xi32> + } +}