mirror of https://github.com/llvm/torch-mlir
55 lines
1.8 KiB
C++
55 lines
1.8 KiB
C++
//===------------------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef NPCOMP_JITRUNTIME_JITMODULE_H
|
|
#define NPCOMP_JITRUNTIME_JITMODULE_H
|
|
|
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Error.h"
|
|
|
|
#include <memory>
|
|
|
|
namespace mlir {
|
|
class PassManager;
|
|
} // namespace mlir
|
|
|
|
namespace refback {
|
|
// Wrapper around refbackrt data structures and a JITted module, facilitating
|
|
// interaction.
|
|
class JITModule {
|
|
public:
|
|
/// Populates a PassManager with a pipeline that performs backend compilation.
|
|
/// The resulting module can be passed to fromCompiledModule().
|
|
static void buildBackendCompilationPipeline(mlir::PassManager &pm,
|
|
bool optimize = false);
|
|
|
|
/// Constructs a JITModule from a compiled Module.
|
|
/// The module should be the result of having run the backend compilation
|
|
/// pipeline successfully.
|
|
static llvm::Expected<std::unique_ptr<JITModule>>
|
|
fromCompiledModule(mlir::ModuleOp module,
|
|
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
|
|
|
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
|
invoke(llvm::StringRef functionName,
|
|
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
|
|
|
|
private:
|
|
JITModule();
|
|
std::unique_ptr<mlir::ExecutionEngine> engine;
|
|
refbackrt::ModuleDescriptor *descriptor;
|
|
};
|
|
} // namespace refback
|
|
|
|
#endif // NPCOMP_JITRUNTIME_JITMODULE_H
|