Add simple type refinement pass.

Currently implemented as a simple intraprocedural dataflow analysis over
a standard ShapedType lattice (hasRank, sizes, and elementType).

It currently hardcodes a few key pieces of information:
- shape transfer functions
- whether it is legal to update the operand type of an op

This needs to be made pluggable obviously and the core propagation logic
moved somewhere agnostic.
pull/203/head
Sean Silva 2021-04-05 17:43:23 -07:00
parent 6431b0f11f
commit 1e357ae680
5 changed files with 361 additions and 0 deletions

View File

@ -28,6 +28,8 @@ void createGlobalizePipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();
} // namespace Torch } // namespace Torch
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -125,4 +125,13 @@ def AdjustCallingConventions
}]; }];
} }
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
let summary = "Refine types";
let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()";
let description = [{
Refines types of the program. Currently, this means shapes and dtypes of
tensors/arrays.
}];
}
#endif // NPCOMP_TORCH_PASSES #endif // NPCOMP_TORCH_PASSES

View File

@ -3,6 +3,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses
Passes.cpp Passes.cpp
GlobalizeObjectGraph.cpp GlobalizeObjectGraph.cpp
PrepareForGlobalizeObjectGraph.cpp PrepareForGlobalizeObjectGraph.cpp
RefineTypes.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms

View File

@ -0,0 +1,312 @@
//===- RefineTypes.cpp ------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
// -----------------------------------------------------------------------------
// Analysis.
// -----------------------------------------------------------------------------
constexpr int64_t kUnknownSize = -1;
namespace {
// Statically known information for a particular Value.
//
// This struct currently tracks only information relevant for tensor/array-like
// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped
// type as long as it is in the default "no knowledge" state returned by
// `getMostConservativeKnowledge`. The important invariant is that we cannot
// claim to know something about a value which is false.
//
// This class could also be called "dataflow facts", "lattice value", etc.
struct ValueKnowledge {
ValueKnowledge() = delete;
// We enforce that `elementType` is always a valid type (possibly
// !numpy.any_dtype), and `sizes` is empty unless `hasRank`.
// So default constructing is prohibited.
ValueKnowledge(bool hasRank, std::vector<int64_t> sizes, Type elementType)
: hasRank(hasRank), sizes(sizes), elementType(elementType) {
assert(elementType != nullptr);
assert(sizes.size() == 0 || hasRank);
}
// Get a safe "most conservative knowledge" default state.
static ValueKnowledge getMostConservativeKnowledge(MLIRContext *context) {
return ValueKnowledge(false, {}, Numpy::AnyDtypeType::get(context));
}
// Whether the Value is known to have a rank.
bool hasRank;
// If `hasRank` the sizes along each rank. Unknown sizes are represented as
// `kUnknownSize`.
std::vector<int64_t> sizes;
// The element type of a shaped type.
// This is equal to !numpy.any_dtype if it is not a concrete type.
Type elementType;
};
} // namespace
bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
return std::make_tuple(lhs.hasRank, lhs.sizes, lhs.elementType) ==
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
}
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge &knowledge) {
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
// llvm::interleaveComma(knowledge.sizes, os);
// os << "]"
// << ", elementType = " << knowledge.elementType;
// return os;
// }
// Given two pieces of static knowledge, calculate conservatively the
// information we can be sure about.
ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
// Mental model: All conditions are checking how to change from the safe "no
// knowledge" default-initialized state to a state with more knowledge
// consistent with lhs and rhs.
ValueKnowledge result = ValueKnowledge::getMostConservativeKnowledge(
lhs.elementType.getContext());
if (lhs.hasRank && !rhs.hasRank) {
result.hasRank = true;
result.sizes = lhs.sizes;
} else if (!lhs.hasRank && rhs.hasRank) {
result.hasRank = true;
result.sizes = rhs.sizes;
} else if (lhs.hasRank && rhs.hasRank &&
lhs.sizes.size() == rhs.sizes.size()) {
result.hasRank = true;
result.sizes.resize(lhs.sizes.size(), kUnknownSize);
for (int i = 0, e = result.sizes.size(); i != e; i++) {
int64_t lhsSize = lhs.sizes[i];
int64_t rhsSize = rhs.sizes[i];
int64_t &resultSize = result.sizes[i];
if (lhsSize == kUnknownSize) {
resultSize = rhsSize;
} else if (rhsSize == kUnknownSize) {
resultSize = lhsSize;
} else if (lhsSize == rhsSize) {
resultSize = lhsSize;
}
}
}
if (!lhs.elementType || lhs.elementType.isa<Numpy::AnyDtypeType>()) {
result.elementType = rhs.elementType;
} else if (!rhs.elementType || rhs.elementType.isa<Numpy::AnyDtypeType>()) {
result.elementType = lhs.elementType;
} else if (lhs.elementType == rhs.elementType) {
result.elementType = lhs.elementType;
}
return result;
}
// Get the static knowledge intrinsic to `type`.
ValueKnowledge getKnowledgeFromType(Type type) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
ValueKnowledge result =
ValueKnowledge::getMostConservativeKnowledge(type.getContext());
if (tensorType.hasRank()) {
result.hasRank = true;
result.sizes = tensorType.getShape().vec();
}
result.elementType = tensorType.getElementType();
return result;
}
return ValueKnowledge::getMostConservativeKnowledge(type.getContext());
}
// Simple forward intraprocedural dataflow for type information.
class TypeAnalyzer {
public:
TypeAnalyzer(MLIRContext *context) : context(context) {}
void propagate(Region &region);
// Get the knowledge that is known about `v`.
ValueKnowledge &getKnowledge(Value v);
private:
// Incorporate `knowledge` into what is known about `v`.
// Return true if new knowledge was obtained about `v`.
bool incorporateKnowledge(Value v, ValueKnowledge knowledge);
MLIRContext *context;
DenseMap<Value, ValueKnowledge> facts;
};
void TypeAnalyzer::propagate(Region &region) {
bool changed;
do {
changed = false;
// TODO: Find out why region.walk doesn't walk the blocks.
for (Block &block : region) {
for (Value v : block.getArguments())
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
for (Operation &op : block.getOperations()) {
for (Value v : op.getResults())
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
changed |= incorporateKnowledge(op.getResult(0),
getKnowledge(op.getOperand(0)));
}
}
};
} while (changed);
}
ValueKnowledge &TypeAnalyzer::getKnowledge(Value v) {
auto p =
facts.insert({v, ValueKnowledge::getMostConservativeKnowledge(context)});
return p.first->second;
}
bool TypeAnalyzer::incorporateKnowledge(Value v, ValueKnowledge knowledge) {
ValueKnowledge &currentKnowledge = getKnowledge(v);
ValueKnowledge updatedKnowledge = join(currentKnowledge, knowledge);
assert(join(updatedKnowledge, knowledge) == updatedKnowledge &&
"nonmonotonic!");
assert(join(updatedKnowledge, currentKnowledge) == updatedKnowledge &&
"nonmonotonic!");
if (currentKnowledge == updatedKnowledge)
return false;
currentKnowledge = updatedKnowledge;
return true;
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------
// Get the most refined TensorType compatible with ValueKnowledge.
static TensorType getTensorTypeFromKnowledge(MLIRContext *context,
ValueKnowledge &knowledge) {
Type elementType = knowledge.elementType ? knowledge.elementType
: Numpy::AnyDtypeType::get(context);
if (!knowledge.hasRank)
return UnrankedTensorType::get(elementType);
return RankedTensorType::get(knowledge.sizes, elementType);
}
// Get a the most refined type compatible with ValueKnowledge, or null if that
// is not possible.
static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) {
if (v.getType().isa<TensorType>())
return getTensorTypeFromKnowledge(v.getContext(), analyzer.getKnowledge(v));
// TODO: Support !numpy.ndarray type.
return nullptr;
}
// Return true whether a type `v` can have its type updated in place.
// This is a function of the value itself and also its users.
static bool canUpdateTypeInPlace(Value v) {
// TODO: There are really two different predicates here, which need to be
// properly interface-ized or otherwise make pluggable.
// 1. Whether an operation allows its result to be refined to a certain type.
// 2. Whether an operand of an operation can be refined to a certain
// type.
//
// A simple first step that probably is enough in practice is a simple trait
// AllowsTypeRefinement which answers yes to both questions. In general, an op
// might allow refinement of some operands/results but not others, but that
// seems unlikely.
//
// Currently, we answer both with the same logic, which is just enough for our
// e2e bringup.
Dialect *atenDialect = v.getContext()->getOrLoadDialect<aten::ATenDialect>();
auto canValueIntrinsicallyBeUpdated = [&](Value v) {
// TODO: Update block arguments.
if (v.isa<BlockArgument>())
return false;
Operation *op = v.cast<OpResult>().getOwner();
if (op->getDialect() == atenDialect)
return true;
if (isa<Numpy::TensorStaticInfoCastOp>(op))
return true;
return false;
};
// TODO: Handle BranchOpInterface and RegionBranchOpInterface ops.
return canValueIntrinsicallyBeUpdated(v) &&
llvm::all_of(v.getUses(), [&](OpOperand &use) {
Operation *user = use.getOwner();
if (user->getDialect() == atenDialect)
return true;
if (isa<Numpy::TensorStaticInfoCastOp>(user))
return true;
return false;
});
}
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
func.walk([&](Operation *op) {
for (Value v : op->getResults()) {
Type refinedType = getMostRefinedStaticType(v, analyzer);
// No type? Nothing to do.
if (!refinedType)
return;
// Type is same as existing one? Nothing to do.
if (refinedType == v.getType())
return;
if (canUpdateTypeInPlace(v)) {
// Update type in place if possible.
v.setType(refinedType);
} else if (v.getType().isa<TensorType>() && v.isa<OpResult>()) {
// Update the type in place, and cast the information way explicitly so
// that users observe the original type.
// TODO: Support updating !numpy.ndarray type too.
// TODO: Support updating block arguments.
// TODO: This update could be more fine-grained (RAUW with per-use
// canUpdateTypeInPlace information), or to get we could get the same
// effect by implementing a canonicalization that uses the fine-grained
// information used by canUpdateTypeInPlace to bypass
// numpy.tensor_static_info_cast when the consuming op is ok with that.
Operation *op = v.cast<OpResult>().getOwner();
OpBuilder builder(op->getBlock(), std::next(op->getIterator()));
auto staticInfoCast = builder.create<Numpy::TensorStaticInfoCastOp>(
op->getLoc(), v.getType(), v);
SmallPtrSet<Operation *, 1> exceptions;
exceptions.insert(staticInfoCast);
v.replaceAllUsesExcept(staticInfoCast, exceptions);
v.setType(refinedType);
}
}
});
}
namespace {
class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
void runOnOperation() override {
auto func = getOperation();
TypeAnalyzer analyzer(&getContext());
analyzer.propagate(func.getRegion());
optimize(func, analyzer);
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::Torch::createRefineTypesPass() {
return std::make_unique<RefineTypesPass>();
}

View File

@ -0,0 +1,37 @@
// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// CHECK: %[[SHAPED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<2x3x?xf32> to tensor<2x3x?xf32>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
%0 = numpy.tensor_static_info_cast %arg0 : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
return %0 : tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// CHECK: %[[SHAPED:.*]] = "aten.tanh"(%[[ARG]]) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
%1 = "aten.tanh"(%arg0) : (tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype>
return %1 : tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// Check propagation through multiple ops.
// CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32>
// CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32>
// CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32>
%1 = "aten.tanh"(%arg0) : (tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype>
%2 = "aten.tanh"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
%3 = "aten.tanh"(%2) : (tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
return %3 : tensor<*x!numpy.any_dtype>
}