Extend type inference so that it works across conditional boundaries.

* The implementation is still limited but gives something to build on.
pull/1/head
Stella Laurenzo 2020-06-10 21:33:17 -07:00
parent c84ce17573
commit 750541e9a9
3 changed files with 140 additions and 46 deletions

View File

@ -8,6 +8,8 @@
#include "PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
@ -149,6 +151,10 @@ public:
/// Print a report of the equations for debugging.
void report(raw_ostream &os) {
os << "Type variable map:\n";
for (auto it : llvm::enumerate(ordinalToVarNode)) {
os << ": " << it.index() << " = " << it.value()->getDef() << "\n";
}
os << "Type equations:\n";
os << "---------------\n";
for (auto &eq : equations) {
@ -158,11 +164,9 @@ public:
simple_ilist<TypeEquation> &getEquations() { return equations; }
void applySubst(unsigned ordinal, TypeNode *resolved) {
assert(ordinal >= 0 && ordinal < ordinalToVarNode.size());
Type constType = resolved->getConstType();
TypeNode *varNode = ordinalToVarNode[ordinal];
varNode->getDef().setType(constType);
TypeNode *lookupVarOrdinal(unsigned ordinal) {
assert(ordinal < ordinalToVarNode.size());
return ordinalToVarNode[ordinal];
}
private:
@ -211,6 +215,23 @@ public:
return subst;
}
Type resolveSubst(TypeNode *typeNode, const Optional<SubstMap> &subst) {
if (!subst)
return nullptr;
if (typeNode->getDiscrim() == TypeNode::Discrim::CONST_TYPE) {
return typeNode->getConstType();
}
if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) {
auto foundIt = subst->find(typeNode->getVarOrdinal());
if (foundIt != subst->end()) {
return resolveSubst(foundIt->second, subst);
} else {
return nullptr;
}
}
return nullptr;
}
Optional<SubstMap> unify(TypeNode *typeX, TypeNode *typeY,
Optional<SubstMap> subst) {
LLVM_DEBUG(llvm::dbgs() << "+ UNIFY: " << *typeX << ", " << *typeY << "\n");
@ -233,10 +254,13 @@ public:
Optional<SubstMap> unifyVariable(TypeNode *varNode, TypeNode *typeNode,
SubstMap subst) {
assert(varNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL);
LLVM_DEBUG(llvm::dbgs() << " - UNIFY VARIABLE: " << *varNode << " <- "
<< *typeNode << "\n");
// Var node in subst?
auto it = subst.find(varNode->getVarOrdinal());
if (it != subst.end()) {
TypeNode *found = it->second;
LLVM_DEBUG(llvm::dbgs() << " --> FOUND VAR: " << *found << "\n");
return unify(found, typeNode, std::move(subst));
}
@ -245,14 +269,36 @@ public:
it = subst.find(typeNode->getVarOrdinal());
if (it != subst.end()) {
TypeNode *found = it->second;
LLVM_DEBUG(llvm::dbgs() << " --> FOUND TYPE: " << *found << "\n");
return unify(varNode, found, std::move(subst));
}
}
// Does the variable appear in the type?
if (occursCheck(varNode, typeNode, subst)) {
LLVM_DEBUG(llvm::dbgs() << "FAILED OCCURS_CHECK\n");
return None;
}
// varNode is not yet in subst and cannot simplify typeNode. Extend.
subst[varNode->getVarOrdinal()] = typeNode;
return std::move(subst);
}
bool occursCheck(TypeNode *varNode, TypeNode *typeNode, SubstMap &subst) {
if (*varNode == *typeNode)
return true;
if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) {
unsigned typeOrdinal = typeNode->getVarOrdinal();
auto foundIt = subst.find(typeOrdinal);
if (foundIt != subst.end()) {
return occursCheck(varNode, foundIt->second, subst);
}
}
return false;
}
};
class TypeEquationPopulator {
@ -262,50 +308,68 @@ public:
/// If a return op was visited, this will be one of them.
Operation *getLastReturnOp() { return funcReturnOp; }
/// Gets any ReturnLike ops that do not return from the outer function.
/// This is used to fixup parent SCF ops and the like.
llvm::SmallVectorImpl<Operation *> &getInnerReturnLikeOps() {
return innerReturnLikeOps;
}
LogicalResult runOnFunction(FuncOp funcOp) {
// Iterate and create type nodes for entry block arguments, as these
// must be resolved no matter what.
if (funcOp.getBody().empty())
return success();
auto &entryBlock = funcOp.getBody().front();
for (auto blockArg : entryBlock.getArguments()) {
equations.getTypeNode(blockArg);
}
// Then walk ops, creating equations.
LLVM_DEBUG(llvm::dbgs() << "POPULATE CHILD OPS:\n");
auto result = funcOp.walk([&](Operation *childOp) -> WalkResult {
if (childOp == funcOp)
return WalkResult::advance();
// Trait based equations.
// ----------------------
// Function returns must all have the same types.
if (childOp->hasTrait<OpTrait::ReturnLike>() &&
childOp->getParentOp() == funcOp) {
if (funcReturnOp) {
if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) {
childOp->emitOpError() << "different arity of function returns";
return WalkResult::interrupt();
LLVM_DEBUG(llvm::dbgs() << " + POPULATE: " << *childOp << "\n");
// Special op handling.
// Many of these (that are not standard ops) should become op
// interfaces.
// --------------------
if (auto op = dyn_cast<SelectOp>(childOp)) {
// Note that the condition is always i1 and not subject to type
// inference.
equations.addTypeEqualityEquation(op.true_value(), op.false_value(),
op);
return WalkResult::advance();
}
if (auto op = dyn_cast<ToBooleanOp>(childOp)) {
// Note that the result is always i1 and not subject to type
// inference.
equations.getTypeNode(op.operand());
return WalkResult::advance();
}
if (auto op = dyn_cast<scf::IfOp>(childOp)) {
// Note that the condition is always i1 and not subject to type
// inference.
for (auto result : op.getResults()) {
equations.getTypeNode(result);
}
return WalkResult::advance();
}
if (auto yieldOp = dyn_cast<scf::YieldOp>(childOp)) {
auto scfParentOp = yieldOp.getParentOp();
if (scfParentOp->getNumResults() != yieldOp.getNumOperands()) {
yieldOp.emitWarning()
<< "cannot run type inference on yield due to arity mismatch";
return WalkResult::advance();
}
for (auto it :
llvm::zip(funcReturnOp->getOperands(), childOp->getOperands())) {
equations.addTypeEqualityEquation(std::get<0>(it), std::get<1>(it),
childOp);
llvm::zip(scfParentOp->getResults(), yieldOp.getOperands())) {
equations.addTypeEqualityEquation(std::get<1>(it), std::get<0>(it),
yieldOp);
}
}
funcReturnOp = childOp;
return WalkResult::advance();
}
// Ensure that constant nodes get assigned a constant type.
if (childOp->hasTrait<OpTrait::ConstantLike>()) {
equations.getTypeNode(childOp->getResult(0));
return WalkResult::advance();
}
// Special op handling.
// Many of these (that are not standard ops) should become op interfaces.
// --------------------
if (auto op = dyn_cast<UnknownCastOp>(childOp)) {
equations.addTypeEqualityEquation(op.operand(), op.result(), op);
return WalkResult::advance();
@ -318,6 +382,34 @@ public:
return WalkResult::advance();
}
// Fallback trait based equations.
// ----------------------
// Ensure that constant nodes get assigned a constant type.
if (childOp->hasTrait<OpTrait::ConstantLike>()) {
equations.getTypeNode(childOp->getResult(0));
return WalkResult::advance();
}
// Function returns must all have the same types.
if (childOp->hasTrait<OpTrait::ReturnLike>()) {
if (childOp->getParentOp() == funcOp) {
if (funcReturnOp) {
if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) {
childOp->emitOpError() << "different arity of function returns";
return WalkResult::interrupt();
}
for (auto it : llvm::zip(funcReturnOp->getOperands(),
childOp->getOperands())) {
equations.addTypeEqualityEquation(std::get<0>(it),
std::get<1>(it), childOp);
}
}
funcReturnOp = childOp;
return WalkResult::advance();
} else {
innerReturnLikeOps.push_back(childOp);
}
}
childOp->emitWarning() << "unhandled op in type inference";
return WalkResult::advance();
@ -329,6 +421,7 @@ public:
private:
// The last encountered ReturnLike op.
Operation *funcReturnOp = nullptr;
llvm::SmallVector<Operation *, 4> innerReturnLikeOps;
TypeEquations &equations;
};
@ -359,7 +452,13 @@ public:
llvm::dbgs() << " " << it.first << " -> " << *it.second << "\n";
});
for (auto it : *substMap) {
equations.applySubst(it.first, it.second);
TypeNode *varNode = equations.lookupVarOrdinal(it.first);
Type resolvedType = unifier.resolveSubst(it.second, substMap);
if (!resolvedType) {
emitError(varNode->getDef().getLoc()) << "unable to infer type";
continue;
}
varNode->getDef().setType(resolvedType);
}
// Now rewrite the function type based on actual types of entry block

View File

@ -30,3 +30,10 @@ def arg_inference(a, b):
# CHECK: basicpy.unknown_cast{{.*}} : i64 -> i64
# CHECK: return{{.*}} : i64
return a + 2 * b
# CHECK-LABEL: func @conditional_inference
# CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType
@import_global
def conditional_inference(cond, a, b):
# CHECK-NOT: UnknownType
return a if cond + 1 else not(b * 4)

View File

@ -1,12 +0,0 @@
from npcomp.compiler.frontend import *
def import_global(f):
fe = ImportFrontend()
fe.import_global_function(f)
print(fe.ir_module.to_asm())
return f
@import_global
def arithmetic_expression():
return 1 + 2 - 3 * 4