mirror of https://github.com/llvm/torch-mlir
Add !numpy.any_dtype dialect type.
parent
b4425fe1d2
commit
d3632af675
|
@ -13,11 +13,32 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace numpy {
|
||||
namespace Numpy {
|
||||
|
||||
namespace NumpyTypes {
|
||||
enum Kind {
|
||||
AnyDtypeType = Type::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE,
|
||||
LAST_NUMPY_TYPE = AnyDtypeType
|
||||
};
|
||||
} // namespace NumpyTypes
|
||||
|
||||
// The singleton type representing an unknown dtype.
|
||||
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static AnyDtypeType get(MLIRContext *context) {
|
||||
return Base::get(context, NumpyTypes::Kind::AnyDtypeType);
|
||||
}
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == NumpyTypes::Kind::AnyDtypeType;
|
||||
}
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc"
|
||||
|
||||
} // namespace numpy
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ def Numpy_Dialect : Dialect {
|
|||
let description = [{
|
||||
Dialect of types and core numpy ops and abstractions.
|
||||
}];
|
||||
let cppNamespace = "numpy";
|
||||
let cppNamespace = "Numpy";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -34,6 +34,18 @@ class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::AnyDtypeType>()">, "any dtype">,
|
||||
BuildableType<"$_builder.getType::mlir::NPCOMP::Numpy::AnyDtypeType()"> {
|
||||
let typeDescription = [{
|
||||
Placeholder for an unknown dtype in a tensor.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace numpy {
|
||||
namespace Numpy {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.h.inc"
|
||||
|
||||
} // namespace numpy
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
|||
|
||||
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
||||
Terminator,
|
||||
HasParent<"numpy::GenericUfuncOp">]> {
|
||||
HasParent<"Numpy::GenericUfuncOp">]> {
|
||||
let summary = "Return a value from a generic_ufunc";
|
||||
let description = [{
|
||||
Must terminate the basic block of a generic_ufunc overload.
|
||||
|
|
|
@ -7,10 +7,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::numpy;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
|
||||
NumpyDialect::NumpyDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
|
@ -18,4 +19,27 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||
>();
|
||||
addTypes<AnyDtypeType>();
|
||||
}
|
||||
|
||||
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
if (keyword == "any_dtype")
|
||||
return AnyDtypeType::get(getContext());
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
||||
void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case NumpyTypes::AnyDtypeType:
|
||||
os << "any_dtype";
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected 'numpy' type kind");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace numpy {
|
||||
namespace Numpy {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuildinUfuncOp
|
||||
|
@ -147,6 +147,6 @@ static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
|
|||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||
} // namespace numpy
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @any_dtype
|
||||
func @any_dtype(%arg0: tensor<*x!numpy.any_dtype>) -> (tensor<*x!numpy.any_dtype>) {
|
||||
return %arg0 : tensor<*x!numpy.any_dtype>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @builtin_ufunc
|
||||
module @builtin_ufunc {
|
||||
|
|
|
@ -60,7 +60,7 @@ int main(int argc, char **argv) {
|
|||
mlir::registerAllDialects();
|
||||
mlir::registerAllPasses();
|
||||
|
||||
mlir::registerDialect<mlir::NPCOMP::numpy::NumpyDialect>();
|
||||
mlir::registerDialect<mlir::NPCOMP::Numpy::NumpyDialect>();
|
||||
// TODO: Register standalone passes here.
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
|
Loading…
Reference in New Issue