Add !numpy.any_dtype dialect type.

pull/1/head
Stella Laurenzo 2020-04-29 18:20:42 -07:00
parent b4425fe1d2
commit d3632af675
8 changed files with 73 additions and 10 deletions

View File

@ -13,11 +13,32 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace numpy { namespace Numpy {
namespace NumpyTypes {
enum Kind {
AnyDtypeType = Type::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE,
LAST_NUMPY_TYPE = AnyDtypeType
};
} // namespace NumpyTypes
// The singleton type representing an unknown dtype.
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
public:
using Base::Base;
static AnyDtypeType get(MLIRContext *context) {
return Base::get(context, NumpyTypes::Kind::AnyDtypeType);
}
static bool kindof(unsigned kind) {
return kind == NumpyTypes::Kind::AnyDtypeType;
}
};
#include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc" #include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc"
} // namespace numpy } // namespace Numpy
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -21,7 +21,7 @@ def Numpy_Dialect : Dialect {
let description = [{ let description = [{
Dialect of types and core numpy ops and abstractions. Dialect of types and core numpy ops and abstractions.
}]; }];
let cppNamespace = "numpy"; let cppNamespace = "Numpy";
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -34,6 +34,18 @@ class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
let printer = [{ return print$cppClass(p, *this); }]; let printer = [{ return print$cppClass(p, *this); }];
} }
//===----------------------------------------------------------------------===//
// Dialect types
//===----------------------------------------------------------------------===//
def Numpy_AnyDtype : DialectType<Numpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::AnyDtypeType>()">, "any dtype">,
BuildableType<"$_builder.getType::mlir::NPCOMP::Numpy::AnyDtypeType()"> {
let typeDescription = [{
Placeholder for an unknown dtype in a tensor.
}];
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type predicates // Type predicates
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -19,12 +19,12 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace numpy { namespace Numpy {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Numpy/NumpyOps.h.inc" #include "npcomp/Dialect/Numpy/NumpyOps.h.inc"
} // namespace numpy } // namespace Numpy
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -37,7 +37,7 @@ def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [ def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
Terminator, Terminator,
HasParent<"numpy::GenericUfuncOp">]> { HasParent<"Numpy::GenericUfuncOp">]> {
let summary = "Return a value from a generic_ufunc"; let summary = "Return a value from a generic_ufunc";
let description = [{ let description = [{
Must terminate the basic block of a generic_ufunc overload. Must terminate the basic block of a generic_ufunc overload.

View File

@ -7,10 +7,11 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Numpy/NumpyDialect.h" #include "npcomp/Dialect/Numpy/NumpyDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Numpy/NumpyOps.h" #include "npcomp/Dialect/Numpy/NumpyOps.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP::numpy; using namespace mlir::NPCOMP::Numpy;
NumpyDialect::NumpyDialect(MLIRContext *context) NumpyDialect::NumpyDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
@ -18,4 +19,27 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
#define GET_OP_LIST #define GET_OP_LIST
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc" #include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
>(); >();
addTypes<AnyDtypeType>();
}
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "any_dtype")
return AnyDtypeType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
return Type();
}
void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case NumpyTypes::AnyDtypeType:
os << "any_dtype";
return;
default:
llvm_unreachable("unexpected 'numpy' type kind");
}
} }

View File

@ -14,7 +14,7 @@
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace numpy { namespace Numpy {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BuildinUfuncOp // BuildinUfuncOp
@ -147,6 +147,6 @@ static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc" #include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
} // namespace numpy } // namespace Numpy
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir

View File

@ -1,4 +1,10 @@
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s // RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
// CHECK-LABEL: @any_dtype
func @any_dtype(%arg0: tensor<*x!numpy.any_dtype>) -> (tensor<*x!numpy.any_dtype>) {
return %arg0 : tensor<*x!numpy.any_dtype>
}
// ----- // -----
// CHECK-LABEL: @builtin_ufunc // CHECK-LABEL: @builtin_ufunc
module @builtin_ufunc { module @builtin_ufunc {

View File

@ -60,7 +60,7 @@ int main(int argc, char **argv) {
mlir::registerAllDialects(); mlir::registerAllDialects();
mlir::registerAllPasses(); mlir::registerAllPasses();
mlir::registerDialect<mlir::NPCOMP::numpy::NumpyDialect>(); mlir::registerDialect<mlir::NPCOMP::Numpy::NumpyDialect>();
// TODO: Register standalone passes here. // TODO: Register standalone passes here.
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);