mirror of https://github.com/llvm/torch-mlir
Add RefinePublicReturn pass.
This pass allows shape information to be propagated to return types, which is nontrivial and cannot be cleanly put anywhere else as it changes the public ABI, which is a concern that we want to keep concentrated in one place.pull/203/head
parent
1e357ae680
commit
927546b3c5
|
@ -19,6 +19,7 @@ namespace Numpy {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createArrayToTensorPass();
|
std::unique_ptr<OperationPass<FuncOp>> createArrayToTensorPass();
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
|
|
||||||
} // namespace Numpy
|
} // namespace Numpy
|
||||||
|
|
||||||
|
|
|
@ -38,4 +38,26 @@ def NumpyArrayToTensor : Pass<"numpy-array-to-tensor", "FuncOp"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def NumpyRefinePublicReturn : Pass<"numpy-refine-public-return", "ModuleOp"> {
|
||||||
|
let summary = "Refine public return";
|
||||||
|
let constructor = "mlir::NPCOMP::Numpy::createRefinePublicReturnPass()";
|
||||||
|
let description = [{
|
||||||
|
Refines types of values return from public functions based on
|
||||||
|
intraprocedural information.
|
||||||
|
|
||||||
|
This pass effectively encodes an assumption by the pass pipeline author that
|
||||||
|
the public calling convention of the module can have its types refined,
|
||||||
|
without causing ABI mismatches. This is frequently true -- for example, in
|
||||||
|
many systems, `tensor<?x?xf32>`, `tensor<3x3xf32>` and
|
||||||
|
`tensor<*x!numpy.any_dtype>` are all the same data structure on calling
|
||||||
|
convention boundaries.
|
||||||
|
|
||||||
|
This pass is expected to run after shape refinement has occurred to
|
||||||
|
otherwise resolve shapes, and is currently mainly useful to convert
|
||||||
|
rank/dtype-erased function boundaries to ranked, dtyped code for
|
||||||
|
compiler backends.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif // NPCOMP_NUMPY_PASSES
|
#endif // NPCOMP_NUMPY_PASSES
|
||||||
|
|
|
@ -2,6 +2,7 @@ add_npcomp_conversion_library(NPCOMPNumpyPasses
|
||||||
ArrayToTensor.cpp
|
ArrayToTensor.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
PublicFunctionToTensor.cpp
|
PublicFunctionToTensor.cpp
|
||||||
|
RefinePublicReturn.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
//===- RefinePublicReturn.cpp ------------------------------------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Numpy;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class RefinePublicReturnPass
|
||||||
|
: public NumpyRefinePublicReturnBase<RefinePublicReturnPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
module.walk([&](FuncOp func) {
|
||||||
|
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
||||||
|
return;
|
||||||
|
if (func.isExternal())
|
||||||
|
return;
|
||||||
|
auto uses = SymbolTable::getSymbolUses(func, module);
|
||||||
|
if (!uses || uses->begin() != uses->end()) {
|
||||||
|
func.emitError() << "unimplemented: cannot refine public return for "
|
||||||
|
<< "for public function with uses";
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
rewriteSignature(func);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewriteSignature(FuncOp func) {
|
||||||
|
// Find the unique return op.
|
||||||
|
ReturnOp returnOp;
|
||||||
|
WalkResult walkResult = func.walk([&](ReturnOp op) {
|
||||||
|
if (returnOp)
|
||||||
|
return WalkResult::interrupt();
|
||||||
|
returnOp = op;
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
if (walkResult.wasInterrupted()) {
|
||||||
|
func.emitError() << "unimplemented: refining returns for function with "
|
||||||
|
"more than one return op";
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the new operands. Either the original operand, or if there is a
|
||||||
|
// TensorStaticInfoCastOp then the pre-casted operand, which is presumed to
|
||||||
|
// have a more precise type.
|
||||||
|
SmallVector<Value> newOperands;
|
||||||
|
for (auto operand : returnOp.getOperands()) {
|
||||||
|
if (auto cast = operand.getDefiningOp<TensorStaticInfoCastOp>()) {
|
||||||
|
newOperands.push_back(cast.getOperand());
|
||||||
|
} else {
|
||||||
|
newOperands.push_back(operand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
returnOp->setOperands(newOperands);
|
||||||
|
|
||||||
|
// Update the function type.
|
||||||
|
auto funcType = func.getType();
|
||||||
|
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
|
||||||
|
ValueRange(newOperands).getTypes()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::Numpy::createRefinePublicReturnPass() {
|
||||||
|
return std::make_unique<RefinePublicReturnPass>();
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-refine-public-return | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @basic(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<?xf32>) {
|
||||||
|
// CHECK: %[[CTRUE:.*]] = constant true
|
||||||
|
// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[ARG]], %[[CTRUE]], %[[ARG]] : tensor<?xf32>, i1, tensor<?xf32>
|
||||||
|
func @basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||||
|
%ctrue = std.constant true
|
||||||
|
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||||
|
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
// No conversion on private function.
|
||||||
|
// CHECK-LABEL: func private @basic_private(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||||
|
// CHECK: %[[CTRUE:.*]] = constant true
|
||||||
|
// CHECK: %[[CASTED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||||
|
// CHECK: return %[[ARG]], %[[CTRUE]], %[[CASTED]] : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||||
|
func private @basic_private(%arg0: tensor<?xf32>) -> (tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>) {
|
||||||
|
%ctrue = std.constant true
|
||||||
|
%cast = numpy.tensor_static_info_cast %arg0 : tensor<?xf32> to tensor<*x!numpy.any_dtype>
|
||||||
|
return %arg0, %ctrue, %cast : tensor<?xf32>, i1, tensor<*x!numpy.any_dtype>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Call to public function.
|
||||||
|
// expected-error @+1 {{unimplemented}}
|
||||||
|
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
return %arg0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Multiple returns.
|
||||||
|
// expected-error @+1 {{unimplemented}}
|
||||||
|
func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
%ctrue = constant true
|
||||||
|
cond_br %ctrue, ^bb1, ^bb2
|
||||||
|
^bb1:
|
||||||
|
return %arg0 : tensor<*xf32>
|
||||||
|
^bb2:
|
||||||
|
return %arg0 : tensor<*xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue