Add RefinePublicReturn pass.

This pass allows shape information to be propagated to return types,
which is nontrivial and cannot be cleanly put anywhere else as it
changes the public ABI, which is a concern that we want to keep
concentrated in one place.
pull/203/head
Sean Silva 2021-04-06 15:23:14 -07:00
parent 1e357ae680
commit 927546b3c5
5 changed files with 157 additions and 0 deletions

View File

@ -19,6 +19,7 @@ namespace Numpy {
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
std::unique_ptr<OperationPass<FuncOp>> createArrayToTensorPass();
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
} // namespace Numpy

View File

@ -38,4 +38,26 @@ def NumpyArrayToTensor : Pass<"numpy-array-to-tensor", "FuncOp"> {
}
def NumpyRefinePublicReturn : Pass<"numpy-refine-public-return", "ModuleOp"> {
let summary = "Refine public return";
let constructor = "mlir::NPCOMP::Numpy::createRefinePublicReturnPass()";
let description = [{
Refines types of values return from public functions based on
intraprocedural information.
This pass effectively encodes an assumption by the pass pipeline author that
the public calling convention of the module can have its types refined,
without causing ABI mismatches. This is frequently true -- for example, in
many systems, `tensor<?x?xf32>`, `tensor<3x3xf32>` and
`tensor<*x!numpy.any_dtype>` are all the same data structure on calling
convention boundaries.
This pass is expected to run after shape refinement has occurred to
otherwise resolve shapes, and is currently mainly useful to convert
rank/dtype-erased function boundaries to ranked, dtyped code for
compiler backends.
}];
}
#endif // NPCOMP_NUMPY_PASSES

View File

@ -2,6 +2,7 @@ add_npcomp_conversion_library(NPCOMPNumpyPasses
ArrayToTensor.cpp
Passes.cpp
PublicFunctionToTensor.cpp
RefinePublicReturn.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms

View File

@ -0,0 +1,82 @@
//===- RefinePublicReturn.cpp ------------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Numpy;
namespace {
class RefinePublicReturnPass
: public NumpyRefinePublicReturnBase<RefinePublicReturnPass> {
void runOnOperation() override {
auto module = getOperation();
module.walk([&](FuncOp func) {
if (func.getVisibility() != SymbolTable::Visibility::Public)
return;
if (func.isExternal())
return;
auto uses = SymbolTable::getSymbolUses(func, module);
if (!uses || uses->begin() != uses->end()) {
func.emitError() << "unimplemented: cannot refine public return for "
<< "for public function with uses";
return signalPassFailure();
}
rewriteSignature(func);
});
}
void rewriteSignature(FuncOp func) {
// Find the unique return op.
ReturnOp returnOp;
WalkResult walkResult = func.walk([&](ReturnOp op) {
if (returnOp)
return WalkResult::interrupt();
returnOp = op;
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
func.emitError() << "unimplemented: refining returns for function with "
"more than one return op";
return signalPassFailure();
}
// Get the new operands. Either the original operand, or if there is a
// TensorStaticInfoCastOp then the pre-casted operand, which is presumed to
// have a more precise type.
SmallVector<Value> newOperands;
for (auto operand : returnOp.getOperands()) {
if (auto cast = operand.getDefiningOp<TensorStaticInfoCastOp>()) {
newOperands.push_back(cast.getOperand());
} else {
newOperands.push_back(operand);
}
}
returnOp->setOperands(newOperands);
// Update the function type.
auto funcType = func.getType();
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
ValueRange(newOperands).getTypes()));
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::Numpy::createRefinePublicReturnPass() {
return std::make_unique<RefinePublicReturnPass>();
}

View File

@ -0,0 +1,51 @@
// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-refine-public-return | FileCheck %s
// CHECK-LABEL: func @basic(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<?xf32>) {
// CHECK: %[[CTRUE:.*]] = constant true
// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[ARG]], %[[CTRUE]], %[[ARG]] : tensor<?xf32>, i1, tensor<?xf32>
func @basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
%ctrue = std.constant true
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
}
// No conversion on private function.
// CHECK-LABEL: func private @basic_private(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
// CHECK: %[[CTRUE:.*]] = constant true
// CHECK: %[[CASTED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[ARG]], %[[CTRUE]], %[[CASTED]] : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
func private @basic_private(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
%ctrue = std.constant true
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
}
// -----
// Call to public function.
// expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %arg0 : tensor<*xf32>
}
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Multiple returns.
// expected-error @+1 {{unimplemented}}
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%ctrue = constant true
cond_br %ctrue, ^bb1, ^bb2
^bb1:
return %arg0 : tensor<*xf32>
^bb2:
return %arg0 : tensor<*xf32>
}