diff --git a/include/npcomp/Dialect/Numpy/Transforms/Passes.h b/include/npcomp/Dialect/Numpy/Transforms/Passes.h index 18eab6a5d..f0f4abbd4 100644 --- a/include/npcomp/Dialect/Numpy/Transforms/Passes.h +++ b/include/npcomp/Dialect/Numpy/Transforms/Passes.h @@ -19,6 +19,7 @@ namespace Numpy { std::unique_ptr> createPublicFunctionsToTensorPass(); std::unique_ptr> createArrayToTensorPass(); +std::unique_ptr> createRefinePublicReturnPass(); } // namespace Numpy diff --git a/include/npcomp/Dialect/Numpy/Transforms/Passes.td b/include/npcomp/Dialect/Numpy/Transforms/Passes.td index d1bc895f6..d7ceaccd1 100644 --- a/include/npcomp/Dialect/Numpy/Transforms/Passes.td +++ b/include/npcomp/Dialect/Numpy/Transforms/Passes.td @@ -38,4 +38,26 @@ def NumpyArrayToTensor : Pass<"numpy-array-to-tensor", "FuncOp"> { } +def NumpyRefinePublicReturn : Pass<"numpy-refine-public-return", "ModuleOp"> { + let summary = "Refine public return"; + let constructor = "mlir::NPCOMP::Numpy::createRefinePublicReturnPass()"; + let description = [{ + Refines types of values return from public functions based on + intraprocedural information. + + This pass effectively encodes an assumption by the pass pipeline author that + the public calling convention of the module can have its types refined, + without causing ABI mismatches. This is frequently true -- for example, in + many systems, `tensor`, `tensor<3x3xf32>` and + `tensor<*x!numpy.any_dtype>` are all the same data structure on calling + convention boundaries. + + This pass is expected to run after shape refinement has occurred to + otherwise resolve shapes, and is currently mainly useful to convert + rank/dtype-erased function boundaries to ranked, dtyped code for + compiler backends. + }]; +} + + #endif // NPCOMP_NUMPY_PASSES diff --git a/lib/Dialect/Numpy/Transforms/CMakeLists.txt b/lib/Dialect/Numpy/Transforms/CMakeLists.txt index f4591faf4..36101c474 100644 --- a/lib/Dialect/Numpy/Transforms/CMakeLists.txt +++ b/lib/Dialect/Numpy/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_npcomp_conversion_library(NPCOMPNumpyPasses ArrayToTensor.cpp Passes.cpp PublicFunctionToTensor.cpp + RefinePublicReturn.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms diff --git a/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp new file mode 100644 index 000000000..af27bb387 --- /dev/null +++ b/lib/Dialect/Numpy/Transforms/RefinePublicReturn.cpp @@ -0,0 +1,82 @@ +//===- RefinePublicReturn.cpp ------------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "npcomp/Dialect/Numpy/IR/NumpyOps.h" +#include "npcomp/Dialect/Numpy/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Numpy; + +namespace { + +class RefinePublicReturnPass + : public NumpyRefinePublicReturnBase { + void runOnOperation() override { + auto module = getOperation(); + module.walk([&](FuncOp func) { + if (func.getVisibility() != SymbolTable::Visibility::Public) + return; + if (func.isExternal()) + return; + auto uses = SymbolTable::getSymbolUses(func, module); + if (!uses || uses->begin() != uses->end()) { + func.emitError() << "unimplemented: cannot refine public return for " + << "for public function with uses"; + return signalPassFailure(); + } + rewriteSignature(func); + }); + } + + void rewriteSignature(FuncOp func) { + // Find the unique return op. + ReturnOp returnOp; + WalkResult walkResult = func.walk([&](ReturnOp op) { + if (returnOp) + return WalkResult::interrupt(); + returnOp = op; + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) { + func.emitError() << "unimplemented: refining returns for function with " + "more than one return op"; + return signalPassFailure(); + } + + // Get the new operands. Either the original operand, or if there is a + // TensorStaticInfoCastOp then the pre-casted operand, which is presumed to + // have a more precise type. + SmallVector newOperands; + for (auto operand : returnOp.getOperands()) { + if (auto cast = operand.getDefiningOp()) { + newOperands.push_back(cast.getOperand()); + } else { + newOperands.push_back(operand); + } + } + returnOp->setOperands(newOperands); + + // Update the function type. + auto funcType = func.getType(); + func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(), + ValueRange(newOperands).getTypes())); + } +}; + +} // namespace + +std::unique_ptr> +mlir::NPCOMP::Numpy::createRefinePublicReturnPass() { + return std::make_unique(); +} diff --git a/test/Dialect/Numpy/refine-public-return.mlir b/test/Dialect/Numpy/refine-public-return.mlir new file mode 100644 index 000000000..0c9d69de5 --- /dev/null +++ b/test/Dialect/Numpy/refine-public-return.mlir @@ -0,0 +1,51 @@ +// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-refine-public-return | FileCheck %s + +// CHECK-LABEL: func @basic( +// CHECK-SAME: %[[ARG:.*]]: tensor) -> (tensor, i1, tensor) { +// CHECK: %[[CTRUE:.*]] = constant true +// CHECK: %[[CAST:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[ARG]], %[[CTRUE]], %[[ARG]] : tensor, i1, tensor +func @basic(%arg0: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { + %ctrue = std.constant true + %cast = numpy.tensor_static_info_cast %arg0 : tensor to tensor<*x!numpy.any_dtype> + return %arg0, %ctrue, %cast : tensor, i1, tensor<*x!numpy.any_dtype> +} + +// No conversion on private function. +// CHECK-LABEL: func private @basic_private( +// CHECK-SAME: %[[ARG:.*]]: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { +// CHECK: %[[CTRUE:.*]] = constant true +// CHECK: %[[CASTED:.*]] = numpy.tensor_static_info_cast %[[ARG]] : tensor to tensor<*x!numpy.any_dtype> +// CHECK: return %[[ARG]], %[[CTRUE]], %[[CASTED]] : tensor, i1, tensor<*x!numpy.any_dtype> +func private @basic_private(%arg0: tensor) -> (tensor, i1, tensor<*x!numpy.any_dtype>) { + %ctrue = std.constant true + %cast = numpy.tensor_static_info_cast %arg0 : tensor to tensor<*x!numpy.any_dtype> + return %arg0, %ctrue, %cast : tensor, i1, tensor<*x!numpy.any_dtype> +} + + +// ----- + +// Call to public function. +// expected-error @+1 {{unimplemented}} +func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} + +func private @caller(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = call @called(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// Multiple returns. +// expected-error @+1 {{unimplemented}} +func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %ctrue = constant true + cond_br %ctrue, ^bb1, ^bb2 +^bb1: + return %arg0 : tensor<*xf32> +^bb2: + return %arg0 : tensor<*xf32> +}