//===- NumpyDialect.cpp - Core numpy dialect --------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "mlir/IR/DialectImplementation.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyOps.h" #include "npcomp/Typing/Support/CPAIrHelpers.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Numpy; void NumpyDialect::initialize() { addOperations< #define GET_OP_LIST #include "npcomp/Dialect/Numpy/IR/NumpyOps.cpp.inc" >(); addTypes(); getContext()->loadDialect(); } Type NumpyDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) return Type(); if (keyword == "any_dtype") return AnyDtypeType::get(getContext()); if (keyword == "ndarray") { // Parse: // ndarray<*:?> // ndarray<*:i32> // ndarary<[1,2,3]:i32> // Note that this is a different syntax than the built-ins as the dialect // parser is not general enough to parse a dimension list with an optional // element type (?). The built-in form is also remarkably ambiguous when // considering extending it. Type dtype = Basicpy::UnknownType::get(getContext()); bool hasShape = false; llvm::SmallVector shape; if (parser.parseLess()) return Type(); if (succeeded(parser.parseOptionalStar())) { // Unranked. } else { // Parse dimension list. hasShape = true; if (parser.parseLSquare()) return Type(); for (bool first = true;; first = false) { if (!first) { if (failed(parser.parseOptionalComma())) { break; } } if (succeeded(parser.parseOptionalQuestion())) { shape.push_back(-1); continue; } int64_t dim; auto optionalPr = parser.parseOptionalInteger(dim); if (optionalPr.hasValue()) { if (failed(*optionalPr)) return Type(); shape.push_back(dim); continue; } break; } if (parser.parseRSquare()) { return Type(); } } // Parse colon dtype. if (parser.parseColon()) { return Type(); } if (failed(parser.parseOptionalQuestion())) { // Specified dtype. if (parser.parseType(dtype)) { return Type(); } } if (parser.parseGreater()) { return Type(); } llvm::Optional> optionalShape; if (hasShape) optionalShape = shape; auto ndarray = NdArrayType::get(dtype, optionalShape); return ndarray; } parser.emitError(parser.getNameLoc(), "unknown numpy type: ") << keyword; return Type(); } void NumpyDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "any_dtype"; }) .Case([&](NdArrayType t) { auto unknownType = Basicpy::UnknownType::get(getContext()); auto ndarray = type.cast(); auto shape = ndarray.getOptionalShape(); auto dtype = ndarray.getDtype(); os << "ndarray<"; if (!shape) { os << "*:"; } else { os << "["; for (auto it : llvm::enumerate(*shape)) { if (it.index() > 0) os << ","; if (it.value() < 0) os << "?"; else os << it.value(); } os << "]:"; } if (dtype != unknownType) os.printType(dtype); else os << "?"; os << ">"; }) .Default([&](Type) { llvm_unreachable("unexpected 'numpy' type kind"); }); } //----------------------------------------------------------------------------// // Type and attribute detail //----------------------------------------------------------------------------// namespace mlir { namespace NPCOMP { namespace Numpy { namespace detail { struct NdArrayTypeStorage : public TypeStorage { using KeyTy = std::pair>>; NdArrayTypeStorage(Type dtype, int rank, const int64_t *shapeElements) : dtype(dtype), rank(rank), shapeElements(shapeElements) {} bool operator==(const KeyTy &key) const { return key == KeyTy(dtype, getOptionalShape()); } static llvm::hash_code hashKey(const KeyTy &key) { if (key.second) { return llvm::hash_combine(key.first, *key.second); } else { return llvm::hash_combine(key.first, -1); } } static NdArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { int rank = -1; const int64_t *shapeElements = nullptr; if (key.second.hasValue()) { auto allocElements = allocator.copyInto(*key.second); rank = key.second->size(); shapeElements = allocElements.data(); } return new (allocator.allocate()) NdArrayTypeStorage(key.first, rank, shapeElements); } llvm::Optional> getOptionalShape() const { if (rank < 0) return llvm::None; return ArrayRef(shapeElements, rank); } Type dtype; int rank; const int64_t *shapeElements; }; } // namespace detail } // namespace Numpy } // namespace NPCOMP } // namespace mlir NdArrayType NdArrayType::get(Type dtype, llvm::Optional> shape) { assert(dtype && "dtype cannot be null"); return Base::get(dtype.getContext(), dtype, shape); } NdArrayType NdArrayType::getFromShapedType(ShapedType shapedType) { llvm::Optional> shape; if (shapedType.hasRank()) shape = shapedType.getShape(); return get(shapedType.getElementType(), shape); } bool NdArrayType::hasKnownDtype() { return getDtype() != Basicpy::UnknownType::get(getContext()); } Type NdArrayType::getDtype() { return getImpl()->dtype; } llvm::Optional> NdArrayType::getOptionalShape() { return getImpl()->getOptionalShape(); } TensorType NdArrayType::toTensorType() { auto shape = getOptionalShape(); if (shape) { return RankedTensorType::get(*shape, getDtype()); } else { return UnrankedTensorType::get(getDtype()); } } Typing::CPA::TypeNode * NdArrayType::mapToCPAType(Typing::CPA::Context &context) { llvm::Optional dtype; if (hasKnownDtype()) { // TODO: This should be using a general mechanism for resolving the dtype, // but we don't have that yet, and for NdArray, these must be primitives // anyway. dtype = context.getIRValueType(getDtype()); } // Safe to capture an ArrayRef backed by type storage since it is uniqued. auto optionalShape = getOptionalShape(); auto irCtor = [optionalShape](Typing::CPA::ObjectValueType *ovt, llvm::ArrayRef fieldTypes, MLIRContext *mlirContext, llvm::Optional) { assert(fieldTypes.size() == 1); return NdArrayType::get(fieldTypes.front(), optionalShape); }; return Typing::CPA::newArrayType(context, irCtor, context.getIdentifier("!NdArray"), dtype); }