From 927546b3c56e0e7aa4195b0de758690c51a03ac0 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 6 Apr 2021 15:23:14 -0700 Subject: [PATCH] Add RefinePublicReturn pass. This pass allows shape information to be propagated to return types, which is nontrivial and cannot be cleanly put anywhere else as it changes the public ABI, which is a concern that we want to keep concentrated in one place. --- .../npcomp/Dialect/Numpy/Transforms/Passes.h | 1 + .../npcomp/Dialect/Numpy/Transforms/Passes.td | 22 +++++ lib/Dialect/Numpy/Transforms/CMakeLists.txt | 1 + .../Numpy/Transforms/RefinePublicReturn.cpp | 82 +++++++++++++++++++ test/Dialect/Numpy/refine-public-return.mlir | 51 ++++++++++++ 5 files changed, 157 insertions(+) create mode 100644 lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp create mode 100644 test/Dialect/Numpy/refine-public-return.mlir diff --git a/include/npcomp/Dialect/Numpy/Transforms/Passes.h b/include/npcomp/Dialect/Numpy/Transforms/Passes.h index 18eab6a5d..f0f4abbd4 100644 --- a/include/npcomp/Dialect/Numpy/Transforms/Passes.h +++ b/include/npcomp/Dialect/Numpy/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace Numpy { std::unique_ptr> createPublicFunctionsToTensorPass(); std::unique_ptr> createArrayToTensorPass(); +std::unique_ptr> createRefinePublicReturnPass(); } // namespace Numpy diff --git a/include/npcomp/Dialect/Numpy/Transforms/Passes.td b/include/npcomp/Dialect/Numpy/Transforms/Passes.td index d1bc895f6..d7ceaccd1 100644 --- a/include/npcomp/Dialect/Numpy/Transforms/Passes.td +++ b/include/npcomp/Dialect/Numpy/Transforms/Passes.td @@ -38,4 +38,26 @@ def NumpyArrayToTensor : Pass<"numpy-array-to-tensor", "FuncOp"> { } +def NumpyRefinePublicReturn : Pass<"numpy-refine-public-return", "ModuleOp"> { + let summary = "Refine public return"; + let constructor = "mlir::NPCOMP::Numpy::createRefinePublicReturnPass()"; + let description = [{ + Refines types of values return from public functions based on + intraprocedural information. + + This pass effectively encodes an assumption by the pass pipeline author that + the public calling convention of the module can have its types refined, + without causing ABI mismatches. This is frequently true -- for example, in + many systems, `tensor`, `tensor<3x3xf32>` and + `tensor<*x!numpy.any_dtype>` are all the same data structure on calling + convention boundaries. + + This pass is expected to run after shape refinement has occurred to + otherwise resolve shapes, and is currently mainly useful to convert + rank/dtype-erased function boundaries to ranked, dtyped code for + compiler backends. + }]; +} + + #endif // NPCOMP_NUMPY_PASSES diff --git a/lib/Dialect/Numpy/Transforms/CMakeLists.txt b/lib/Dialect/Numpy/Transforms/CMakeLists.txt index f4591faf4..36101c474 100644 --- a/lib/Dialect/Numpy/Transforms/CMakeLists.txt +++ b/lib/Dialect/Numpy/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_npcomp_conversion_library(NPCOMPNumpyPasses ArrayToTensor.cpp Passes.cpp PublicFunctionToTensor.cpp + RefinePublicReturn.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms diff --git a/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp new file mode 100644 index 000000000..af27bb387 --- /dev/null +++ b/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp @@ -0,0 +1,82 @@ +//===- RefinePublicReturn.cpp ------------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" +#include "npcomp/Dialect/Numpy/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Numpy; + +namespace { + +class RefinePublicReturnPass + : public NumpyRefinePublicReturnBase { + void runOnOperation() override { + auto module = getOperation(); + module.walk([&](FuncOp func) { + if (func.getVisibility() != SymbolTable::Visibility::Public) + return; + if (func.isExternal()) + return; + auto uses = SymbolTable::getSymbolUses(func, module); + if (!uses || uses->begin() != uses->end()) { + func.emitError() << "unimplemented: cannot refine public return for " + << "for public function with uses"; + return signalPassFailure(); + } + rewriteSignature(func); + }); + } + + void rewriteSignature(FuncOp func) { + // Find the unique return op. + ReturnOp returnOp; + WalkResult walkResult = func.walk([&](ReturnOp op) { + if (returnOp) + return WalkResult::interrupt(); + returnOp = op; + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + func.emitError() << "unimplemented: refining returns for function with " + "more than one return op"; + return signalPassFailure(); + } + + // Get the new operands. Either the original operand, or if there is a + // TensorStaticInfoCastOp then the pre-casted operand, which is presumed to + // have a more precise type. + SmallVector newOperands; + for (auto operand : returnOp.getOperands()) { + if (auto cast = operand.getDefiningOp()) { + newOperands.push_back(cast.getOperand()); + } else { + newOperands.push_back(operand); + } + } + returnOp->setOperands(newOperands); + + // Update the function type. + auto funcType = func.getType(); + func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(), + ValueRange(newOperands).getTypes())); + } +}; + +} // namespace + +std::unique_ptr> +mlir::NPCOMP::Numpy::createRefinePublicReturnPass() { + return std::make_unique(); +} diff --git a/test/Dialect/Numpy/refine-public-return.mlir b/test/Dialect/Numpy/refine-public-return.mlir new file mode 100644 index 000000000..0c9d69de5 --- /dev/null +++ b/test/Dialect/Numpy/refine-public-return.mlir @@ -0,0 +1,51 @@ +// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-refine-public-return | FileCheck %s + +// CHECK-LABEL: func @basic( +// CHECK-SAME: %[[ARG:.*]]: tensor) -> (tensor, i1, tensor) { +// CHECK: %[[CTRUE:.*]] = constant true +// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[ARG]], %[[CTRUE]], %[[ARG]] : tensor, i1, tensor +func @basic(%arg0: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { + %ctrue = std.constant true + %cast = numpy.tensor_static_info_cast %arg0 : tensor to tensor<*x!numpy.any_dtype> + return %arg0, %ctrue, %cast : tensor, i1, tensor<*x!numpy.any_dtype> +} + +// No conversion on private function. +// CHECK-LABEL: func private @basic_private( +// CHECK-SAME: %[[ARG:.*]]: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { +// CHECK: %[[CTRUE:.*]] = constant true +// CHECK: %[[CASTED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[ARG]], %[[CTRUE]], %[[CASTED]] : tensor, i1, tensor<*x!numpy.any_dtype> +func private @basic_private(%arg0: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { + %ctrue = std.constant true + %cast = numpy.tensor_static_info_cast %arg0 : tensor to tensor<*x!numpy.any_dtype> + return %arg0, %ctrue, %cast : tensor, i1, tensor<*x!numpy.any_dtype> +} + + +// ----- + +// Call to public function. +// expected-error @+1 {{unimplemented}} +func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} + +func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// Multiple returns. +// expected-error @+1 {{unimplemented}} +func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %ctrue = constant true + cond_br %ctrue, ^bb1, ^bb2 +^bb1: + return %arg0 : tensor<*xf32> +^bb2: + return %arg0 : tensor<*xf32> +}