From 1e357ae68017f2277e20533cb7a317839afa4cbe Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 5 Apr 2021 17:43:23 -0700 Subject: [PATCH] Add simple type refinement pass. Currently implemented as a simple intraprocedural dataflow analysis over a standard ShapedType lattice (hasRank, sizes, and elementType). It currently hardcodes a few key pieces of information: - shape transfer functions - whether it is legal to update the operand type of an op This needs to be made pluggable obviously and the core propagation logic moved somewhere agnostic. --- .../npcomp/Dialect/Torch/Transforms/Passes.h | 2 + .../npcomp/Dialect/Torch/Transforms/Passes.td | 9 + lib/Dialect/Torch/Transforms/CMakeLists.txt | 1 + lib/Dialect/Torch/Transforms/RefineTypes.cpp | 312 ++++++++++++++++++ test/Dialect/Torch/refine-types.mlir | 37 +++ 5 files changed, 361 insertions(+) create mode 100644 lib/Dialect/Torch/Transforms/RefineTypes.cpp create mode 100644 test/Dialect/Torch/refine-types.mlir diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.h b/include/npcomp/Dialect/Torch/Transforms/Passes.h index a38df95d4..51e6ecc89 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.h +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.h @@ -28,6 +28,8 @@ void createGlobalizePipeline(OpPassManager &pm); std::unique_ptr> createAdjustCallingConventionsPass(); +std::unique_ptr> createRefineTypesPass(); + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/npcomp/Dialect/Torch/Transforms/Passes.td b/include/npcomp/Dialect/Torch/Transforms/Passes.td index 8c6dcf293..d1654a9a1 100644 --- a/include/npcomp/Dialect/Torch/Transforms/Passes.td +++ b/include/npcomp/Dialect/Torch/Transforms/Passes.td @@ -125,4 +125,13 @@ def AdjustCallingConventions }]; } +def RefineTypes : Pass<"torch-refine-types", "FuncOp"> { + let summary = "Refine types"; + let constructor = "mlir::NPCOMP::Torch::createRefineTypesPass()"; + let description = [{ + Refines types of the program. Currently, this means shapes and dtypes of + tensors/arrays. + }]; +} + #endif // NPCOMP_TORCH_PASSES diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 7349d9a7e..a366316c8 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_npcomp_conversion_library(NPCOMPTorchPasses Passes.cpp GlobalizeObjectGraph.cpp PrepareForGlobalizeObjectGraph.cpp + RefineTypes.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Torch/Transforms diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp new file mode 100644 index 000000000..90e451268 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -0,0 +1,312 @@ +//===- RefineTypes.cpp ------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "npcomp/Dialect/ATen/IR/ATenDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h" +#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" +#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" +#include "npcomp/Dialect/Torch/IR/TorchDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" +#include "npcomp/Dialect/Torch/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Torch; + +// ----------------------------------------------------------------------------- +// Analysis. +// ----------------------------------------------------------------------------- + +constexpr int64_t kUnknownSize = -1; + +namespace { +// Statically known information for a particular Value. +// +// This struct currently tracks only information relevant for tensor/array-like +// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped +// type as long as it is in the default "no knowledge" state returned by +// `getMostConservativeKnowledge`. The important invariant is that we cannot +// claim to know something about a value which is false. +// +// This class could also be called "dataflow facts", "lattice value", etc. +struct ValueKnowledge { + ValueKnowledge() = delete; + // We enforce that `elementType` is always a valid type (possibly + // !numpy.any_dtype), and `sizes` is empty unless `hasRank`. + // So default constructing is prohibited. + ValueKnowledge(bool hasRank, std::vector sizes, Type elementType) + : hasRank(hasRank), sizes(sizes), elementType(elementType) { + assert(elementType != nullptr); + assert(sizes.size() == 0 || hasRank); + } + + // Get a safe "most conservative knowledge" default state. + static ValueKnowledge getMostConservativeKnowledge(MLIRContext *context) { + return ValueKnowledge(false, {}, Numpy::AnyDtypeType::get(context)); + } + + // Whether the Value is known to have a rank. + bool hasRank; + // If `hasRank` the sizes along each rank. Unknown sizes are represented as + // `kUnknownSize`. + std::vector sizes; + // The element type of a shaped type. + // This is equal to !numpy.any_dtype if it is not a concrete type. + Type elementType; +}; +} // namespace + +bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + return std::make_tuple(lhs.hasRank, lhs.sizes, lhs.elementType) == + std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType); +} + +// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge &knowledge) { +// os << "hasRank = " << knowledge.hasRank << ", sizes = ["; +// llvm::interleaveComma(knowledge.sizes, os); +// os << "]" +// << ", elementType = " << knowledge.elementType; +// return os; +// } + +// Given two pieces of static knowledge, calculate conservatively the +// information we can be sure about. +ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) { + // Mental model: All conditions are checking how to change from the safe "no + // knowledge" default-initialized state to a state with more knowledge + // consistent with lhs and rhs. + ValueKnowledge result = ValueKnowledge::getMostConservativeKnowledge( + lhs.elementType.getContext()); + + if (lhs.hasRank && !rhs.hasRank) { + result.hasRank = true; + result.sizes = lhs.sizes; + } else if (!lhs.hasRank && rhs.hasRank) { + result.hasRank = true; + result.sizes = rhs.sizes; + } else if (lhs.hasRank && rhs.hasRank && + lhs.sizes.size() == rhs.sizes.size()) { + result.hasRank = true; + result.sizes.resize(lhs.sizes.size(), kUnknownSize); + for (int i = 0, e = result.sizes.size(); i != e; i++) { + int64_t lhsSize = lhs.sizes[i]; + int64_t rhsSize = rhs.sizes[i]; + int64_t &resultSize = result.sizes[i]; + if (lhsSize == kUnknownSize) { + resultSize = rhsSize; + } else if (rhsSize == kUnknownSize) { + resultSize = lhsSize; + } else if (lhsSize == rhsSize) { + resultSize = lhsSize; + } + } + } + + if (!lhs.elementType || lhs.elementType.isa()) { + result.elementType = rhs.elementType; + } else if (!rhs.elementType || rhs.elementType.isa()) { + result.elementType = lhs.elementType; + } else if (lhs.elementType == rhs.elementType) { + result.elementType = lhs.elementType; + } + return result; +} + +// Get the static knowledge intrinsic to `type`. +ValueKnowledge getKnowledgeFromType(Type type) { + if (auto tensorType = type.dyn_cast()) { + ValueKnowledge result = + ValueKnowledge::getMostConservativeKnowledge(type.getContext()); + if (tensorType.hasRank()) { + result.hasRank = true; + result.sizes = tensorType.getShape().vec(); + } + result.elementType = tensorType.getElementType(); + return result; + } + return ValueKnowledge::getMostConservativeKnowledge(type.getContext()); +} + +// Simple forward intraprocedural dataflow for type information. +class TypeAnalyzer { +public: + TypeAnalyzer(MLIRContext *context) : context(context) {} + void propagate(Region ®ion); + // Get the knowledge that is known about `v`. + ValueKnowledge &getKnowledge(Value v); + +private: + // Incorporate `knowledge` into what is known about `v`. + // Return true if new knowledge was obtained about `v`. + bool incorporateKnowledge(Value v, ValueKnowledge knowledge); + MLIRContext *context; + DenseMap facts; +}; + +void TypeAnalyzer::propagate(Region ®ion) { + bool changed; + do { + changed = false; + // TODO: Find out why region.walk doesn't walk the blocks. + for (Block &block : region) { + for (Value v : block.getArguments()) + changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType())); + for (Operation &op : block.getOperations()) { + for (Value v : op.getResults()) + changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType())); + if (isa(op)) { + changed |= incorporateKnowledge(op.getResult(0), + getKnowledge(op.getOperand(0))); + } + } + }; + } while (changed); +} + +ValueKnowledge &TypeAnalyzer::getKnowledge(Value v) { + auto p = + facts.insert({v, ValueKnowledge::getMostConservativeKnowledge(context)}); + return p.first->second; +} + +bool TypeAnalyzer::incorporateKnowledge(Value v, ValueKnowledge knowledge) { + ValueKnowledge ¤tKnowledge = getKnowledge(v); + ValueKnowledge updatedKnowledge = join(currentKnowledge, knowledge); + assert(join(updatedKnowledge, knowledge) == updatedKnowledge && + "nonmonotonic!"); + assert(join(updatedKnowledge, currentKnowledge) == updatedKnowledge && + "nonmonotonic!"); + if (currentKnowledge == updatedKnowledge) + return false; + currentKnowledge = updatedKnowledge; + return true; +} + +// ----------------------------------------------------------------------------- +// Transforms. +// ----------------------------------------------------------------------------- + +// Get the most refined TensorType compatible with ValueKnowledge. +static TensorType getTensorTypeFromKnowledge(MLIRContext *context, + ValueKnowledge &knowledge) { + Type elementType = knowledge.elementType ? knowledge.elementType + : Numpy::AnyDtypeType::get(context); + if (!knowledge.hasRank) + return UnrankedTensorType::get(elementType); + return RankedTensorType::get(knowledge.sizes, elementType); +} + +// Get a the most refined type compatible with ValueKnowledge, or null if that +// is not possible. +static Type getMostRefinedStaticType(Value v, TypeAnalyzer &analyzer) { + if (v.getType().isa()) + return getTensorTypeFromKnowledge(v.getContext(), analyzer.getKnowledge(v)); + // TODO: Support !numpy.ndarray type. + return nullptr; +} + +// Return true whether a type `v` can have its type updated in place. +// This is a function of the value itself and also its users. +static bool canUpdateTypeInPlace(Value v) { + // TODO: There are really two different predicates here, which need to be + // properly interface-ized or otherwise make pluggable. + // 1. Whether an operation allows its result to be refined to a certain type. + // 2. Whether an operand of an operation can be refined to a certain + // type. + // + // A simple first step that probably is enough in practice is a simple trait + // AllowsTypeRefinement which answers yes to both questions. In general, an op + // might allow refinement of some operands/results but not others, but that + // seems unlikely. + // + // Currently, we answer both with the same logic, which is just enough for our + // e2e bringup. + Dialect *atenDialect = v.getContext()->getOrLoadDialect(); + auto canValueIntrinsicallyBeUpdated = [&](Value v) { + // TODO: Update block arguments. + if (v.isa()) + return false; + Operation *op = v.cast().getOwner(); + if (op->getDialect() == atenDialect) + return true; + if (isa(op)) + return true; + return false; + }; + // TODO: Handle BranchOpInterface and RegionBranchOpInterface ops. + return canValueIntrinsicallyBeUpdated(v) && + llvm::all_of(v.getUses(), [&](OpOperand &use) { + Operation *user = use.getOwner(); + if (user->getDialect() == atenDialect) + return true; + if (isa(user)) + return true; + return false; + }); +} + +void optimize(FuncOp func, TypeAnalyzer &analyzer) { + func.walk([&](Operation *op) { + for (Value v : op->getResults()) { + Type refinedType = getMostRefinedStaticType(v, analyzer); + // No type? Nothing to do. + if (!refinedType) + return; + // Type is same as existing one? Nothing to do. + if (refinedType == v.getType()) + return; + if (canUpdateTypeInPlace(v)) { + // Update type in place if possible. + v.setType(refinedType); + } else if (v.getType().isa() && v.isa()) { + // Update the type in place, and cast the information way explicitly so + // that users observe the original type. + // TODO: Support updating !numpy.ndarray type too. + // TODO: Support updating block arguments. + // TODO: This update could be more fine-grained (RAUW with per-use + // canUpdateTypeInPlace information), or to get we could get the same + // effect by implementing a canonicalization that uses the fine-grained + // information used by canUpdateTypeInPlace to bypass + // numpy.tensor_static_info_cast when the consuming op is ok with that. + Operation *op = v.cast().getOwner(); + OpBuilder builder(op->getBlock(), std::next(op->getIterator())); + auto staticInfoCast = builder.create( + op->getLoc(), v.getType(), v); + SmallPtrSet exceptions; + exceptions.insert(staticInfoCast); + v.replaceAllUsesExcept(staticInfoCast, exceptions); + v.setType(refinedType); + } + } + }); +} + +namespace { +class RefineTypesPass : public RefineTypesBase { + void runOnOperation() override { + auto func = getOperation(); + TypeAnalyzer analyzer(&getContext()); + analyzer.propagate(func.getRegion()); + optimize(func, analyzer); + } +}; +} // namespace + +std::unique_ptr> +mlir::NPCOMP::Torch::createRefineTypesPass() { + return std::make_unique(); +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir new file mode 100644 index 000000000..33fbc01b4 --- /dev/null +++ b/test/Dialect/Torch/refine-types.mlir @@ -0,0 +1,37 @@ +// RUN: npcomp-opt -torch-refine-types -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @f( +// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { +// CHECK: %[[SHAPED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<2x3x?xf32> to tensor<2x3x?xf32> +// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype> +// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype> +func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { + %0 = numpy.tensor_static_info_cast %arg0 : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype> + return %0 : tensor<*x!numpy.any_dtype> +} + +// ----- + +// CHECK-LABEL: func @f( +// CHECK-SAME: %[[ARG:.*]]: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { +// CHECK: %[[SHAPED:.*]] = "aten.tanh"(%[[ARG]]) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32> +// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[SHAPED]] : tensor<2x3x?xf32> to tensor<*x!numpy.any_dtype> +// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype> +func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { + %1 = "aten.tanh"(%arg0) : (tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> + return %1 : tensor<*x!numpy.any_dtype> +} + +// ----- + +// CHECK-LABEL: func @f +func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> { + // Check propagation through multiple ops. + // CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32> + // CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32> + // CHECK: "aten.tanh"(%{{.*}}) : (tensor<2x3x?xf32>) -> tensor<2x3x?xf32> + %1 = "aten.tanh"(%arg0) : (tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> + %2 = "aten.tanh"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> + %3 = "aten.tanh"(%2) : (tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> + return %3 : tensor<*x!numpy.any_dtype> +}