More progress on CPA.

* Added transitivity propagation rules.
* Fixed up some copy-n-paste inversions from the old algorithm.
pull/1/head
Stella Laurenzo 2020-07-02 18:56:05 -07:00
parent 74b8bed7e3
commit 1a13c38033
6 changed files with 174 additions and 14 deletions

View File

@ -0,0 +1,54 @@
//===- Algorithm.h - Main algorithm ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Support types and utilities for the Cartesian Product Algorithm for
// Type Inference.
//
// See:
// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.30.8177
// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.129.2756
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_TYPING_CPA_ALGORITHM_H
#define NPCOMP_TYPING_CPA_ALGORITHM_H
#include "npcomp/Typing/CPA/Support.h"
#include "llvm/ADT/DenseSet.h"
namespace mlir {
namespace NPCOMP {
namespace Typing {
namespace CPA {
/// Propagates constraints in an environment.
class PropagationWorklist {
public:
PropagationWorklist(Environment &env);
/// Propagates any current constraints that match the transitivity rule:
/// τv <: t, t <: τ (τv=ValueType, t=TypeVar, τ=TypeBase)
/// Expanding to:
/// τv <: τ
/// (τv=ValueType, t=TypeVar, τ=TypeBase)
void propagateTransitivity();
/// Commits the current round, returning true if any new constraints were
/// added.
bool commit();
private:
Environment &env;
llvm::DenseSet<Constraint *> currentConstraints;
int newConstraintCount = 0;
};
} // namespace CPA
} // namespace Typing
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_TYPING_CPA_ALGORITHM_H

View File

@ -226,7 +226,7 @@ private:
class ValueType : public TypeBase {
public:
using TypeBase::TypeBase;
bool classof(ObjectBase *ob) {
static bool classof(const ObjectBase *ob) {
return ob->getKind() >= Kind::FIRST_VALUE_TYPE &&
ob->getKind() <= Kind::LAST_VALUE_TYPE;
}
@ -238,7 +238,7 @@ public:
IRValueType(mlir::Type irType)
: ValueType(Kind::IRValueType, llvm::hash_combine(irType)),
irType(irType) {}
static bool classof(ObjectBase *ob) {
static bool classof(const ObjectBase *ob) {
return ob->getKind() == Kind::IRValueType;
}
@ -254,7 +254,7 @@ private:
/// Referred to as 'obj(δ, [ li : τi ])'
class ObjectValueType : public ValueType {
public:
static bool classof(ObjectBase *ob) {
static bool classof(const ObjectBase *ob) {
return ob->getKind() == Kind::ObjectValueType;
}
@ -291,8 +291,8 @@ public:
return ob->getKind() == Kind::Constraint;
}
TypeBase *getT1() { return t1; }
TypeBase *getT2() { return t2; }
TypeBase *getLhs() { return t1; }
TypeBase *getRhs() { return t2; }
void setContextOp(Operation *contextOp) { this->contextOp = contextOp; }
@ -334,17 +334,18 @@ private:
/// Referred to as: 'C'
class ConstraintSet : public ObjectBase {
public:
using CollectionTy = llvm::simple_ilist<Constraint>;
static bool classof(ObjectBase *ob) {
return ob->getKind() == Kind::ConstraintSet;
}
llvm::simple_ilist<Constraint> &getConstraints() { return constraints; }
CollectionTy &getContents() { return constraints; }
void print(raw_ostream &os, bool brief = false) override;
private:
ConstraintSet() : ObjectBase(Kind::ConstraintSet){};
llvm::simple_ilist<Constraint> constraints;
CollectionTy constraints;
friend class Context;
};

View File

@ -14,6 +14,7 @@
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
#include "npcomp/Typing/CPA/Algorithm.h"
#include "npcomp/Typing/CPA/Interfaces.h"
#include "npcomp/Typing/CPA/Support.h"
#include "llvm/Support/Debug.h"
@ -50,7 +51,7 @@ public:
auto subVt = resolveValueType(subValue);
CPA::Constraint *c = env.getContext().getConstraint(superVt, subVt);
c->setContextOp(contextOp);
env.getConstraints()->getConstraints().push_back(*c);
env.getConstraints()->getContents().push_back(*c);
}
LogicalResult runOnFunction(FuncOp funcOp) {
@ -78,6 +79,7 @@ public:
// Note that the condition is always i1 and not subject to type
// inference.
addSubtypeConstraint(op.true_value(), op.false_value(), op);
addSubtypeConstraint(op.false_value(), op.true_value(), op);
return WalkResult::advance();
}
if (auto op = dyn_cast<ToBooleanOp>(childOp)) {
@ -108,20 +110,24 @@ public:
return WalkResult::advance();
}
if (auto op = dyn_cast<UnknownCastOp>(childOp)) {
addSubtypeConstraint(op.operand(), op.result(), op);
addSubtypeConstraint(op.result(), op.operand(), op);
// addSubtypeConstraint(op.operand(), op.result(), op);
return WalkResult::advance();
}
if (auto op = dyn_cast<BinaryExprOp>(childOp)) {
// TODO: This should really be applying arithmetic promotion, not
// strict equality.
addSubtypeConstraint(op.left(), op.right(), op);
addSubtypeConstraint(op.left(), op.result(), op);
addSubtypeConstraint(op.result(), op.left(), op);
addSubtypeConstraint(op.result(), op.right(), op);
// addSubtypeConstraint(op.left(), op.right(), op);
// addSubtypeConstraint(op.left(), op.result(), op);
return WalkResult::advance();
}
if (auto op = dyn_cast<BinaryCompareOp>(childOp)) {
// TODO: This should really be applying arithmetic promotion, not
// strict equality.
addSubtypeConstraint(op.left(), op.right(), op);
addSubtypeConstraint(op.right(), op.left(), op);
return WalkResult::advance();
}
@ -142,7 +148,9 @@ public:
}
for (auto it : llvm::zip(funcReturnOp->getOperands(),
childOp->getOperands())) {
addSubtypeConstraint(std::get<0>(it), std::get<1>(it), childOp);
// addSubtypeConstraint(std::get<0>(it), std::get<1>(it),
// childOp);
addSubtypeConstraint(std::get<1>(it), std::get<0>(it), childOp);
}
}
funcReturnOp = childOp;
@ -183,11 +191,18 @@ public:
llvm::errs() << "CONSTRAINTS:\n";
llvm::errs() << "------------\n";
env.getConstraints()->print(llvm::errs());
env.getConstraints()->print(llvm::errs(), true);
llvm::errs() << "TYPEVARS:\n";
llvm::errs() << "\nTYPEVARS:\n";
llvm::errs() << "---------\n";
env.getTypeVars()->print(llvm::errs());
CPA::PropagationWorklist prop(env);
do {
llvm::errs() << "\nPROPAGATE CLOSURE:\n";
llvm::errs() << "------------------\n";
prop.propagateTransitivity();
} while (prop.commit());
}
};

View File

@ -0,0 +1,66 @@
//===- Algorith.cpp - Main algorithm --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Typing/CPA/Algorithm.h"
using namespace mlir::NPCOMP::Typing::CPA;
PropagationWorklist::PropagationWorklist(Environment &env) : env(env) {
auto &contents = env.getConstraints()->getContents();
currentConstraints.reserve(contents.size() * 2);
for (auto &c : contents) {
currentConstraints.insert(&c);
}
}
bool PropagationWorklist::commit() {
bool hadNew = newConstraintCount > 0;
newConstraintCount = 0;
return hadNew;
}
void PropagationWorklist::propagateTransitivity() {
// Prepare for join.
constexpr size_t N = 8;
llvm::DenseMap<TypeVar *, llvm::SmallVector<ValueType *, N>> varToValueType;
llvm::DenseMap<TypeVar *, llvm::SmallVector<TypeBase *, N>> varToAny;
for (auto *c : currentConstraints) {
auto *lhsVar = llvm::dyn_cast<TypeVar>(c->getLhs());
auto *rhsVar = llvm::dyn_cast<TypeVar>(c->getRhs());
if (lhsVar) {
varToAny[lhsVar].push_back(c->getRhs());
}
if (rhsVar) {
if (auto *vt = llvm::dyn_cast<ValueType>(c->getLhs())) {
varToValueType[rhsVar].push_back(vt);
}
}
}
// Expand join.
for (auto vtIt : varToValueType) {
auto &lhsSet = vtIt.second;
auto anyIt = varToAny.find(vtIt.first);
if (anyIt == varToAny.end())
continue;
auto &rhsSet = anyIt->second;
for (ValueType *lhsItem : lhsSet) {
for (TypeBase *rhsItem : rhsSet) {
Constraint *newC = env.getContext().getConstraint(lhsItem, rhsItem);
if (currentConstraints.insert(newC).second) {
llvm::errs() << "-->ADD TRANS CONSTRAINT: ";
newC->print(llvm::errs());
llvm::errs() << "\n";
newConstraintCount += 1;
}
}
}
}
}

View File

@ -1,4 +1,5 @@
add_library(NPCOMPTypingCPA
Algorithm.cpp
Interfaces.cpp
Support.cpp
)

View File

@ -0,0 +1,23 @@
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-cpa-type-inference | FileCheck %s --dump-input=fail
from npcomp.compiler import test_config
import_global = test_config.create_import_dump_decorator()
# CHECK-LABEL: func @arithmetic_expression
@import_global
def arithmetic_expression():
return 1 + 2 - 3 * 4
# CHECK-LABEL: func @arg_inference
@import_global
def arg_inference(a, b):
return a + 2 * b
# CHECK-LABEL: func @conditional_inference
@import_global
def conditional_inference(cond, a, b):
return a if cond + 1 else not (b * 4)