mirror of https://github.com/llvm/torch-mlir
Add NumpyPublicFunctionsToTensor pass.
* Rewrites public function signatures to operate on tensors (vs ndarray). * Most of our backends presume immutable tensors at public function boundaries.pull/1/head
parent
5ceb37c19b
commit
efbcf0aa44
|
@ -18,7 +18,6 @@ namespace NPCOMP {
|
|||
namespace Basicpy {
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createFunctionTypeInferencePass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createCPAFunctionTypeInferencePass();
|
||||
|
||||
} // namespace Basicpy
|
||||
} // namespace NPCOMP
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define NPCOMP_DIALECT_NUMPY_IR_NUMPY_DIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "npcomp/Dialect/Common.h"
|
||||
#include "npcomp/Typing/Analysis/CPA/Interfaces.h"
|
||||
|
||||
|
@ -64,6 +65,9 @@ public:
|
|||
/// unknown dimensions are -1.
|
||||
llvm::Optional<ArrayRef<int64_t>> getOptionalShape();
|
||||
|
||||
/// Converts to an equivalent TensorType.
|
||||
TensorType toTensorType();
|
||||
|
||||
// CPA::TypeMapInterface methods.
|
||||
Typing::CPA::TypeNode *mapToCPAType(Typing::CPA::Context &context);
|
||||
};
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPNumpyPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc NPCOMPNumpyTransforms ./)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
||||
#define NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Numpy {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPublicFunctionsToTensorPass();
|
||||
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSES_H
|
|
@ -0,0 +1,23 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_NUMPY_PASSES
|
||||
#define NPCOMP_NUMPY_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeInference
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NumpyPublicFunctionsToTensor : Pass<"numpy-public-functions-to-tensor", "ModuleOp"> {
|
||||
let summary = "Converts public functions to operate on tensors (instead of ndarray)";
|
||||
let constructor = "mlir::NPCOMP::Numpy::createPublicFunctionsToTensorPass()";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_NUMPY_PASSES
|
|
@ -40,6 +40,7 @@ add_mlir_library(NPCOMPInitAll
|
|||
NPCOMPBasicpyDialect
|
||||
NPCOMPBasicpyPasses
|
||||
NPCOMPNumpyDialect
|
||||
NPCOMPNumpyPasses
|
||||
NPCOMPTypingPasses
|
||||
|
||||
${ALL_DEPENDS}
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -210,6 +210,15 @@ llvm::Optional<ArrayRef<int64_t>> NdArrayType::getOptionalShape() {
|
|||
return getImpl()->getOptionalShape();
|
||||
}
|
||||
|
||||
TensorType NdArrayType::toTensorType() {
|
||||
auto shape = getOptionalShape();
|
||||
if (shape) {
|
||||
return RankedTensorType::get(*shape, getDtype());
|
||||
} else {
|
||||
return UnrankedTensorType::get(getDtype());
|
||||
}
|
||||
}
|
||||
|
||||
Typing::CPA::TypeNode *
|
||||
NdArrayType::mapToCPAType(Typing::CPA::Context &context) {
|
||||
llvm::Optional<Typing::CPA::TypeNode *> dtype;
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_conversion_library(NPCOMPNumpyPasses
|
||||
PublicFunctionToTensor.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Numpy/Transforms
|
||||
|
||||
DEPENDS
|
||||
NPCOMPNumpyPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
NPCOMPNumpyDialect
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
//===- PassDetail.h - Pass details ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
||||
#define NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace Numpy {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace Numpy
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // NPCOMP_DIALECT_NUMPY_TRANSFORMS_PASSDETAIL_H
|
|
@ -0,0 +1,98 @@
|
|||
//===- PublicFunctionToTensor.cpp - Type inference passes --------*- C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyOps.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
|
||||
namespace {
|
||||
|
||||
class PublicFunctionsToTensorPass
|
||||
: public NumpyPublicFunctionsToTensorBase<PublicFunctionsToTensorPass> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
module.walk([&](FuncOp func) {
|
||||
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
||||
return;
|
||||
if (func.isExternal())
|
||||
return;
|
||||
auto uses = SymbolTable::getSymbolUses(func, module);
|
||||
if (!uses || uses->begin() != uses->end()) {
|
||||
func.emitWarning() << "unimplemented: cannot convert ndarray->tensor "
|
||||
<< "signature for public function with uses";
|
||||
return;
|
||||
}
|
||||
rewriteSignature(func);
|
||||
});
|
||||
}
|
||||
|
||||
void rewriteSignature(FuncOp func) {
|
||||
auto &entryBlock = func.getBlocks().front();
|
||||
auto funcType = func.getType();
|
||||
auto loc = func.getLoc();
|
||||
|
||||
// Rewrite inputs.
|
||||
auto builder = OpBuilder::atBlockBegin(&entryBlock);
|
||||
auto inputTypes = llvm::to_vector<4>(funcType.getInputs());
|
||||
for (unsigned i = 0; i < inputTypes.size(); ++i) {
|
||||
auto arrayType = inputTypes[i].dyn_cast<NdArrayType>();
|
||||
if (!arrayType)
|
||||
continue;
|
||||
Type tensorType = arrayType.toTensorType();
|
||||
BlockArgument argument = entryBlock.getArgument(i);
|
||||
argument.setType(tensorType);
|
||||
auto createOp =
|
||||
builder.create<CreateArrayFromTensorOp>(loc, arrayType, argument);
|
||||
argument.replaceAllUsesExcept(createOp,
|
||||
SmallPtrSet<Operation *, 1>{createOp});
|
||||
inputTypes[i] = tensorType;
|
||||
}
|
||||
|
||||
// Rewrite result signature.
|
||||
auto resultTypes = llvm::to_vector<4>(funcType.getResults());
|
||||
for (auto &resultType : resultTypes) {
|
||||
auto arrayType = resultType.dyn_cast<NdArrayType>();
|
||||
if (arrayType)
|
||||
resultType = arrayType.toTensorType();
|
||||
}
|
||||
|
||||
// Update signature.
|
||||
funcType =
|
||||
FunctionType::get(inputTypes, resultTypes, funcType.getContext());
|
||||
func.setType(funcType);
|
||||
|
||||
// Rewrite all return terminators.
|
||||
func.walk([&](ReturnOp term) {
|
||||
OpBuilder builder(term);
|
||||
for (unsigned i = 0; i < term.getNumOperands(); ++i) {
|
||||
Value operand = term.getOperand(i);
|
||||
auto arrayType = operand.getType().dyn_cast<NdArrayType>();
|
||||
if (!arrayType)
|
||||
continue;
|
||||
Type tensorType = arrayType.toTensorType();
|
||||
auto copyOp = builder.create<CopyToTensorOp>(loc, tensorType, operand);
|
||||
term.setOperand(i, copyOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::Numpy::createPublicFunctionsToTensorPass() {
|
||||
return std::make_unique<PublicFunctionsToTensorPass>();
|
||||
}
|
|
@ -12,6 +12,7 @@
|
|||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||
#include "npcomp/Dialect/TCF/IR/TCFDialect.h"
|
||||
#include "npcomp/Dialect/TCP/IR/TCPDialect.h"
|
||||
#include "npcomp/Typing/Transforms/Passes.h"
|
||||
|
@ -93,6 +94,8 @@ void mlir::NPCOMP::registerAllPasses() {
|
|||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h.inc"
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h.inc"
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/Typing/Transforms/Passes.h.inc"
|
||||
registerDependencyPasses();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// RUN: npcomp-opt -split-input-file %s -verify-diagnostics -allow-unregistered-dialect -numpy-public-functions-to-tensor | FileCheck --dump-input=fail %s
|
||||
|
||||
// CHECK-LABEL: legalConversion
|
||||
module @legalConversion {
|
||||
// CHECK: @f(%arg0: tensor<3x?xf32>, %arg1: i32, %arg2: tensor<*xf32>) -> (i32, tensor<3x?xf32>, tensor<*xf32>)
|
||||
func @f(%arg0: !numpy.ndarray<[3,?]:f32>, %arg1: i32, %arg2: !numpy.ndarray<*:f32>) ->
|
||||
(i32, !numpy.ndarray<[3,?]:f32>, !numpy.ndarray<*:f32>) {
|
||||
// CHECK: %[[CREATE0:.+]] = numpy.create_array_from_tensor %arg0
|
||||
// CHECK: %[[CREATE1:.+]] = numpy.create_array_from_tensor %arg2
|
||||
// CHECK: %[[R0:.+]] = "unfoldable_arg0"(%[[CREATE0]]) : (!numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||
// CHECK: %[[R1:.+]] = "unfoldable_arg1"(%[[CREATE1]]) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||
%0 = "unfoldable_arg0"(%arg0) : (!numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||
%1 = "unfoldable_arg1"(%arg2) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||
// CHECK: %[[COPY0:.+]] = numpy.copy_to_tensor %[[R0]]
|
||||
// CHECK: %[[COPY1:.+]] = numpy.copy_to_tensor %[[R1]]
|
||||
// CHECK: return %arg1, %[[COPY0]], %[[COPY1]] : i32, tensor<3x?xf32>, tensor<*xf32>
|
||||
return %arg1, %0, %1 : i32, !numpy.ndarray<[3,?]:f32>, !numpy.ndarray<*:f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @nonPublic
|
||||
module @nonPublic {
|
||||
// CHECK: @f(%arg0: !numpy.ndarray<[3,?]:f32>) -> !numpy.ndarray<[3,?]:f32>
|
||||
func @f(%arg0: !numpy.ndarray<[3,?]:f32>) -> (!numpy.ndarray<[3,?]:f32>)
|
||||
attributes { sym_visibility = "private" } {
|
||||
return %arg0 : !numpy.ndarray<[3,?]:f32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @called
|
||||
module @called {
|
||||
// CHECK: @f(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||
// expected-warning @+1 {{unimplemented: cannot convert}}
|
||||
func @f(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32> {
|
||||
return %arg0 : !numpy.ndarray<*:f32>
|
||||
}
|
||||
|
||||
func @caller(%arg0: !numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||
attributes { sym_visibility = "private" } {
|
||||
%0 = call @f(%arg0) : (!numpy.ndarray<*:f32>) -> !numpy.ndarray<*:f32>
|
||||
return %0 : !numpy.ndarray<*:f32>
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue