From d3632af675938c98b2c7260152193a468b384b3c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 29 Apr 2020 18:20:42 -0700 Subject: [PATCH] Add !numpy.any_dtype dialect type. --- include/npcomp/Dialect/Numpy/NumpyDialect.h | 25 +++++++++++++++++-- include/npcomp/Dialect/Numpy/NumpyDialect.td | 14 ++++++++++- include/npcomp/Dialect/Numpy/NumpyOps.h | 4 +-- include/npcomp/Dialect/Numpy/NumpyOps.td | 2 +- lib/Dialect/Numpy/NumpyDialect.cpp | 26 +++++++++++++++++++- lib/Dialect/Numpy/NumpyOps.cpp | 4 +-- test/Dialect/Numpy/ops.mlir | 6 +++++ tools/npcomp-opt/npcomp-opt.cpp | 2 +- 8 files changed, 73 insertions(+), 10 deletions(-) diff --git a/include/npcomp/Dialect/Numpy/NumpyDialect.h b/include/npcomp/Dialect/Numpy/NumpyDialect.h index b8bfe7956..11b5adcfc 100644 --- a/include/npcomp/Dialect/Numpy/NumpyDialect.h +++ b/include/npcomp/Dialect/Numpy/NumpyDialect.h @@ -13,11 +13,32 @@ namespace mlir { namespace NPCOMP { -namespace numpy { +namespace Numpy { + +namespace NumpyTypes { +enum Kind { + AnyDtypeType = Type::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE, + LAST_NUMPY_TYPE = AnyDtypeType +}; +} // namespace NumpyTypes + +// The singleton type representing an unknown dtype. +class AnyDtypeType : public Type::TypeBase { +public: + using Base::Base; + + static AnyDtypeType get(MLIRContext *context) { + return Base::get(context, NumpyTypes::Kind::AnyDtypeType); + } + + static bool kindof(unsigned kind) { + return kind == NumpyTypes::Kind::AnyDtypeType; + } +}; #include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc" -} // namespace numpy +} // namespace Numpy } // namespace NPCOMP } // namespace mlir diff --git a/include/npcomp/Dialect/Numpy/NumpyDialect.td b/include/npcomp/Dialect/Numpy/NumpyDialect.td index af430b838..49fd8ffb6 100644 --- a/include/npcomp/Dialect/Numpy/NumpyDialect.td +++ b/include/npcomp/Dialect/Numpy/NumpyDialect.td @@ -21,7 +21,7 @@ def Numpy_Dialect : Dialect { let description = [{ Dialect of types and core numpy ops and abstractions. }]; - let cppNamespace = "numpy"; + let cppNamespace = "Numpy"; } //===----------------------------------------------------------------------===// @@ -34,6 +34,18 @@ class Numpy_Op traits = []> : let printer = [{ return print$cppClass(p, *this); }]; } +//===----------------------------------------------------------------------===// +// Dialect types +//===----------------------------------------------------------------------===// + +def Numpy_AnyDtype : DialectType()">, "any dtype">, + BuildableType<"$_builder.getType::mlir::NPCOMP::Numpy::AnyDtypeType()"> { + let typeDescription = [{ + Placeholder for an unknown dtype in a tensor. + }]; +} + //===----------------------------------------------------------------------===// // Type predicates //===----------------------------------------------------------------------===// diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.h b/include/npcomp/Dialect/Numpy/NumpyOps.h index f6b0363bf..bd621cb8a 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.h +++ b/include/npcomp/Dialect/Numpy/NumpyOps.h @@ -19,12 +19,12 @@ namespace mlir { namespace NPCOMP { -namespace numpy { +namespace Numpy { #define GET_OP_CLASSES #include "npcomp/Dialect/Numpy/NumpyOps.h.inc" -} // namespace numpy +} // namespace Numpy } // namespace NPCOMP } // namespace mlir diff --git a/include/npcomp/Dialect/Numpy/NumpyOps.td b/include/npcomp/Dialect/Numpy/NumpyOps.td index dfd09fb8a..3a4ef40db 100644 --- a/include/npcomp/Dialect/Numpy/NumpyOps.td +++ b/include/npcomp/Dialect/Numpy/NumpyOps.td @@ -37,7 +37,7 @@ def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [ Terminator, - HasParent<"numpy::GenericUfuncOp">]> { + HasParent<"Numpy::GenericUfuncOp">]> { let summary = "Return a value from a generic_ufunc"; let description = [{ Must terminate the basic block of a generic_ufunc overload. diff --git a/lib/Dialect/Numpy/NumpyDialect.cpp b/lib/Dialect/Numpy/NumpyDialect.cpp index 71de11924..f1d67b78d 100644 --- a/lib/Dialect/Numpy/NumpyDialect.cpp +++ b/lib/Dialect/Numpy/NumpyDialect.cpp @@ -7,10 +7,11 @@ //===----------------------------------------------------------------------===// #include "npcomp/Dialect/Numpy/NumpyDialect.h" +#include "mlir/IR/DialectImplementation.h" #include "npcomp/Dialect/Numpy/NumpyOps.h" using namespace mlir; -using namespace mlir::NPCOMP::numpy; +using namespace mlir::NPCOMP::Numpy; NumpyDialect::NumpyDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { @@ -18,4 +19,27 @@ NumpyDialect::NumpyDialect(MLIRContext *context) #define GET_OP_LIST #include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc" >(); + addTypes(); +} + +Type NumpyDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + if (keyword == "any_dtype") + return AnyDtypeType::get(getContext()); + + parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword; + return Type(); +} + +void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const { + switch (type.getKind()) { + case NumpyTypes::AnyDtypeType: + os << "any_dtype"; + return; + default: + llvm_unreachable("unexpected 'numpy' type kind"); + } } diff --git a/lib/Dialect/Numpy/NumpyOps.cpp b/lib/Dialect/Numpy/NumpyOps.cpp index ccac24fa9..770338ffc 100644 --- a/lib/Dialect/Numpy/NumpyOps.cpp +++ b/lib/Dialect/Numpy/NumpyOps.cpp @@ -14,7 +14,7 @@ namespace mlir { namespace NPCOMP { -namespace numpy { +namespace Numpy { //===----------------------------------------------------------------------===// // BuildinUfuncOp @@ -147,6 +147,6 @@ static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) { #define GET_OP_CLASSES #include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc" -} // namespace numpy +} // namespace Numpy } // namespace NPCOMP } // namespace mlir diff --git a/test/Dialect/Numpy/ops.mlir b/test/Dialect/Numpy/ops.mlir index 20d4cb0dc..2ac649870 100644 --- a/test/Dialect/Numpy/ops.mlir +++ b/test/Dialect/Numpy/ops.mlir @@ -1,4 +1,10 @@ // RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s + +// CHECK-LABEL: @any_dtype +func @any_dtype(%arg0: tensor<*x!numpy.any_dtype>) -> (tensor<*x!numpy.any_dtype>) { + return %arg0 : tensor<*x!numpy.any_dtype> +} + // ----- // CHECK-LABEL: @builtin_ufunc module @builtin_ufunc { diff --git a/tools/npcomp-opt/npcomp-opt.cpp b/tools/npcomp-opt/npcomp-opt.cpp index 9f6fc667f..3279573c7 100644 --- a/tools/npcomp-opt/npcomp-opt.cpp +++ b/tools/npcomp-opt/npcomp-opt.cpp @@ -60,7 +60,7 @@ int main(int argc, char **argv) { mlir::registerAllDialects(); mlir::registerAllPasses(); - mlir::registerDialect(); + mlir::registerDialect(); // TODO: Register standalone passes here. llvm::InitLLVM y(argc, argv);