mirror of https://github.com/llvm/torch-mlir
Extend type inference so that it works across conditional boundaries.
* The implementation is still limited but gives something to build on.pull/1/head
parent
c84ce17573
commit
750541e9a9
|
@ -8,6 +8,8 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
|
@ -149,6 +151,10 @@ public:
|
|||
|
||||
/// Print a report of the equations for debugging.
|
||||
void report(raw_ostream &os) {
|
||||
os << "Type variable map:\n";
|
||||
for (auto it : llvm::enumerate(ordinalToVarNode)) {
|
||||
os << ": " << it.index() << " = " << it.value()->getDef() << "\n";
|
||||
}
|
||||
os << "Type equations:\n";
|
||||
os << "---------------\n";
|
||||
for (auto &eq : equations) {
|
||||
|
@ -158,11 +164,9 @@ public:
|
|||
|
||||
simple_ilist<TypeEquation> &getEquations() { return equations; }
|
||||
|
||||
void applySubst(unsigned ordinal, TypeNode *resolved) {
|
||||
assert(ordinal >= 0 && ordinal < ordinalToVarNode.size());
|
||||
Type constType = resolved->getConstType();
|
||||
TypeNode *varNode = ordinalToVarNode[ordinal];
|
||||
varNode->getDef().setType(constType);
|
||||
TypeNode *lookupVarOrdinal(unsigned ordinal) {
|
||||
assert(ordinal < ordinalToVarNode.size());
|
||||
return ordinalToVarNode[ordinal];
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -211,6 +215,23 @@ public:
|
|||
return subst;
|
||||
}
|
||||
|
||||
Type resolveSubst(TypeNode *typeNode, const Optional<SubstMap> &subst) {
|
||||
if (!subst)
|
||||
return nullptr;
|
||||
if (typeNode->getDiscrim() == TypeNode::Discrim::CONST_TYPE) {
|
||||
return typeNode->getConstType();
|
||||
}
|
||||
if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) {
|
||||
auto foundIt = subst->find(typeNode->getVarOrdinal());
|
||||
if (foundIt != subst->end()) {
|
||||
return resolveSubst(foundIt->second, subst);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Optional<SubstMap> unify(TypeNode *typeX, TypeNode *typeY,
|
||||
Optional<SubstMap> subst) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "+ UNIFY: " << *typeX << ", " << *typeY << "\n");
|
||||
|
@ -233,10 +254,13 @@ public:
|
|||
Optional<SubstMap> unifyVariable(TypeNode *varNode, TypeNode *typeNode,
|
||||
SubstMap subst) {
|
||||
assert(varNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL);
|
||||
LLVM_DEBUG(llvm::dbgs() << " - UNIFY VARIABLE: " << *varNode << " <- "
|
||||
<< *typeNode << "\n");
|
||||
// Var node in subst?
|
||||
auto it = subst.find(varNode->getVarOrdinal());
|
||||
if (it != subst.end()) {
|
||||
TypeNode *found = it->second;
|
||||
LLVM_DEBUG(llvm::dbgs() << " --> FOUND VAR: " << *found << "\n");
|
||||
return unify(found, typeNode, std::move(subst));
|
||||
}
|
||||
|
||||
|
@ -245,14 +269,36 @@ public:
|
|||
it = subst.find(typeNode->getVarOrdinal());
|
||||
if (it != subst.end()) {
|
||||
TypeNode *found = it->second;
|
||||
LLVM_DEBUG(llvm::dbgs() << " --> FOUND TYPE: " << *found << "\n");
|
||||
return unify(varNode, found, std::move(subst));
|
||||
}
|
||||
}
|
||||
|
||||
// Does the variable appear in the type?
|
||||
if (occursCheck(varNode, typeNode, subst)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "FAILED OCCURS_CHECK\n");
|
||||
return None;
|
||||
}
|
||||
|
||||
// varNode is not yet in subst and cannot simplify typeNode. Extend.
|
||||
subst[varNode->getVarOrdinal()] = typeNode;
|
||||
return std::move(subst);
|
||||
}
|
||||
|
||||
bool occursCheck(TypeNode *varNode, TypeNode *typeNode, SubstMap &subst) {
|
||||
if (*varNode == *typeNode)
|
||||
return true;
|
||||
|
||||
if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) {
|
||||
unsigned typeOrdinal = typeNode->getVarOrdinal();
|
||||
auto foundIt = subst.find(typeOrdinal);
|
||||
if (foundIt != subst.end()) {
|
||||
return occursCheck(varNode, foundIt->second, subst);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class TypeEquationPopulator {
|
||||
|
@ -262,50 +308,68 @@ public:
|
|||
/// If a return op was visited, this will be one of them.
|
||||
Operation *getLastReturnOp() { return funcReturnOp; }
|
||||
|
||||
/// Gets any ReturnLike ops that do not return from the outer function.
|
||||
/// This is used to fixup parent SCF ops and the like.
|
||||
llvm::SmallVectorImpl<Operation *> &getInnerReturnLikeOps() {
|
||||
return innerReturnLikeOps;
|
||||
}
|
||||
|
||||
LogicalResult runOnFunction(FuncOp funcOp) {
|
||||
// Iterate and create type nodes for entry block arguments, as these
|
||||
// must be resolved no matter what.
|
||||
if (funcOp.getBody().empty())
|
||||
return success();
|
||||
|
||||
auto &entryBlock = funcOp.getBody().front();
|
||||
for (auto blockArg : entryBlock.getArguments()) {
|
||||
equations.getTypeNode(blockArg);
|
||||
}
|
||||
|
||||
// Then walk ops, creating equations.
|
||||
LLVM_DEBUG(llvm::dbgs() << "POPULATE CHILD OPS:\n");
|
||||
auto result = funcOp.walk([&](Operation *childOp) -> WalkResult {
|
||||
if (childOp == funcOp)
|
||||
return WalkResult::advance();
|
||||
|
||||
// Trait based equations.
|
||||
// ----------------------
|
||||
// Function returns must all have the same types.
|
||||
if (childOp->hasTrait<OpTrait::ReturnLike>() &&
|
||||
childOp->getParentOp() == funcOp) {
|
||||
if (funcReturnOp) {
|
||||
if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) {
|
||||
childOp->emitOpError() << "different arity of function returns";
|
||||
return WalkResult::interrupt();
|
||||
LLVM_DEBUG(llvm::dbgs() << " + POPULATE: " << *childOp << "\n");
|
||||
// Special op handling.
|
||||
// Many of these (that are not standard ops) should become op
|
||||
// interfaces.
|
||||
// --------------------
|
||||
if (auto op = dyn_cast<SelectOp>(childOp)) {
|
||||
// Note that the condition is always i1 and not subject to type
|
||||
// inference.
|
||||
equations.addTypeEqualityEquation(op.true_value(), op.false_value(),
|
||||
op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<ToBooleanOp>(childOp)) {
|
||||
// Note that the result is always i1 and not subject to type
|
||||
// inference.
|
||||
equations.getTypeNode(op.operand());
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<scf::IfOp>(childOp)) {
|
||||
// Note that the condition is always i1 and not subject to type
|
||||
// inference.
|
||||
for (auto result : op.getResults()) {
|
||||
equations.getTypeNode(result);
|
||||
}
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(childOp)) {
|
||||
auto scfParentOp = yieldOp.getParentOp();
|
||||
if (scfParentOp->getNumResults() != yieldOp.getNumOperands()) {
|
||||
yieldOp.emitWarning()
|
||||
<< "cannot run type inference on yield due to arity mismatch";
|
||||
return WalkResult::advance();
|
||||
}
|
||||
for (auto it :
|
||||
llvm::zip(funcReturnOp->getOperands(), childOp->getOperands())) {
|
||||
equations.addTypeEqualityEquation(std::get<0>(it), std::get<1>(it),
|
||||
childOp);
|
||||
llvm::zip(scfParentOp->getResults(), yieldOp.getOperands())) {
|
||||
equations.addTypeEqualityEquation(std::get<1>(it), std::get<0>(it),
|
||||
yieldOp);
|
||||
}
|
||||
}
|
||||
funcReturnOp = childOp;
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
// Ensure that constant nodes get assigned a constant type.
|
||||
if (childOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
equations.getTypeNode(childOp->getResult(0));
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
// Special op handling.
|
||||
// Many of these (that are not standard ops) should become op interfaces.
|
||||
// --------------------
|
||||
if (auto op = dyn_cast<UnknownCastOp>(childOp)) {
|
||||
equations.addTypeEqualityEquation(op.operand(), op.result(), op);
|
||||
return WalkResult::advance();
|
||||
|
@ -318,6 +382,34 @@ public:
|
|||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
// Fallback trait based equations.
|
||||
// ----------------------
|
||||
// Ensure that constant nodes get assigned a constant type.
|
||||
if (childOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
equations.getTypeNode(childOp->getResult(0));
|
||||
return WalkResult::advance();
|
||||
}
|
||||
// Function returns must all have the same types.
|
||||
if (childOp->hasTrait<OpTrait::ReturnLike>()) {
|
||||
if (childOp->getParentOp() == funcOp) {
|
||||
if (funcReturnOp) {
|
||||
if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) {
|
||||
childOp->emitOpError() << "different arity of function returns";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
for (auto it : llvm::zip(funcReturnOp->getOperands(),
|
||||
childOp->getOperands())) {
|
||||
equations.addTypeEqualityEquation(std::get<0>(it),
|
||||
std::get<1>(it), childOp);
|
||||
}
|
||||
}
|
||||
funcReturnOp = childOp;
|
||||
return WalkResult::advance();
|
||||
} else {
|
||||
innerReturnLikeOps.push_back(childOp);
|
||||
}
|
||||
}
|
||||
|
||||
childOp->emitWarning() << "unhandled op in type inference";
|
||||
|
||||
return WalkResult::advance();
|
||||
|
@ -329,6 +421,7 @@ public:
|
|||
private:
|
||||
// The last encountered ReturnLike op.
|
||||
Operation *funcReturnOp = nullptr;
|
||||
llvm::SmallVector<Operation *, 4> innerReturnLikeOps;
|
||||
TypeEquations &equations;
|
||||
};
|
||||
|
||||
|
@ -359,7 +452,13 @@ public:
|
|||
llvm::dbgs() << " " << it.first << " -> " << *it.second << "\n";
|
||||
});
|
||||
for (auto it : *substMap) {
|
||||
equations.applySubst(it.first, it.second);
|
||||
TypeNode *varNode = equations.lookupVarOrdinal(it.first);
|
||||
Type resolvedType = unifier.resolveSubst(it.second, substMap);
|
||||
if (!resolvedType) {
|
||||
emitError(varNode->getDef().getLoc()) << "unable to infer type";
|
||||
continue;
|
||||
}
|
||||
varNode->getDef().setType(resolvedType);
|
||||
}
|
||||
|
||||
// Now rewrite the function type based on actual types of entry block
|
||||
|
|
|
@ -30,3 +30,10 @@ def arg_inference(a, b):
|
|||
# CHECK: basicpy.unknown_cast{{.*}} : i64 -> i64
|
||||
# CHECK: return{{.*}} : i64
|
||||
return a + 2 * b
|
||||
|
||||
# CHECK-LABEL: func @conditional_inference
|
||||
# CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType
|
||||
@import_global
|
||||
def conditional_inference(cond, a, b):
|
||||
# CHECK-NOT: UnknownType
|
||||
return a if cond + 1 else not(b * 4)
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
from npcomp.compiler.frontend import *
|
||||
|
||||
|
||||
def import_global(f):
|
||||
fe = ImportFrontend()
|
||||
fe.import_global_function(f)
|
||||
print(fe.ir_module.to_asm())
|
||||
return f
|
||||
|
||||
@import_global
|
||||
def arithmetic_expression():
|
||||
return 1 + 2 - 3 * 4
|
Loading…
Reference in New Issue