Add NdArrayType.

pull/1/head
Stella Laurenzo 2020-06-28 17:37:20 -07:00
parent bccfd5f6fc
commit efe8915901
3 changed files with 121 additions and 5 deletions

View File

@ -17,10 +17,18 @@ namespace NPCOMP {
namespace Numpy {
namespace NumpyTypes {
enum Kind { AnyDtypeType = TypeRanges::Numpy, LAST_NUMPY_TYPE = AnyDtypeType };
enum Kind {
AnyDtypeType = TypeRanges::Numpy,
NdArray,
LAST_NUMPY_TYPE = AnyDtypeType,
};
} // namespace NumpyTypes
// The singleton type representing an unknown dtype.
namespace detail {
struct NdArrayTypeStorage;
} // namespace detail
/// The singleton type representing an unknown dtype.
class AnyDtypeType : public Type::TypeBase<AnyDtypeType, Type> {
public:
using Base::Base;
@ -34,6 +42,15 @@ public:
}
};
class NdArrayType
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
static NdArrayType get(Type optionalDtype, MLIRContext *context);
Type getOptionalDtype();
};
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
} // namespace Numpy

View File

@ -32,7 +32,7 @@ def Numpy_Dialect : Dialect {
class Numpy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Numpy_Dialect, mnemonic, traits> {
let parser = [{ return parse$cppClass(parser, &result); }];
let printer = [{ return print$cppClass(p, *this); }];
let printer = [{ return print$cppClass(p, *this); }];
}
//===----------------------------------------------------------------------===//
@ -47,6 +47,43 @@ def Numpy_AnyDtype : DialectType<Numpy_Dialect,
}];
}
def Numpy_NdArrayType : DialectType<Numpy_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::Numpy::NdArrayType>()">, "ndarray type">,
BuildableType<"$_builder.getType<::mlir::NPCOMP::Numpy::NdArrayType>()"> {
let typeDescription = [{
NdArrayType: Models a numpy.ndarray and compatible types.
Unlike lower level representations, this type solely exists to represent
top-level semantics and source-dialect transformations. As such, it
is not a general modeling like `tensor` or `memref`, instead being just
enough to infer proper lowerings to those types.
Like its numpy counterparts, NdArrayType represents a mutable array of
some value type (dtype), with a shape, strides, and various controls
around contiguity. Most of that is not modeled in this type, which focuses
on a representation sufficient to infer high level types and aliasing
based on program flow.
Note that most operations in numpy can be legally defined similar to the
following:
%0 = ... -> !numpy.ndarray<...>
%1 = numpy.copy_to_tensor %0 -> tensor<...>
%2 = numpy.some_operation %1
%4 = numpy.copy_from_tensor -> !numpy.ndarray<...>
(in other words, the operation does not alias any of its operands to its
results)
When this is the case, the operation will *only* be defined for tensors,
as staying in the value domain makes sense for as many operations as
can be reasonably represented as such. It is left to subsequent parts of
the compiler to transform the program in such a way as to elide the copies
that such sequences encode.
Only ops that mutate or alias their operands should accept and/or produce
ndarray types.
}];
}
//===----------------------------------------------------------------------===//
// Type predicates
//===----------------------------------------------------------------------===//
@ -57,7 +94,7 @@ def Numpy_AnyArray : TensorOf<[AnyType]>;
def Numpy_SliceTupleElement : AnyTypeOf<[
// Supports both "Index Arrays" and "Boolean mask index arrays".
Numpy_AnyArray,
// Indicates that an axis should be added (np.newaxis == None).
Basicpy_NoneType,

View File

@ -19,7 +19,7 @@ NumpyDialect::NumpyDialect(MLIRContext *context)
#define GET_OP_LIST
#include "npcomp/Dialect/Numpy/IR/NumpyOps.cpp.inc"
>();
addTypes<AnyDtypeType>();
addTypes<AnyDtypeType, NdArrayType>();
}
Type NumpyDialect::parseType(DialectAsmParser &parser) const {
@ -29,6 +29,22 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "any_dtype")
return AnyDtypeType::get(getContext());
if (keyword == "ndarray") {
// Parse:
// ndarray<?>
// ndarray<i32>
Type dtype;
if (parser.parseLess())
return Type();
if (failed(parser.parseOptionalQuestion())) {
// Specified dtype.
if (parser.parseType(dtype))
return Type();
}
if (parser.parseGreater())
return Type();
return NdArrayType::get(dtype, getContext());
}
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
return Type();
@ -39,7 +55,53 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
case NumpyTypes::AnyDtypeType:
os << "any_dtype";
return;
case NumpyTypes::NdArray: {
auto ndarray = type.cast<NdArrayType>();
auto dtype = ndarray.getOptionalDtype();
os << "ndarray<";
if (dtype)
os.printType(dtype);
else
os << "?";
os << ">";
return;
}
default:
llvm_unreachable("unexpected 'numpy' type kind");
}
}
//----------------------------------------------------------------------------//
// Type and attribute detail
//----------------------------------------------------------------------------//
namespace mlir {
namespace NPCOMP {
namespace Numpy {
namespace detail {
struct NdArrayTypeStorage : public TypeStorage {
using KeyTy = Type;
NdArrayTypeStorage(Type optionalDtype) : optionalDtype(optionalDtype) {}
bool operator==(const KeyTy &other) const { return optionalDtype == other; }
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key);
}
static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<NdArrayTypeStorage>())
NdArrayTypeStorage(key);
}
Type optionalDtype;
};
} // namespace detail
} // namespace Numpy
} // namespace NPCOMP
} // namespace mlir
NdArrayType NdArrayType::get(Type optionalDtype, MLIRContext *context) {
return Base::get(context, NumpyTypes::NdArray, optionalDtype);
}
Type NdArrayType::getOptionalDtype() { return getImpl()->optionalDtype; }