Allow the ndarray type to carry a shape.

pull/1/head
Stella Laurenzo 2020-07-05 17:34:03 -07:00
parent dc271dfb87
commit fae15ec5e7
6 changed files with 137 additions and 24 deletions

View File

@ -49,10 +49,21 @@ class NdArrayType
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
static NdArrayType get(Type optionalDtype);
/// Constructs an NdArray with a dtype and no shape. Setting the dtype
/// to !basicpy.UnknownType will print as ?.
static NdArrayType get(Type dtype,
llvm::Optional<ArrayRef<int64_t>> shape = llvm::None);
/// Returns whether the dtype is a concrete type (versus
/// !basicpy.UnknownType).
bool hasKnownDtype();
Type getDtype();
/// If the shape has been partially specified, this will have a value.
/// unknown dimensions are -1.
llvm::Optional<ArrayRef<int64_t>> getOptionalShape();
// CPA::TypeMapInterface methods.
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
};

View File

@ -35,19 +35,71 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
return AnyDtypeType::get(getContext());
if (keyword == "ndarray") {
// Parse:
// ndarray<?>
// ndarray<i32>
// ndarray<*:?>
// ndarray<*:i32>
// ndarary<[1,2,3]:i32>
// Note that this is a different syntax than the built-ins as the dialect
// parser is not general enough to parse a dimension list with an optional
// element type (?). The built-in form is also remarkably ambiguous when
// considering extending it.
Type dtype = Basicpy::UnknownType::get(getContext());
bool hasShape = false;
llvm::SmallVector<int64_t, 4> shape;
if (parser.parseLess())
return Type();
if (succeeded(parser.parseOptionalStar())) {
// Unranked.
} else {
// Parse dimension list.
hasShape = true;
if (parser.parseLSquare())
return Type();
for (bool first = true;; first = false) {
if (!first) {
if (failed(parser.parseOptionalComma())) {
break;
}
}
if (succeeded(parser.parseOptionalQuestion())) {
shape.push_back(-1);
continue;
}
int64_t dim;
auto optionalPr = parser.parseOptionalInteger(dim);
if (optionalPr.hasValue()) {
if (failed(*optionalPr))
return Type();
shape.push_back(dim);
continue;
}
break;
}
if (parser.parseRSquare()) {
return Type();
}
}
// Parse colon dtype.
if (parser.parseColon()) {
return Type();
}
if (failed(parser.parseOptionalQuestion())) {
// Specified dtype.
if (parser.parseType(dtype))
if (parser.parseType(dtype)) {
return Type();
}
}
if (parser.parseGreater())
if (parser.parseGreater()) {
return Type();
return NdArrayType::get(dtype);
}
llvm::Optional<ArrayRef<int64_t>> optionalShape;
if (hasShape)
optionalShape = shape;
auto ndarray = NdArrayType::get(dtype, optionalShape);
return ndarray;
}
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
@ -62,8 +114,23 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
case NumpyTypes::NdArray: {
auto unknownType = Basicpy::UnknownType::get(getContext());
auto ndarray = type.cast<NdArrayType>();
auto shape = ndarray.getOptionalShape();
auto dtype = ndarray.getDtype();
os << "ndarray<";
if (!shape) {
os << "*:";
} else {
os << "[";
for (auto it : llvm::enumerate(*shape)) {
if (it.index() > 0)
os << ",";
if (it.value() < 0)
os << "?";
else
os << it.value();
}
os << "]:";
}
if (dtype != unknownType)
os.printType(dtype);
else
@ -85,19 +152,41 @@ namespace Numpy {
namespace detail {
struct NdArrayTypeStorage : public TypeStorage {
using KeyTy = Type;
NdArrayTypeStorage(Type optionalDtype) : optionalDtype(optionalDtype) {}
bool operator==(const KeyTy &other) const { return optionalDtype == other; }
using KeyTy = std::pair<Type, llvm::Optional<ArrayRef<int64_t>>>;
NdArrayTypeStorage(Type dtype, int rank, const int64_t *shapeElements)
: dtype(dtype), rank(rank), shapeElements(shapeElements) {}
bool operator==(const KeyTy &key) const {
return key == KeyTy(dtype, getOptionalShape());
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key);
if (key.second) {
return llvm::hash_combine(key.first, *key.second);
} else {
return llvm::hash_combine(key.first, -1);
}
}
static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
int rank = -1;
const int64_t *shapeElements = nullptr;
if (key.second.hasValue()) {
auto allocElements = allocator.copyInto(*key.second);
rank = key.second->size();
shapeElements = allocElements.data();
}
return new (allocator.allocate<NdArrayTypeStorage>())
NdArrayTypeStorage(key);
NdArrayTypeStorage(key.first, rank, shapeElements);
}
Type optionalDtype;
llvm::Optional<ArrayRef<int64_t>> getOptionalShape() const {
if (rank < 0)
return llvm::None;
return ArrayRef<int64_t>(shapeElements, rank);
}
Type dtype;
int rank;
const int64_t *shapeElements;
};
} // namespace detail
@ -105,16 +194,21 @@ struct NdArrayTypeStorage : public TypeStorage {
} // namespace NPCOMP
} // namespace mlir
NdArrayType NdArrayType::get(Type dtype) {
NdArrayType NdArrayType::get(Type dtype,
llvm::Optional<ArrayRef<int64_t>> shape) {
assert(dtype && "dtype cannot be null");
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype);
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype, shape);
}
bool NdArrayType::hasKnownDtype() {
return getDtype() != Basicpy::UnknownType::get(getContext());
}
Type NdArrayType::getDtype() { return getImpl()->optionalDtype; }
Type NdArrayType::getDtype() { return getImpl()->dtype; }
llvm::Optional<ArrayRef<int64_t>> NdArrayType::getOptionalShape() {
return getImpl()->getOptionalShape();
}
Typing::CPA::TypeNode *
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {

View File

@ -13,6 +13,6 @@ global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
@import_global
def global_array_to_const():
# CHECK: %[[CST:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
# CHECK: numpy.create_array_from_tensor %[[CST]] : (tensor<2x3xf64>) -> !numpy.ndarray<f64>
# CHECK: numpy.create_array_from_tensor %[[CST]] : (tensor<2x3xf64>) -> !numpy.ndarray<*:f64>
local_data = global_data
return local_data

View File

@ -16,7 +16,7 @@ b = np.asarray([3.0, 4.0])
# Test the basic flow of invoking a ufunc call with constants captured from
# a global using explicit function syntax (np.add(a, b)).
# CHECK-LABEL: func @global_add
# CHECK-SAME: -> !numpy.ndarray<f64>
# CHECK-SAME: -> !numpy.ndarray<*:f64>
@import_global
def global_add():
# CHECK-NOT: UnknownType

View File

@ -25,7 +25,7 @@ def global_add():
# CHECK-DAG: %[[A:.*]] = numpy.copy_to_tensor %[[A_ARRAY]]
# CHECK-DAG: %[[B:.*]] = numpy.copy_to_tensor %[[B_ARRAY]]
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<*:?>
return np.add(a, b)

View File

@ -1,13 +1,21 @@
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
// CHECK-LABEL: @ndarray_no_dtype
// CHECK: !numpy.ndarray<?>
func @ndarray_no_dtype(%arg0 : !numpy.ndarray<?>) -> !numpy.ndarray<?> {
return %arg0 : !numpy.ndarray<?>
// CHECK: !numpy.ndarray<*:?>
func @ndarray_no_dtype(%arg0 : !numpy.ndarray<*:?>) -> !numpy.ndarray<*:?> {
return %arg0 : !numpy.ndarray<*:?>
}
// -----
// CHECK-LABEL: @ndarray_dtype
// CHECK: !numpy.ndarray<i32>
func @ndarray_dtype(%arg0 : !numpy.ndarray<i32>) -> !numpy.ndarray<i32> {
return %arg0 : !numpy.ndarray<i32>
// CHECK: !numpy.ndarray<*:i32>
func @ndarray_dtype(%arg0 : !numpy.ndarray<*:i32>) -> !numpy.ndarray<*:i32> {
return %arg0 : !numpy.ndarray<*:i32>
}
// -----
// CHECK-LABEL: @ndarray_ranked
// CHECK: !numpy.ndarray<[1,?,3]:i32>
func @ndarray_ranked(%arg0 : !numpy.ndarray<[1,?,3]:i32>) -> !numpy.ndarray<[1,?,3]:i32> {
return %arg0 : !numpy.ndarray<[1,?,3]:i32>
}