mirror of https://github.com/llvm/torch-mlir
Add NdArray type inference conversion.
parent
4a2f7c0b5f
commit
34861b18f4
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "npcomp/Dialect/Common.h"
|
||||
#include "npcomp/Typing/CPA/Interfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
@ -43,12 +44,17 @@ public:
|
|||
};
|
||||
|
||||
class NdArrayType
|
||||
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage> {
|
||||
: public Type::TypeBase<NdArrayType, Type, detail::NdArrayTypeStorage,
|
||||
Typing::CPA::TypeMapInterface::Trait> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == NumpyTypes::NdArray; }
|
||||
static NdArrayType get(Type optionalDtype);
|
||||
bool hasKnownDtype();
|
||||
Type getDtype();
|
||||
|
||||
// CPA::TypeMapInterface methods.
|
||||
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
|
||||
};
|
||||
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOpsDialect.h.inc"
|
||||
|
|
|
@ -269,15 +269,16 @@ public:
|
|||
|
||||
private:
|
||||
ObjectValueType(Identifier *typeIdentifier, size_t fieldCount,
|
||||
Identifier **fieldIdentifiers, TypeNode **fieldTypes)
|
||||
Identifier *const *fieldIdentifiers,
|
||||
TypeNode *const *fieldTypes)
|
||||
// TODO: Real hashcode.
|
||||
: ValueType(Kind::ObjectValueType, 0), typeIdentifier(typeIdentifier),
|
||||
fieldCount(fieldCount), fieldIdentifiers(fieldIdentifiers),
|
||||
fieldTypes(fieldTypes) {}
|
||||
Identifier *typeIdentifier;
|
||||
size_t fieldCount;
|
||||
Identifier **fieldIdentifiers;
|
||||
TypeNode **fieldTypes;
|
||||
Identifier *const *fieldIdentifiers;
|
||||
TypeNode *const *fieldTypes;
|
||||
friend class Context;
|
||||
};
|
||||
|
||||
|
@ -385,6 +386,11 @@ private:
|
|||
/// analysis.
|
||||
class Context {
|
||||
public:
|
||||
Context();
|
||||
|
||||
/// Gets the current environment (roughly call scope).
|
||||
Environment *getCurrentEnvironment() { return currentEnvironment; }
|
||||
|
||||
/// Maps an IR Type to a CPA TypeNode.
|
||||
/// This is currently not overridable but a hook may need to be provided
|
||||
/// eventually.
|
||||
|
@ -395,6 +401,7 @@ public:
|
|||
TypeVar *newTypeVar() {
|
||||
TypeVar *tv = allocator.Allocate<TypeVar>(1);
|
||||
new (tv) TypeVar(++typeVarCounter);
|
||||
currentEnvironment->getTypeVars().insert(tv);
|
||||
return tv;
|
||||
}
|
||||
|
||||
|
@ -419,19 +426,28 @@ public:
|
|||
|
||||
/// Creates a new ObjectValueType.
|
||||
/// Object value types are not uniqued.
|
||||
// ObjectValueType *
|
||||
// newObjectValueType(Identifier *typeIdentifier,
|
||||
// llvm::ArrayRef<Identifier *> fieldIdentifiers) {
|
||||
// size_t n = fieldIdentifiers.size();
|
||||
// Identifier **allocFieldIdentifiers = allocator.Allocate<Identifier *>(n);
|
||||
// std::copy_n(fieldIdentifiers.begin(), n, allocFieldIdentifiers);
|
||||
// TypeNode **allocFieldTypes = allocator.Allocate<TypeNode *>(n);
|
||||
// std::fill_n(allocFieldTypes, n, nullptr);
|
||||
// auto *ovt = allocator.Allocate<ObjectValueType>(1);
|
||||
// new (ovt) ObjectValueType(typeIdentifier, n, allocFieldIdentifiers,
|
||||
// allocFieldTypes);
|
||||
// return ovt;
|
||||
// }
|
||||
ObjectValueType *
|
||||
newObjectValueType(Identifier *typeIdentifier,
|
||||
llvm::ArrayRef<Identifier *> fieldIdentifiers,
|
||||
llvm::ArrayRef<TypeNode *> fieldTypes) {
|
||||
assert(fieldIdentifiers.size() == fieldTypes.size());
|
||||
size_t n = fieldIdentifiers.size();
|
||||
|
||||
Identifier **allocFieldIdentifiers = allocator.Allocate<Identifier *>(n);
|
||||
std::copy_n(fieldIdentifiers.begin(), n, allocFieldIdentifiers);
|
||||
TypeNode **allocFieldTypes = allocator.Allocate<TypeNode *>(n);
|
||||
std::copy_n(fieldTypes.begin(), n, allocFieldTypes);
|
||||
auto *ovt = allocator.Allocate<ObjectValueType>(1);
|
||||
new (ovt) ObjectValueType(typeIdentifier, n, allocFieldIdentifiers,
|
||||
allocFieldTypes);
|
||||
return ovt;
|
||||
}
|
||||
|
||||
/// Creates an array object type with a possibly unknown element type.
|
||||
/// By convention, arrays have a single type slot for the element type
|
||||
/// named 'e'.
|
||||
ObjectValueType *newArrayType(Identifier *typeIdentifier,
|
||||
llvm::Optional<TypeNode *> elementType);
|
||||
|
||||
/// Gets a CastType.
|
||||
CastType *getCastType(Identifier *typeIdentifier, TypeVar *typeVar) {
|
||||
|
@ -460,6 +476,7 @@ public:
|
|||
new (av) Constraint(v); // Copy ctor
|
||||
*it.first = av; // Replace key pointer with durable allocation.
|
||||
addConstraintToGraph(av);
|
||||
currentEnvironment->getConstraints().insert(av);
|
||||
return av;
|
||||
}
|
||||
|
||||
|
@ -514,6 +531,9 @@ private:
|
|||
llvm::DenseSet<Constraint *, Constraint::PtrInfo> constraintUniquer;
|
||||
int typeVarCounter = 0;
|
||||
|
||||
// Singletons created for the context.
|
||||
Identifier *arrayElementIdent;
|
||||
|
||||
// Graph management.
|
||||
llvm::DenseMap<TypeNode *, ConstraintSet> fwdNodeToConstraintMap;
|
||||
llvm::DenseMap<Constraint *, TypeNodeSet> fwdConstraintToNodeMap;
|
||||
|
@ -526,6 +546,10 @@ private:
|
|||
/// Constraints that are pending propagation.
|
||||
ConstraintSet pendingConstraints;
|
||||
ConstraintSet pendingConstraintWorklist;
|
||||
|
||||
// Environment management.
|
||||
std::vector<std::unique_ptr<Environment>> environmentStack;
|
||||
Environment *currentEnvironment;
|
||||
};
|
||||
|
||||
inline bool TypeNode::operator==(const TypeNode &that) const {
|
||||
|
|
|
@ -51,7 +51,6 @@ public:
|
|||
auto subVt = resolveValueType(subValue);
|
||||
CPA::Constraint *c = env.getContext().getConstraint(superVt, subVt);
|
||||
c->setContextOp(contextOp);
|
||||
env.getConstraints().insert(c);
|
||||
}
|
||||
|
||||
LogicalResult runOnFunction(FuncOp funcOp) {
|
||||
|
@ -181,7 +180,7 @@ public:
|
|||
return;
|
||||
|
||||
CPA::Context cpaContext;
|
||||
CPA::Environment env(cpaContext);
|
||||
auto &env = *cpaContext.getCurrentEnvironment();
|
||||
|
||||
InitialConstraintGenerator p(env);
|
||||
p.runOnFunction(func);
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
|
||||
NumpyDialect::NumpyDialect(MLIRContext *context)
|
||||
|
@ -107,4 +108,20 @@ NdArrayType NdArrayType::get(Type dtype) {
|
|||
return Base::get(dtype.getContext(), NumpyTypes::NdArray, dtype);
|
||||
}
|
||||
|
||||
bool NdArrayType::hasKnownDtype() {
|
||||
return getDtype() != Basicpy::UnknownType::get(getContext());
|
||||
}
|
||||
|
||||
Type NdArrayType::getDtype() { return getImpl()->optionalDtype; }
|
||||
|
||||
Typing::CPA::TypeNode *
|
||||
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
|
||||
llvm::Optional<Typing::CPA::TypeNode *> dtype;
|
||||
if (hasKnownDtype()) {
|
||||
// TODO: This should be using a general mechanism for resolving the dtype,
|
||||
// but we don't have that yet, and for NdArray, these must be primitives
|
||||
// anyway.
|
||||
dtype = context.getIRValueType(getDtype());
|
||||
}
|
||||
return context.newArrayType(context.getIdentifier("!NdArray"), dtype);
|
||||
}
|
||||
|
|
|
@ -66,6 +66,25 @@ TypeNode *Environment::mapValueToType(Value value) {
|
|||
// Context
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Context::Context() {
|
||||
environmentStack.emplace_back(std::make_unique<Environment>(*this));
|
||||
currentEnvironment = environmentStack.back().get();
|
||||
arrayElementIdent = getIdentifier("e");
|
||||
}
|
||||
|
||||
ObjectValueType *Context::newArrayType(Identifier *typeIdentifier,
|
||||
llvm::Optional<TypeNode *> elementType) {
|
||||
TypeNode *concreteElementType;
|
||||
if (elementType) {
|
||||
concreteElementType = *elementType;
|
||||
} else {
|
||||
concreteElementType = newTypeVar();
|
||||
}
|
||||
|
||||
return newObjectValueType(typeIdentifier, {arrayElementIdent},
|
||||
{concreteElementType});
|
||||
}
|
||||
|
||||
TypeNode *Context::mapIrType(::mlir::Type irType) {
|
||||
// First, see if the type knows how to map itself.
|
||||
assert(irType);
|
||||
|
@ -190,7 +209,7 @@ void IRValueType::print(Context &context, raw_ostream &os, bool brief) {
|
|||
}
|
||||
|
||||
void ObjectValueType::print(Context &context, raw_ostream &os, bool brief) {
|
||||
os << "object(" << *typeIdentifier;
|
||||
os << "object(" << *typeIdentifier << ",[";
|
||||
bool first = true;
|
||||
for (auto it : llvm::zip(getFieldIdentifiers(), getFieldTypes())) {
|
||||
if (!first)
|
||||
|
@ -204,7 +223,7 @@ void ObjectValueType::print(Context &context, raw_ostream &os, bool brief) {
|
|||
else
|
||||
os << "NULL";
|
||||
}
|
||||
os << ")";
|
||||
os << "])";
|
||||
}
|
||||
|
||||
void Constraint::print(Context &context, raw_ostream &os, bool brief) {
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-cpa-type-inference | FileCheck %s --dump-input=fail
|
||||
|
||||
import numpy as np
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.frontend import EmittedError
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
|
||||
(2, 1)))
|
||||
|
||||
a = np.asarray([1.0, 2.0])
|
||||
b = np.asarray([3.0, 4.0])
|
||||
|
||||
|
||||
# Test the basic flow of invoking a ufunc call with constants captured from
|
||||
# a global using explicit function syntax (np.add(a, b)).
|
||||
# CHECK-LABEL: func @global_add
|
||||
@import_global
|
||||
def global_add():
|
||||
return np.add(a, b)
|
Loading…
Reference in New Issue