mirror of https://github.com/llvm/torch-mlir
Add !numpy.any_dtype dialect type.
parent
b4425fe1d2
commit
d3632af675
|
@ -13,11 +13,32 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace numpy {
|
namespace Numpy {
|
||||||
|
|
||||||
|
namespace NumpyTypes {
|
||||||
|
enum Kind {
|
||||||
|
AnyDtypeType = Type::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE,
|
||||||
|
LAST_NUMPY_TYPE = AnyDtypeType
|
||||||
|
};
|
||||||
|
} // namespace NumpyTypes
|
||||||
|
|
||||||
|
// The singleton type representing an unknown dtype.
|
||||||
|
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
|
||||||
|
static AnyDtypeType get(MLIRContext *context) {
|
||||||
|
return Base::get(context, NumpyTypes::Kind::AnyDtypeType);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool kindof(unsigned kind) {
|
||||||
|
return kind == NumpyTypes::Kind::AnyDtypeType;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc"
|
#include "npcomp/Dialect/Numpy/NumpyOpsDialect.h.inc"
|
||||||
|
|
||||||
} // namespace numpy
|
} // namespace Numpy
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ def Numpy_Dialect : Dialect {
|
||||||
let description = [{
|
let description = [{
|
||||||
Dialect of types and core numpy ops and abstractions.
|
Dialect of types and core numpy ops and abstractions.
|
||||||
}];
|
}];
|
||||||
let cppNamespace = "numpy";
|
let cppNamespace = "Numpy";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -34,6 +34,18 @@ class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
let printer = [{ return print$cppClass(p, *this); }];
|
let printer = [{ return print$cppClass(p, *this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Numpy_AnyDtype : DialectType<Numpy_Dialect,
|
||||||
|
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::AnyDtypeType>()">, "any dtype">,
|
||||||
|
BuildableType<"$_builder.getType::mlir::NPCOMP::Numpy::AnyDtypeType()"> {
|
||||||
|
let typeDescription = [{
|
||||||
|
Placeholder for an unknown dtype in a tensor.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Type predicates
|
// Type predicates
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -19,12 +19,12 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace numpy {
|
namespace Numpy {
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.h.inc"
|
#include "npcomp/Dialect/Numpy/NumpyOps.h.inc"
|
||||||
|
|
||||||
} // namespace numpy
|
} // namespace Numpy
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ def Numpy_GenericUfuncOp : Numpy_Op<"generic_ufunc", [
|
||||||
|
|
||||||
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
def Numpy_UfuncReturnOp : Numpy_Op<"ufunc_return", [
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"numpy::GenericUfuncOp">]> {
|
HasParent<"Numpy::GenericUfuncOp">]> {
|
||||||
let summary = "Return a value from a generic_ufunc";
|
let summary = "Return a value from a generic_ufunc";
|
||||||
let description = [{
|
let description = [{
|
||||||
Must terminate the basic block of a generic_ufunc overload.
|
Must terminate the basic block of a generic_ufunc overload.
|
||||||
|
|
|
@ -7,10 +7,11 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/NumpyDialect.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
#include "npcomp/Dialect/Numpy/NumpyOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::numpy;
|
using namespace mlir::NPCOMP::Numpy;
|
||||||
|
|
||||||
NumpyDialect::NumpyDialect(MLIRContext *context)
|
NumpyDialect::NumpyDialect(MLIRContext *context)
|
||||||
: Dialect(getDialectNamespace(), context) {
|
: Dialect(getDialectNamespace(), context) {
|
||||||
|
@ -18,4 +19,27 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
addTypes<AnyDtypeType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
||||||
|
StringRef keyword;
|
||||||
|
if (parser.parseKeyword(&keyword))
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
if (keyword == "any_dtype")
|
||||||
|
return AnyDtypeType::get(getContext());
|
||||||
|
|
||||||
|
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||||
|
switch (type.getKind()) {
|
||||||
|
case NumpyTypes::AnyDtypeType:
|
||||||
|
os << "any_dtype";
|
||||||
|
return;
|
||||||
|
default:
|
||||||
|
llvm_unreachable("unexpected 'numpy' type kind");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace numpy {
|
namespace Numpy {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BuildinUfuncOp
|
// BuildinUfuncOp
|
||||||
|
@ -147,6 +147,6 @@ static void printGenericUfuncOp(OpAsmPrinter &p, GenericUfuncOp op) {
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
#include "npcomp/Dialect/Numpy/NumpyOps.cpp.inc"
|
||||||
} // namespace numpy
|
} // namespace Numpy
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -1,4 +1,10 @@
|
||||||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @any_dtype
|
||||||
|
func @any_dtype(%arg0: tensor<*x!numpy.any_dtype>) -> (tensor<*x!numpy.any_dtype>) {
|
||||||
|
return %arg0 : tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: @builtin_ufunc
|
// CHECK-LABEL: @builtin_ufunc
|
||||||
module @builtin_ufunc {
|
module @builtin_ufunc {
|
||||||
|
|
|
@ -60,7 +60,7 @@ int main(int argc, char **argv) {
|
||||||
mlir::registerAllDialects();
|
mlir::registerAllDialects();
|
||||||
mlir::registerAllPasses();
|
mlir::registerAllPasses();
|
||||||
|
|
||||||
mlir::registerDialect<mlir::NPCOMP::numpy::NumpyDialect>();
|
mlir::registerDialect<mlir::NPCOMP::Numpy::NumpyDialect>();
|
||||||
// TODO: Register standalone passes here.
|
// TODO: Register standalone passes here.
|
||||||
|
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
Loading…
Reference in New Issue