mirror of https://github.com/llvm/torch-mlir
More progress on CPA.
* Added transitivity propagation rules. * Fixed up some copy-n-paste inversions from the old algorithm.pull/1/head
parent
74b8bed7e3
commit
1a13c38033
|
@ -0,0 +1,54 @@
|
|||
//===- Algorithm.h - Main algorithm ---------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support types and utilities for the Cartesian Product Algorithm for
|
||||
// Type Inference.
|
||||
//
|
||||
// See:
|
||||
// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.30.8177
|
||||
// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.129.2756
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_TYPING_CPA_ALGORITHM_H
|
||||
#define NPCOMP_TYPING_CPA_ALGORITHM_H
|
||||
|
||||
#include "npcomp/Typing/CPA/Support.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Typing {
|
||||
namespace CPA {
|
||||
|
||||
/// Propagates constraints in an environment.
|
||||
class PropagationWorklist {
|
||||
public:
|
||||
PropagationWorklist(Environment &env);
|
||||
|
||||
/// Propagates any current constraints that match the transitivity rule:
|
||||
/// τv <: t, t <: τ (τv=ValueType, t=TypeVar, τ=TypeBase)
|
||||
/// Expanding to:
|
||||
/// τv <: τ
|
||||
/// (τv=ValueType, t=TypeVar, τ=TypeBase)
|
||||
void propagateTransitivity();
|
||||
|
||||
/// Commits the current round, returning true if any new constraints were
|
||||
/// added.
|
||||
bool commit();
|
||||
|
||||
private:
|
||||
Environment &env;
|
||||
llvm::DenseSet<Constraint *> currentConstraints;
|
||||
int newConstraintCount = 0;
|
||||
};
|
||||
|
||||
} // namespace CPA
|
||||
} // namespace Typing
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_TYPING_CPA_ALGORITHM_H
|
|
@ -226,7 +226,7 @@ private:
|
|||
class ValueType : public TypeBase {
|
||||
public:
|
||||
using TypeBase::TypeBase;
|
||||
bool classof(ObjectBase *ob) {
|
||||
static bool classof(const ObjectBase *ob) {
|
||||
return ob->getKind() >= Kind::FIRST_VALUE_TYPE &&
|
||||
ob->getKind() <= Kind::LAST_VALUE_TYPE;
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ public:
|
|||
IRValueType(mlir::Type irType)
|
||||
: ValueType(Kind::IRValueType, llvm::hash_combine(irType)),
|
||||
irType(irType) {}
|
||||
static bool classof(ObjectBase *ob) {
|
||||
static bool classof(const ObjectBase *ob) {
|
||||
return ob->getKind() == Kind::IRValueType;
|
||||
}
|
||||
|
||||
|
@ -254,7 +254,7 @@ private:
|
|||
/// Referred to as 'obj(δ, [ li : τi ])'
|
||||
class ObjectValueType : public ValueType {
|
||||
public:
|
||||
static bool classof(ObjectBase *ob) {
|
||||
static bool classof(const ObjectBase *ob) {
|
||||
return ob->getKind() == Kind::ObjectValueType;
|
||||
}
|
||||
|
||||
|
@ -291,8 +291,8 @@ public:
|
|||
return ob->getKind() == Kind::Constraint;
|
||||
}
|
||||
|
||||
TypeBase *getT1() { return t1; }
|
||||
TypeBase *getT2() { return t2; }
|
||||
TypeBase *getLhs() { return t1; }
|
||||
TypeBase *getRhs() { return t2; }
|
||||
|
||||
void setContextOp(Operation *contextOp) { this->contextOp = contextOp; }
|
||||
|
||||
|
@ -334,17 +334,18 @@ private:
|
|||
/// Referred to as: 'C'
|
||||
class ConstraintSet : public ObjectBase {
|
||||
public:
|
||||
using CollectionTy = llvm::simple_ilist<Constraint>;
|
||||
static bool classof(ObjectBase *ob) {
|
||||
return ob->getKind() == Kind::ConstraintSet;
|
||||
}
|
||||
|
||||
llvm::simple_ilist<Constraint> &getConstraints() { return constraints; }
|
||||
CollectionTy &getContents() { return constraints; }
|
||||
|
||||
void print(raw_ostream &os, bool brief = false) override;
|
||||
|
||||
private:
|
||||
ConstraintSet() : ObjectBase(Kind::ConstraintSet){};
|
||||
llvm::simple_ilist<Constraint> constraints;
|
||||
CollectionTy constraints;
|
||||
friend class Context;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
|
||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
||||
#include "npcomp/Typing/CPA/Algorithm.h"
|
||||
#include "npcomp/Typing/CPA/Interfaces.h"
|
||||
#include "npcomp/Typing/CPA/Support.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -50,7 +51,7 @@ public:
|
|||
auto subVt = resolveValueType(subValue);
|
||||
CPA::Constraint *c = env.getContext().getConstraint(superVt, subVt);
|
||||
c->setContextOp(contextOp);
|
||||
env.getConstraints()->getConstraints().push_back(*c);
|
||||
env.getConstraints()->getContents().push_back(*c);
|
||||
}
|
||||
|
||||
LogicalResult runOnFunction(FuncOp funcOp) {
|
||||
|
@ -78,6 +79,7 @@ public:
|
|||
// Note that the condition is always i1 and not subject to type
|
||||
// inference.
|
||||
addSubtypeConstraint(op.true_value(), op.false_value(), op);
|
||||
addSubtypeConstraint(op.false_value(), op.true_value(), op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<ToBooleanOp>(childOp)) {
|
||||
|
@ -108,20 +110,24 @@ public:
|
|||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<UnknownCastOp>(childOp)) {
|
||||
addSubtypeConstraint(op.operand(), op.result(), op);
|
||||
addSubtypeConstraint(op.result(), op.operand(), op);
|
||||
// addSubtypeConstraint(op.operand(), op.result(), op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<BinaryExprOp>(childOp)) {
|
||||
// TODO: This should really be applying arithmetic promotion, not
|
||||
// strict equality.
|
||||
addSubtypeConstraint(op.left(), op.right(), op);
|
||||
addSubtypeConstraint(op.left(), op.result(), op);
|
||||
addSubtypeConstraint(op.result(), op.left(), op);
|
||||
addSubtypeConstraint(op.result(), op.right(), op);
|
||||
// addSubtypeConstraint(op.left(), op.right(), op);
|
||||
// addSubtypeConstraint(op.left(), op.result(), op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (auto op = dyn_cast<BinaryCompareOp>(childOp)) {
|
||||
// TODO: This should really be applying arithmetic promotion, not
|
||||
// strict equality.
|
||||
addSubtypeConstraint(op.left(), op.right(), op);
|
||||
addSubtypeConstraint(op.right(), op.left(), op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
|
@ -142,7 +148,9 @@ public:
|
|||
}
|
||||
for (auto it : llvm::zip(funcReturnOp->getOperands(),
|
||||
childOp->getOperands())) {
|
||||
addSubtypeConstraint(std::get<0>(it), std::get<1>(it), childOp);
|
||||
// addSubtypeConstraint(std::get<0>(it), std::get<1>(it),
|
||||
// childOp);
|
||||
addSubtypeConstraint(std::get<1>(it), std::get<0>(it), childOp);
|
||||
}
|
||||
}
|
||||
funcReturnOp = childOp;
|
||||
|
@ -183,11 +191,18 @@ public:
|
|||
|
||||
llvm::errs() << "CONSTRAINTS:\n";
|
||||
llvm::errs() << "------------\n";
|
||||
env.getConstraints()->print(llvm::errs());
|
||||
env.getConstraints()->print(llvm::errs(), true);
|
||||
|
||||
llvm::errs() << "TYPEVARS:\n";
|
||||
llvm::errs() << "\nTYPEVARS:\n";
|
||||
llvm::errs() << "---------\n";
|
||||
env.getTypeVars()->print(llvm::errs());
|
||||
|
||||
CPA::PropagationWorklist prop(env);
|
||||
do {
|
||||
llvm::errs() << "\nPROPAGATE CLOSURE:\n";
|
||||
llvm::errs() << "------------------\n";
|
||||
prop.propagateTransitivity();
|
||||
} while (prop.commit());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
//===- Algorith.cpp - Main algorithm --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Typing/CPA/Algorithm.h"
|
||||
|
||||
using namespace mlir::NPCOMP::Typing::CPA;
|
||||
|
||||
PropagationWorklist::PropagationWorklist(Environment &env) : env(env) {
|
||||
auto &contents = env.getConstraints()->getContents();
|
||||
currentConstraints.reserve(contents.size() * 2);
|
||||
for (auto &c : contents) {
|
||||
currentConstraints.insert(&c);
|
||||
}
|
||||
}
|
||||
|
||||
bool PropagationWorklist::commit() {
|
||||
bool hadNew = newConstraintCount > 0;
|
||||
newConstraintCount = 0;
|
||||
return hadNew;
|
||||
}
|
||||
|
||||
void PropagationWorklist::propagateTransitivity() {
|
||||
// Prepare for join.
|
||||
constexpr size_t N = 8;
|
||||
llvm::DenseMap<TypeVar *, llvm::SmallVector<ValueType *, N>> varToValueType;
|
||||
llvm::DenseMap<TypeVar *, llvm::SmallVector<TypeBase *, N>> varToAny;
|
||||
for (auto *c : currentConstraints) {
|
||||
auto *lhsVar = llvm::dyn_cast<TypeVar>(c->getLhs());
|
||||
auto *rhsVar = llvm::dyn_cast<TypeVar>(c->getRhs());
|
||||
|
||||
if (lhsVar) {
|
||||
varToAny[lhsVar].push_back(c->getRhs());
|
||||
}
|
||||
if (rhsVar) {
|
||||
if (auto *vt = llvm::dyn_cast<ValueType>(c->getLhs())) {
|
||||
varToValueType[rhsVar].push_back(vt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expand join.
|
||||
for (auto vtIt : varToValueType) {
|
||||
auto &lhsSet = vtIt.second;
|
||||
auto anyIt = varToAny.find(vtIt.first);
|
||||
if (anyIt == varToAny.end())
|
||||
continue;
|
||||
auto &rhsSet = anyIt->second;
|
||||
|
||||
for (ValueType *lhsItem : lhsSet) {
|
||||
for (TypeBase *rhsItem : rhsSet) {
|
||||
Constraint *newC = env.getContext().getConstraint(lhsItem, rhsItem);
|
||||
if (currentConstraints.insert(newC).second) {
|
||||
llvm::errs() << "-->ADD TRANS CONSTRAINT: ";
|
||||
newC->print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
newConstraintCount += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
add_library(NPCOMPTypingCPA
|
||||
Algorithm.cpp
|
||||
Interfaces.cpp
|
||||
Support.cpp
|
||||
)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# RUN: %PYTHON %s | npcomp-opt -split-input-file -basicpy-cpa-type-inference | FileCheck %s --dump-input=fail
|
||||
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @arithmetic_expression
|
||||
@import_global
|
||||
def arithmetic_expression():
|
||||
return 1 + 2 - 3 * 4
|
||||
|
||||
|
||||
# CHECK-LABEL: func @arg_inference
|
||||
@import_global
|
||||
def arg_inference(a, b):
|
||||
return a + 2 * b
|
||||
|
||||
|
||||
# CHECK-LABEL: func @conditional_inference
|
||||
@import_global
|
||||
def conditional_inference(cond, a, b):
|
||||
return a if cond + 1 else not (b * 4)
|
Loading…
Reference in New Issue