Add NdArray type inference conversion.

pull/1/head
Stella Laurenzo 2020-07-03 16:38:10 -07:00
parent 4a2f7c0b5f
commit 34861b18f4
6 changed files with 107 additions and 21 deletions

View File

@ -11,6 +11,7 @@
#include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/Common.h"
#include "npcomp/Typing/CPA/Interfaces.h"
namespace mlir {
namespace NPCOMP {
@ -43,12 +44,17 @@ public:
};
class NdArrayType
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage> {
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage,
Typing::CPA::TypeMapInterface::Trait> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
static NdArrayType get(Type optionalDtype);
bool hasKnownDtype();
Type getDtype();
// CPA::TypeMapInterface methods.
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
};
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"

View File

@ -269,15 +269,16 @@ public:
private:
ObjectValueType(Identifier *typeIdentifier, size_t fieldCount,
Identifier **fieldIdentifiers, TypeNode **fieldTypes)
Identifier *const *fieldIdentifiers,
TypeNode *const *fieldTypes)
// TODO: Real hashcode.
: ValueType(Kind::ObjectValueType, 0), typeIdentifier(typeIdentifier),
fieldCount(fieldCount), fieldIdentifiers(fieldIdentifiers),
fieldTypes(fieldTypes) {}
Identifier *typeIdentifier;
size_t fieldCount;
Identifier **fieldIdentifiers;
TypeNode **fieldTypes;
Identifier *const *fieldIdentifiers;
TypeNode *const *fieldTypes;
friend class Context;
};
@ -385,6 +386,11 @@ private:
/// analysis.
class Context {
public:
Context();
/// Gets the current environment (roughly call scope).
Environment *getCurrentEnvironment() { return currentEnvironment; }
/// Maps an IR Type to a CPA TypeNode.
/// This is currently not overridable but a hook may need to be provided
/// eventually.
@ -395,6 +401,7 @@ public:
TypeVar *newTypeVar() {
TypeVar *tv = allocator.Allocate<TypeVar>(1);
new (tv) TypeVar(++typeVarCounter);
currentEnvironment->getTypeVars().insert(tv);
return tv;
}
@ -419,19 +426,28 @@ public:
/// Creates a new ObjectValueType.
/// Object value types are not uniqued.
// ObjectValueType *
// newObjectValueType(Identifier *typeIdentifier,
// llvm::ArrayRef<Identifier *> fieldIdentifiers) {
// size_t n = fieldIdentifiers.size();
// Identifier **allocFieldIdentifiers = allocator.Allocate<Identifier *>(n);
// std::copy_n(fieldIdentifiers.begin(), n, allocFieldIdentifiers);
// TypeNode **allocFieldTypes = allocator.Allocate<TypeNode *>(n);
// std::fill_n(allocFieldTypes, n, nullptr);
// auto *ovt = allocator.Allocate<ObjectValueType>(1);
// new (ovt) ObjectValueType(typeIdentifier, n, allocFieldIdentifiers,
// allocFieldTypes);
// return ovt;
// }
ObjectValueType *
newObjectValueType(Identifier *typeIdentifier,
llvm::ArrayRef<Identifier *> fieldIdentifiers,
llvm::ArrayRef<TypeNode *> fieldTypes) {
assert(fieldIdentifiers.size() == fieldTypes.size());
size_t n = fieldIdentifiers.size();
Identifier **allocFieldIdentifiers = allocator.Allocate<Identifier *>(n);
std::copy_n(fieldIdentifiers.begin(), n, allocFieldIdentifiers);
TypeNode **allocFieldTypes = allocator.Allocate<TypeNode *>(n);
std::copy_n(fieldTypes.begin(), n, allocFieldTypes);
auto *ovt = allocator.Allocate<ObjectValueType>(1);
new (ovt) ObjectValueType(typeIdentifier, n, allocFieldIdentifiers,
allocFieldTypes);
return ovt;
}
/// Creates an array object type with a possibly unknown element type.
/// By convention, arrays have a single type slot for the element type
/// named 'e'.
ObjectValueType *newArrayType(Identifier *typeIdentifier,
llvm::Optional<TypeNode *> elementType);
/// Gets a CastType.
CastType *getCastType(Identifier *typeIdentifier, TypeVar *typeVar) {
@ -460,6 +476,7 @@ public:
new (av) Constraint(v); // Copy ctor
*it.first = av; // Replace key pointer with durable allocation.
addConstraintToGraph(av);
currentEnvironment->getConstraints().insert(av);
return av;
}
@ -514,6 +531,9 @@ private:
llvm::DenseSet<Constraint *, Constraint::PtrInfo> constraintUniquer;
int typeVarCounter = 0;
// Singletons created for the context.
Identifier *arrayElementIdent;
// Graph management.
llvm::DenseMap<TypeNode *, ConstraintSet> fwdNodeToConstraintMap;
llvm::DenseMap<Constraint *, TypeNodeSet> fwdConstraintToNodeMap;
@ -526,6 +546,10 @@ private:
/// Constraints that are pending propagation.
ConstraintSet pendingConstraints;
ConstraintSet pendingConstraintWorklist;
// Environment management.
std::vector<std::unique_ptr<Environment>> environmentStack;
Environment *currentEnvironment;
};
inline bool TypeNode::operator==(const TypeNode &that) const {

View File

@ -51,7 +51,6 @@ public:
auto subVt = resolveValueType(subValue);
CPA::Constraint *c = env.getContext().getConstraint(superVt, subVt);
c->setContextOp(contextOp);
env.getConstraints().insert(c);
}
LogicalResult runOnFunction(FuncOp funcOp) {
@ -181,7 +180,7 @@ public:
return;
CPA::Context cpaContext;
CPA::Environment env(cpaContext);
auto &env = *cpaContext.getCurrentEnvironment();
InitialConstraintGenerator p(env);
p.runOnFunction(func);

View File

@ -12,6 +12,7 @@
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Numpy;
NumpyDialect::NumpyDialect(MLIRContext *context)
@ -107,4 +108,20 @@ NdArrayType NdArrayType::get(Type dtype) {
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype);
}
bool NdArrayType::hasKnownDtype() {
return getDtype() != Basicpy::UnknownType::get(getContext());
}
Type NdArrayType::getDtype() { return getImpl()->optionalDtype; }
Typing::CPA::TypeNode *
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
llvm::Optional<Typing::CPA::TypeNode *> dtype;
if (hasKnownDtype()) {
// TODO: This should be using a general mechanism for resolving the dtype,
// but we don't have that yet, and for NdArray, these must be primitives
// anyway.
dtype = context.getIRValueType(getDtype());
}
return context.newArrayType(context.getIdentifier("!NdArray"), dtype);
}

View File

@ -66,6 +66,25 @@ TypeNode *Environment::mapValueToType(Value value) {
// Context
//===----------------------------------------------------------------------===//
Context::Context() {
environmentStack.emplace_back(std::make_unique<Environment>(*this));
currentEnvironment = environmentStack.back().get();
arrayElementIdent = getIdentifier("e");
}
ObjectValueType *Context::newArrayType(Identifier *typeIdentifier,
llvm::Optional<TypeNode *> elementType) {
TypeNode *concreteElementType;
if (elementType) {
concreteElementType = *elementType;
} else {
concreteElementType = newTypeVar();
}
return newObjectValueType(typeIdentifier, {arrayElementIdent},
{concreteElementType});
}
TypeNode *Context::mapIrType(::mlir::Type irType) {
// First, see if the type knows how to map itself.
assert(irType);
@ -190,7 +209,7 @@ void IRValueType::print(Context &context, raw_ostream &os, bool brief) {
}
void ObjectValueType::print(Context &context, raw_ostream &os, bool brief) {
os << "object(" << *typeIdentifier;
os << "object(" << *typeIdentifier << ",[";
bool first = true;
for (auto it : llvm::zip(getFieldIdentifiers(), getFieldTypes())) {
if (!first)
@ -204,7 +223,7 @@ void ObjectValueType::print(Context &context, raw_ostream &os, bool brief) {
else
os << "NULL";
}
os << ")";
os << "])";
}
void Constraint::print(Context &context, raw_ostream &os, bool brief) {

View File

@ -0,0 +1,21 @@
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-cpa-type-inference | FileCheck %s --dump-input=fail
import numpy as np
from npcomp.compiler import test_config
from npcomp.compiler.frontend import EmittedError
import_global = test_config.create_import_dump_decorator()
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
(2, 1)))
a = np.asarray([1.0, 2.0])
b = np.asarray([3.0, 4.0])
# Test the basic flow of invoking a ufunc call with constants captured from
# a global using explicit function syntax (np.add(a, b)).
# CHECK-LABEL: func @global_add
@import_global
def global_add():
return np.add(a, b)