diff --git a/lib/Dialect/Basicpy/Transforms/TypeInference.cpp b/lib/Dialect/Basicpy/Transforms/TypeInference.cpp index 2fcbc7da2..859cff54f 100644 --- a/lib/Dialect/Basicpy/Transforms/TypeInference.cpp +++ b/lib/Dialect/Basicpy/Transforms/TypeInference.cpp @@ -8,6 +8,8 @@ #include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" @@ -149,6 +151,10 @@ public: /// Print a report of the equations for debugging. void report(raw_ostream &os) { + os << "Type variable map:\n"; + for (auto it : llvm::enumerate(ordinalToVarNode)) { + os << ": " << it.index() << " = " << it.value()->getDef() << "\n"; + } os << "Type equations:\n"; os << "---------------\n"; for (auto &eq : equations) { @@ -158,11 +164,9 @@ public: simple_ilist &getEquations() { return equations; } - void applySubst(unsigned ordinal, TypeNode *resolved) { - assert(ordinal >= 0 && ordinal < ordinalToVarNode.size()); - Type constType = resolved->getConstType(); - TypeNode *varNode = ordinalToVarNode[ordinal]; - varNode->getDef().setType(constType); + TypeNode *lookupVarOrdinal(unsigned ordinal) { + assert(ordinal < ordinalToVarNode.size()); + return ordinalToVarNode[ordinal]; } private: @@ -211,6 +215,23 @@ public: return subst; } + Type resolveSubst(TypeNode *typeNode, const Optional &subst) { + if (!subst) + return nullptr; + if (typeNode->getDiscrim() == TypeNode::Discrim::CONST_TYPE) { + return typeNode->getConstType(); + } + if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) { + auto foundIt = subst->find(typeNode->getVarOrdinal()); + if (foundIt != subst->end()) { + return resolveSubst(foundIt->second, subst); + } else { + return nullptr; + } + } + return nullptr; + } + Optional unify(TypeNode *typeX, TypeNode *typeY, Optional subst) { LLVM_DEBUG(llvm::dbgs() << "+ UNIFY: " << *typeX << ", " << *typeY << "\n"); @@ -233,10 +254,13 @@ public: Optional unifyVariable(TypeNode *varNode, TypeNode *typeNode, SubstMap subst) { assert(varNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL); + LLVM_DEBUG(llvm::dbgs() << " - UNIFY VARIABLE: " << *varNode << " <- " + << *typeNode << "\n"); // Var node in subst? auto it = subst.find(varNode->getVarOrdinal()); if (it != subst.end()) { TypeNode *found = it->second; + LLVM_DEBUG(llvm::dbgs() << " --> FOUND VAR: " << *found << "\n"); return unify(found, typeNode, std::move(subst)); } @@ -245,14 +269,36 @@ public: it = subst.find(typeNode->getVarOrdinal()); if (it != subst.end()) { TypeNode *found = it->second; + LLVM_DEBUG(llvm::dbgs() << " --> FOUND TYPE: " << *found << "\n"); return unify(varNode, found, std::move(subst)); } } + // Does the variable appear in the type? + if (occursCheck(varNode, typeNode, subst)) { + LLVM_DEBUG(llvm::dbgs() << "FAILED OCCURS_CHECK\n"); + return None; + } + // varNode is not yet in subst and cannot simplify typeNode. Extend. subst[varNode->getVarOrdinal()] = typeNode; return std::move(subst); } + + bool occursCheck(TypeNode *varNode, TypeNode *typeNode, SubstMap &subst) { + if (*varNode == *typeNode) + return true; + + if (typeNode->getDiscrim() == TypeNode::Discrim::VAR_ORDINAL) { + unsigned typeOrdinal = typeNode->getVarOrdinal(); + auto foundIt = subst.find(typeOrdinal); + if (foundIt != subst.end()) { + return occursCheck(varNode, foundIt->second, subst); + } + } + + return false; + } }; class TypeEquationPopulator { @@ -262,50 +308,68 @@ public: /// If a return op was visited, this will be one of them. Operation *getLastReturnOp() { return funcReturnOp; } + /// Gets any ReturnLike ops that do not return from the outer function. + /// This is used to fixup parent SCF ops and the like. + llvm::SmallVectorImpl &getInnerReturnLikeOps() { + return innerReturnLikeOps; + } + LogicalResult runOnFunction(FuncOp funcOp) { // Iterate and create type nodes for entry block arguments, as these // must be resolved no matter what. if (funcOp.getBody().empty()) return success(); + auto &entryBlock = funcOp.getBody().front(); for (auto blockArg : entryBlock.getArguments()) { equations.getTypeNode(blockArg); } // Then walk ops, creating equations. + LLVM_DEBUG(llvm::dbgs() << "POPULATE CHILD OPS:\n"); auto result = funcOp.walk([&](Operation *childOp) -> WalkResult { if (childOp == funcOp) return WalkResult::advance(); - - // Trait based equations. - // ---------------------- - // Function returns must all have the same types. - if (childOp->hasTrait() && - childOp->getParentOp() == funcOp) { - if (funcReturnOp) { - if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) { - childOp->emitOpError() << "different arity of function returns"; - return WalkResult::interrupt(); - } - for (auto it : - llvm::zip(funcReturnOp->getOperands(), childOp->getOperands())) { - equations.addTypeEqualityEquation(std::get<0>(it), std::get<1>(it), - childOp); - } - } - funcReturnOp = childOp; - return WalkResult::advance(); - } - - // Ensure that constant nodes get assigned a constant type. - if (childOp->hasTrait()) { - equations.getTypeNode(childOp->getResult(0)); - return WalkResult::advance(); - } - + LLVM_DEBUG(llvm::dbgs() << " + POPULATE: " << *childOp << "\n"); // Special op handling. - // Many of these (that are not standard ops) should become op interfaces. + // Many of these (that are not standard ops) should become op + // interfaces. // -------------------- + if (auto op = dyn_cast(childOp)) { + // Note that the condition is always i1 and not subject to type + // inference. + equations.addTypeEqualityEquation(op.true_value(), op.false_value(), + op); + return WalkResult::advance(); + } + if (auto op = dyn_cast(childOp)) { + // Note that the result is always i1 and not subject to type + // inference. + equations.getTypeNode(op.operand()); + return WalkResult::advance(); + } + if (auto op = dyn_cast(childOp)) { + // Note that the condition is always i1 and not subject to type + // inference. + for (auto result : op.getResults()) { + equations.getTypeNode(result); + } + return WalkResult::advance(); + } + if (auto yieldOp = dyn_cast(childOp)) { + auto scfParentOp = yieldOp.getParentOp(); + if (scfParentOp->getNumResults() != yieldOp.getNumOperands()) { + yieldOp.emitWarning() + << "cannot run type inference on yield due to arity mismatch"; + return WalkResult::advance(); + } + for (auto it : + llvm::zip(scfParentOp->getResults(), yieldOp.getOperands())) { + equations.addTypeEqualityEquation(std::get<1>(it), std::get<0>(it), + yieldOp); + } + return WalkResult::advance(); + } if (auto op = dyn_cast(childOp)) { equations.addTypeEqualityEquation(op.operand(), op.result(), op); return WalkResult::advance(); @@ -318,6 +382,34 @@ public: return WalkResult::advance(); } + // Fallback trait based equations. + // ---------------------- + // Ensure that constant nodes get assigned a constant type. + if (childOp->hasTrait()) { + equations.getTypeNode(childOp->getResult(0)); + return WalkResult::advance(); + } + // Function returns must all have the same types. + if (childOp->hasTrait()) { + if (childOp->getParentOp() == funcOp) { + if (funcReturnOp) { + if (funcReturnOp->getNumOperands() != childOp->getNumOperands()) { + childOp->emitOpError() << "different arity of function returns"; + return WalkResult::interrupt(); + } + for (auto it : llvm::zip(funcReturnOp->getOperands(), + childOp->getOperands())) { + equations.addTypeEqualityEquation(std::get<0>(it), + std::get<1>(it), childOp); + } + } + funcReturnOp = childOp; + return WalkResult::advance(); + } else { + innerReturnLikeOps.push_back(childOp); + } + } + childOp->emitWarning() << "unhandled op in type inference"; return WalkResult::advance(); @@ -329,6 +421,7 @@ public: private: // The last encountered ReturnLike op. Operation *funcReturnOp = nullptr; + llvm::SmallVector innerReturnLikeOps; TypeEquations &equations; }; @@ -359,7 +452,13 @@ public: llvm::dbgs() << " " << it.first << " -> " << *it.second << "\n"; }); for (auto it : *substMap) { - equations.applySubst(it.first, it.second); + TypeNode *varNode = equations.lookupVarOrdinal(it.first); + Type resolvedType = unifier.resolveSubst(it.second, substMap); + if (!resolvedType) { + emitError(varNode->getDef().getLoc()) << "unable to infer type"; + continue; + } + varNode->getDef().setType(resolvedType); } // Now rewrite the function type based on actual types of entry block diff --git a/pytest/Compiler/type_inference.py b/pytest/Compiler/type_inference.py index 2c1a3239a..b75fdda51 100644 --- a/pytest/Compiler/type_inference.py +++ b/pytest/Compiler/type_inference.py @@ -30,3 +30,10 @@ def arg_inference(a, b): # CHECK: basicpy.unknown_cast{{.*}} : i64 -> i64 # CHECK: return{{.*}} : i64 return a + 2 * b + +# CHECK-LABEL: func @conditional_inference +# CHECK-SAME: (%arg0: i64, %arg1: !basicpy.BoolType, %arg2: i64) -> !basicpy.BoolType +@import_global +def conditional_inference(cond, a, b): + # CHECK-NOT: UnknownType + return a if cond + 1 else not(b * 4) diff --git a/python/samples/type_inference.py b/python/samples/type_inference.py deleted file mode 100644 index 7f2606d76..000000000 --- a/python/samples/type_inference.py +++ /dev/null @@ -1,12 +0,0 @@ -from npcomp.compiler.frontend import * - - -def import_global(f): - fe = ImportFrontend() - fe.import_global_function(f) - print(fe.ir_module.to_asm()) - return f - -@import_global -def arithmetic_expression(): - return 1 + 2 - 3 * 4