mirror of https://github.com/llvm/torch-mlir
Allow the ndarray type to carry a shape.
parent
dc271dfb87
commit
fae15ec5e7
|
@ -49,10 +49,21 @@ class NdArrayType
|
|||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
|
||||
static NdArrayType get(Type optionalDtype);
|
||||
|
||||
/// Constructs an NdArray with a dtype and no shape. Setting the dtype
|
||||
/// to !basicpy.UnknownType will print as ?.
|
||||
static NdArrayType get(Type dtype,
|
||||
llvm::Optional<ArrayRef<int64_t>> shape = llvm::None);
|
||||
|
||||
/// Returns whether the dtype is a concrete type (versus
|
||||
/// !basicpy.UnknownType).
|
||||
bool hasKnownDtype();
|
||||
Type getDtype();
|
||||
|
||||
/// If the shape has been partially specified, this will have a value.
|
||||
/// unknown dimensions are -1.
|
||||
llvm::Optional<ArrayRef<int64_t>> getOptionalShape();
|
||||
|
||||
// CPA::TypeMapInterface methods.
|
||||
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
|
||||
};
|
||||
|
|
|
@ -35,19 +35,71 @@ Type NumpyDialect::parseType(DialectAsmParser &parser) const {
|
|||
return AnyDtypeType::get(getContext());
|
||||
if (keyword == "ndarray") {
|
||||
// Parse:
|
||||
// ndarray<?>
|
||||
// ndarray<i32>
|
||||
// ndarray<*:?>
|
||||
// ndarray<*:i32>
|
||||
// ndarary<[1,2,3]:i32>
|
||||
// Note that this is a different syntax than the built-ins as the dialect
|
||||
// parser is not general enough to parse a dimension list with an optional
|
||||
// element type (?). The built-in form is also remarkably ambiguous when
|
||||
// considering extending it.
|
||||
Type dtype = Basicpy::UnknownType::get(getContext());
|
||||
bool hasShape = false;
|
||||
llvm::SmallVector<int64_t, 4> shape;
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
if (succeeded(parser.parseOptionalStar())) {
|
||||
// Unranked.
|
||||
} else {
|
||||
// Parse dimension list.
|
||||
hasShape = true;
|
||||
if (parser.parseLSquare())
|
||||
return Type();
|
||||
for (bool first = true;; first = false) {
|
||||
if (!first) {
|
||||
if (failed(parser.parseOptionalComma())) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (succeeded(parser.parseOptionalQuestion())) {
|
||||
shape.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
auto optionalPr = parser.parseOptionalInteger(dim);
|
||||
if (optionalPr.hasValue()) {
|
||||
if (failed(*optionalPr))
|
||||
return Type();
|
||||
shape.push_back(dim);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (parser.parseRSquare()) {
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
|
||||
// Parse colon dtype.
|
||||
if (parser.parseColon()) {
|
||||
return Type();
|
||||
}
|
||||
|
||||
if (failed(parser.parseOptionalQuestion())) {
|
||||
// Specified dtype.
|
||||
if (parser.parseType(dtype))
|
||||
if (parser.parseType(dtype)) {
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
if (parser.parseGreater())
|
||||
if (parser.parseGreater()) {
|
||||
return Type();
|
||||
return NdArrayType::get(dtype);
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayRef<int64_t>> optionalShape;
|
||||
if (hasShape)
|
||||
optionalShape = shape;
|
||||
auto ndarray = NdArrayType::get(dtype, optionalShape);
|
||||
return ndarray;
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword;
|
||||
|
@ -62,8 +114,23 @@ void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
case NumpyTypes::NdArray: {
|
||||
auto unknownType = Basicpy::UnknownType::get(getContext());
|
||||
auto ndarray = type.cast<NdArrayType>();
|
||||
auto shape = ndarray.getOptionalShape();
|
||||
auto dtype = ndarray.getDtype();
|
||||
os << "ndarray<";
|
||||
if (!shape) {
|
||||
os << "*:";
|
||||
} else {
|
||||
os << "[";
|
||||
for (auto it : llvm::enumerate(*shape)) {
|
||||
if (it.index() > 0)
|
||||
os << ",";
|
||||
if (it.value() < 0)
|
||||
os << "?";
|
||||
else
|
||||
os << it.value();
|
||||
}
|
||||
os << "]:";
|
||||
}
|
||||
if (dtype != unknownType)
|
||||
os.printType(dtype);
|
||||
else
|
||||
|
@ -85,19 +152,41 @@ namespace Numpy {
|
|||
namespace detail {
|
||||
|
||||
struct NdArrayTypeStorage : public TypeStorage {
|
||||
using KeyTy = Type;
|
||||
NdArrayTypeStorage(Type optionalDtype) : optionalDtype(optionalDtype) {}
|
||||
bool operator==(const KeyTy &other) const { return optionalDtype == other; }
|
||||
using KeyTy = std::pair<Type, llvm::Optional<ArrayRef<int64_t>>>;
|
||||
NdArrayTypeStorage(Type dtype, int rank, const int64_t *shapeElements)
|
||||
: dtype(dtype), rank(rank), shapeElements(shapeElements) {}
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(dtype, getOptionalShape());
|
||||
}
|
||||
static llvm::hash_code hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key);
|
||||
if (key.second) {
|
||||
return llvm::hash_combine(key.first, *key.second);
|
||||
} else {
|
||||
return llvm::hash_combine(key.first, -1);
|
||||
}
|
||||
}
|
||||
static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
int rank = -1;
|
||||
const int64_t *shapeElements = nullptr;
|
||||
if (key.second.hasValue()) {
|
||||
auto allocElements = allocator.copyInto(*key.second);
|
||||
rank = key.second->size();
|
||||
shapeElements = allocElements.data();
|
||||
}
|
||||
return new (allocator.allocate<NdArrayTypeStorage>())
|
||||
NdArrayTypeStorage(key);
|
||||
NdArrayTypeStorage(key.first, rank, shapeElements);
|
||||
}
|
||||
|
||||
Type optionalDtype;
|
||||
llvm::Optional<ArrayRef<int64_t>> getOptionalShape() const {
|
||||
if (rank < 0)
|
||||
return llvm::None;
|
||||
return ArrayRef<int64_t>(shapeElements, rank);
|
||||
}
|
||||
|
||||
Type dtype;
|
||||
int rank;
|
||||
const int64_t *shapeElements;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
@ -105,16 +194,21 @@ struct NdArrayTypeStorage : public TypeStorage {
|
|||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
NdArrayType NdArrayType::get(Type dtype) {
|
||||
NdArrayType NdArrayType::get(Type dtype,
|
||||
llvm::Optional<ArrayRef<int64_t>> shape) {
|
||||
assert(dtype && "dtype cannot be null");
|
||||
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype);
|
||||
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype, shape);
|
||||
}
|
||||
|
||||
bool NdArrayType::hasKnownDtype() {
|
||||
return getDtype() != Basicpy::UnknownType::get(getContext());
|
||||
}
|
||||
|
||||
Type NdArrayType::getDtype() { return getImpl()->optionalDtype; }
|
||||
Type NdArrayType::getDtype() { return getImpl()->dtype; }
|
||||
|
||||
llvm::Optional<ArrayRef<int64_t>> NdArrayType::getOptionalShape() {
|
||||
return getImpl()->getOptionalShape();
|
||||
}
|
||||
|
||||
Typing::CPA::TypeNode *
|
||||
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
|
||||
|
|
|
@ -13,6 +13,6 @@ global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
|
|||
@import_global
|
||||
def global_array_to_const():
|
||||
# CHECK: %[[CST:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
|
||||
# CHECK: numpy.create_array_from_tensor %[[CST]] : (tensor<2x3xf64>) -> !numpy.ndarray<f64>
|
||||
# CHECK: numpy.create_array_from_tensor %[[CST]] : (tensor<2x3xf64>) -> !numpy.ndarray<*:f64>
|
||||
local_data = global_data
|
||||
return local_data
|
||||
|
|
|
@ -16,7 +16,7 @@ b = np.asarray([3.0, 4.0])
|
|||
# Test the basic flow of invoking a ufunc call with constants captured from
|
||||
# a global using explicit function syntax (np.add(a, b)).
|
||||
# CHECK-LABEL: func @global_add
|
||||
# CHECK-SAME: -> !numpy.ndarray<f64>
|
||||
# CHECK-SAME: -> !numpy.ndarray<*:f64>
|
||||
@import_global
|
||||
def global_add():
|
||||
# CHECK-NOT: UnknownType
|
||||
|
|
|
@ -25,7 +25,7 @@ def global_add():
|
|||
# CHECK-DAG: %[[A:.*]] = numpy.copy_to_tensor %[[A_ARRAY]]
|
||||
# CHECK-DAG: %[[B:.*]] = numpy.copy_to_tensor %[[B_ARRAY]]
|
||||
# CHECK: %[[R_TENSOR:.*]] = numpy.builtin_ufunc_call<"numpy.add"> (%[[A]], %[[B]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*x!basicpy.UnknownType>
|
||||
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<?>
|
||||
# CHECK: numpy.create_array_from_tensor %[[R_TENSOR]] : (tensor<*x!basicpy.UnknownType>) -> !numpy.ndarray<*:?>
|
||||
return np.add(a, b)
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
// RUN: npcomp-opt -split-input-file %s | npcomp-opt | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: @ndarray_no_dtype
|
||||
// CHECK: !numpy.ndarray<?>
|
||||
func @ndarray_no_dtype(%arg0 : !numpy.ndarray<?>) -> !numpy.ndarray<?> {
|
||||
return %arg0 : !numpy.ndarray<?>
|
||||
// CHECK: !numpy.ndarray<*:?>
|
||||
func @ndarray_no_dtype(%arg0 : !numpy.ndarray<*:?>) -> !numpy.ndarray<*:?> {
|
||||
return %arg0 : !numpy.ndarray<*:?>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @ndarray_dtype
|
||||
// CHECK: !numpy.ndarray<i32>
|
||||
func @ndarray_dtype(%arg0 : !numpy.ndarray<i32>) -> !numpy.ndarray<i32> {
|
||||
return %arg0 : !numpy.ndarray<i32>
|
||||
// CHECK: !numpy.ndarray<*:i32>
|
||||
func @ndarray_dtype(%arg0 : !numpy.ndarray<*:i32>) -> !numpy.ndarray<*:i32> {
|
||||
return %arg0 : !numpy.ndarray<*:i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @ndarray_ranked
|
||||
// CHECK: !numpy.ndarray<[1,?,3]:i32>
|
||||
func @ndarray_ranked(%arg0 : !numpy.ndarray<[1,?,3]:i32>) -> !numpy.ndarray<[1,?,3]:i32> {
|
||||
return %arg0 : !numpy.ndarray<[1,?,3]:i32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue