mirror of https://github.com/llvm/torch-mlir
Add NumpyPublicFunctionsToTensor pass.
* Rewrites public function signatures to operate on tensors (vs ndarray). * Most of our backends presume immutable tensors at public function boundaries.pull/1/head
parent
5ceb37c19b
commit
efbcf0aa44
|
@ -18,7 +18,6 @@ namespace NPCOMP {
|
||||||
namespace Basicpy {
|
namespace Basicpy {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createFunctionTypeInferencePass();
|
std::unique_ptr<OperationPass<FuncOp>> createFunctionTypeInferencePass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createCPAFunctionTypeInferencePass();
|
|
||||||
|
|
||||||
} // namespace Basicpy
|
} // namespace Basicpy
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H
|
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "npcomp/Dialect/Common.h"
|
#include "npcomp/Dialect/Common.h"
|
||||||
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
||||||
|
|
||||||
|
@ -64,6 +65,9 @@ public:
|
||||||
/// unknown dimensions are -1.
|
/// unknown dimensions are -1.
|
||||||
llvm::Optional<ArrayRef<int64_t>> getOptionalShape();
|
llvm::Optional<ArrayRef<int64_t>> getOptionalShape();
|
||||||
|
|
||||||
|
/// Converts to an equivalent TensorType.
|
||||||
|
TensorType toTensorType();
|
||||||
|
|
||||||
// CPA::TypeMapInterface methods.
|
// CPA::TypeMapInterface methods.
|
||||||
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
|
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
add_public_tablegen_target(NPCOMPNumpyPassIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(Passes -gen-pass-doc NPCOMPNumpyTransforms ./)
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
||||||
|
#define NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace Numpy {
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
|
||||||
|
|
||||||
|
} // namespace Numpy
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_NUMPY_PASSES
|
||||||
|
#define NPCOMP_NUMPY_PASSES
|
||||||
|
|
||||||
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TypeInference
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def NumpyPublicFunctionsToTensor : Pass<"numpy-public-functions-to-tensor", "ModuleOp"> {
|
||||||
|
let summary = "Converts public functions to operate on tensors (instead of ndarray)";
|
||||||
|
let constructor = "mlir::NPCOMP::Numpy::createPublicFunctionsToTensorPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // NPCOMP_NUMPY_PASSES
|
|
@ -40,6 +40,7 @@ add_mlir_library(NPCOMPInitAll
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPBasicpyPasses
|
NPCOMPBasicpyPasses
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
|
NPCOMPNumpyPasses
|
||||||
NPCOMPTypingPasses
|
NPCOMPTypingPasses
|
||||||
|
|
||||||
${ALL_DEPENDS}
|
${ALL_DEPENDS}
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(IR)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
|
|
|
@ -210,6 +210,15 @@ llvm::Optional<ArrayRef<int64_t>> NdArrayType::getOptionalShape() {
|
||||||
return getImpl()->getOptionalShape();
|
return getImpl()->getOptionalShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TensorType NdArrayType::toTensorType() {
|
||||||
|
auto shape = getOptionalShape();
|
||||||
|
if (shape) {
|
||||||
|
return RankedTensorType::get(*shape, getDtype());
|
||||||
|
} else {
|
||||||
|
return UnrankedTensorType::get(getDtype());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Typing::CPA::TypeNode *
|
Typing::CPA::TypeNode *
|
||||||
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
|
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
|
||||||
llvm::Optional<Typing::CPA::TypeNode *> dtype;
|
llvm::Optional<Typing::CPA::TypeNode *> dtype;
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_mlir_conversion_library(NPCOMPNumpyPasses
|
||||||
|
PublicFunctionToTensor.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPNumpyPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
NPCOMPNumpyDialect
|
||||||
|
)
|
|
@ -0,0 +1,25 @@
|
||||||
|
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
||||||
|
#define NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
namespace Numpy {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h.inc"
|
||||||
|
|
||||||
|
} // namespace Numpy
|
||||||
|
} // namespace NPCOMP
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
|
@ -0,0 +1,98 @@
|
||||||
|
//===- PublicFunctionToTensor.cpp - Type inference passes --------*- C++-*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP::Numpy;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class PublicFunctionsToTensorPass
|
||||||
|
: public NumpyPublicFunctionsToTensorBase<PublicFunctionsToTensorPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
module.walk([&](FuncOp func) {
|
||||||
|
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
||||||
|
return;
|
||||||
|
if (func.isExternal())
|
||||||
|
return;
|
||||||
|
auto uses = SymbolTable::getSymbolUses(func, module);
|
||||||
|
if (!uses || uses->begin() != uses->end()) {
|
||||||
|
func.emitWarning() << "unimplemented: cannot convert ndarray->tensor "
|
||||||
|
<< "signature for public function with uses";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rewriteSignature(func);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewriteSignature(FuncOp func) {
|
||||||
|
auto &entryBlock = func.getBlocks().front();
|
||||||
|
auto funcType = func.getType();
|
||||||
|
auto loc = func.getLoc();
|
||||||
|
|
||||||
|
// Rewrite inputs.
|
||||||
|
auto builder = OpBuilder::atBlockBegin(&entryBlock);
|
||||||
|
auto inputTypes = llvm::to_vector<4>(funcType.getInputs());
|
||||||
|
for (unsigned i = 0; i < inputTypes.size(); ++i) {
|
||||||
|
auto arrayType = inputTypes[i].dyn_cast<NdArrayType>();
|
||||||
|
if (!arrayType)
|
||||||
|
continue;
|
||||||
|
Type tensorType = arrayType.toTensorType();
|
||||||
|
BlockArgument argument = entryBlock.getArgument(i);
|
||||||
|
argument.setType(tensorType);
|
||||||
|
auto createOp =
|
||||||
|
builder.create<CreateArrayFromTensorOp>(loc, arrayType, argument);
|
||||||
|
argument.replaceAllUsesExcept(createOp,
|
||||||
|
SmallPtrSet<Operation *, 1>{createOp});
|
||||||
|
inputTypes[i] = tensorType;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite result signature.
|
||||||
|
auto resultTypes = llvm::to_vector<4>(funcType.getResults());
|
||||||
|
for (auto &resultType : resultTypes) {
|
||||||
|
auto arrayType = resultType.dyn_cast<NdArrayType>();
|
||||||
|
if (arrayType)
|
||||||
|
resultType = arrayType.toTensorType();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update signature.
|
||||||
|
funcType =
|
||||||
|
FunctionType::get(inputTypes, resultTypes, funcType.getContext());
|
||||||
|
func.setType(funcType);
|
||||||
|
|
||||||
|
// Rewrite all return terminators.
|
||||||
|
func.walk([&](ReturnOp term) {
|
||||||
|
OpBuilder builder(term);
|
||||||
|
for (unsigned i = 0; i < term.getNumOperands(); ++i) {
|
||||||
|
Value operand = term.getOperand(i);
|
||||||
|
auto arrayType = operand.getType().dyn_cast<NdArrayType>();
|
||||||
|
if (!arrayType)
|
||||||
|
continue;
|
||||||
|
Type tensorType = arrayType.toTensorType();
|
||||||
|
auto copyOp = builder.create<CopyToTensorOp>(loc, tensorType, operand);
|
||||||
|
term.setOperand(i, copyOp);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::Numpy::createPublicFunctionsToTensorPass() {
|
||||||
|
return std::make_unique<PublicFunctionsToTensorPass>();
|
||||||
|
}
|
|
@ -12,6 +12,7 @@
|
||||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||||
#include "npcomp/Typing/Transforms/Passes.h"
|
#include "npcomp/Typing/Transforms/Passes.h"
|
||||||
|
@ -93,6 +94,8 @@ void mlir::NPCOMP::registerAllPasses() {
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h.inc"
|
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h.inc"
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h.inc"
|
||||||
|
#define GEN_PASS_REGISTRATION
|
||||||
#include "npcomp/Typing/Transforms/Passes.h.inc"
|
#include "npcomp/Typing/Transforms/Passes.h.inc"
|
||||||
registerDependencyPasses();
|
registerDependencyPasses();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-public-functions-to-tensor | FileCheck --dump-input=fail %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: legalConversion
|
||||||
|
module @legalConversion {
|
||||||
|
// CHECK: @f(%arg0: tensor<3x?xf32>, %arg1: i32, %arg2: tensor<*xf32>) -> (i32, tensor<3x?xf32>, tensor<*xf32>)
|
||||||
|
func @f(%arg0: !numpy.ndarray<[3,?]:f32>, %arg1: i32, %arg2: !numpy.ndarray<*:f32>) ->
|
||||||
|
(i32, !numpy.ndarray<[3,?]:f32>, !numpy.ndarray<*:f32>) {
|
||||||
|
// CHECK: %[[CREATE0:.+]] = numpy.create_array_from_tensor %arg0
|
||||||
|
// CHECK: %[[CREATE1:.+]] = numpy.create_array_from_tensor %arg2
|
||||||
|
// CHECK: %[[R0:.+]] = "unfoldable_arg0"(%[[CREATE0]]) : (!numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||||
|
// CHECK: %[[R1:.+]] = "unfoldable_arg1"(%[[CREATE1]]) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||||
|
%0 = "unfoldable_arg0"(%arg0) : (!numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||||
|
%1 = "unfoldable_arg1"(%arg2) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||||
|
// CHECK: %[[COPY0:.+]] = numpy.copy_to_tensor %[[R0]]
|
||||||
|
// CHECK: %[[COPY1:.+]] = numpy.copy_to_tensor %[[R1]]
|
||||||
|
// CHECK: return %arg1, %[[COPY0]], %[[COPY1]] : i32, tensor<3x?xf32>, tensor<*xf32>
|
||||||
|
return %arg1, %0, %1 : i32, !numpy.ndarray<[3,?]:f32>, !numpy.ndarray<*:f32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @nonPublic
|
||||||
|
module @nonPublic {
|
||||||
|
// CHECK: @f(%arg0: !numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||||
|
func @f(%arg0: !numpy.ndarray<[3,?]:f32>) -> (!numpy.ndarray<[3,?]:f32>)
|
||||||
|
attributes { sym_visibility = "private" } {
|
||||||
|
return %arg0 : !numpy.ndarray<[3,?]:f32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: @called
|
||||||
|
module @called {
|
||||||
|
// CHECK: @f(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||||
|
// expected-warning @+1 {{unimplemented: cannot convert}}
|
||||||
|
func @f(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32> {
|
||||||
|
return %arg0 : !numpy.ndarray<*:f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @caller(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||||
|
attributes { sym_visibility = "private" } {
|
||||||
|
%0 = call @f(%arg0) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||||
|
return %0 : !numpy.ndarray<*:f32>
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue