//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/JITRuntime/JITModule.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "npcomp/E2E/E2E.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" using namespace npcomp; using namespace mlir; using llvm::Error; using llvm::Expected; using llvm::StringError; using llvm::Twine; /// Wrap a string into an llvm::StringError. static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } JITModule::JITModule() {} void JITModule::buildBackendCompilationPipeline(PassManager &pm, bool optimize) { NPCOMP::E2ELoweringPipelineOptions options; options.optimize = optimize; NPCOMP::createE2ELoweringPipeline(pm, options); } llvm::Expected> JITModule::fromCompiledModule(mlir::ModuleOp module, llvm::ArrayRef sharedLibs) { auto expectedEngine = ExecutionEngine::create( module, [](llvm::Module *) { return Error::success(); }, /*jitCodeGenOptLevel=*/llvm::None, llvm::to_vector<6>(sharedLibs)); if (!expectedEngine) return expectedEngine.takeError(); std::unique_ptr ret(new JITModule); ret->engine = std::move(*expectedEngine); // Here we abuse mlir::ExecutionEngine a bit. It technically returns a // function pointer, but here we look up a module descriptor. auto expectedAddress = ret->engine->lookup("__npcomp_module_descriptor"); if (!expectedAddress) return expectedAddress.takeError(); ret->descriptor = reinterpret_cast(*expectedAddress); return ret; } // Converter for bridging to npcomprt llvm-lookalike data structures. static npcomprt::StringRef toNpcomprt(llvm::StringRef s) { return npcomprt::StringRef(s.data(), s.size()); } template static npcomprt::ArrayRef toNpcomprt(llvm::ArrayRef a) { return npcomprt::ArrayRef(a.data(), a.size()); } template static npcomprt::MutableArrayRef toNpcomprt(llvm::MutableArrayRef a) { return npcomprt::MutableArrayRef(a.data(), a.size()); } llvm::Expected, 6>> JITModule::invoke(llvm::StringRef functionName, llvm::ArrayRef> inputs) { npcomprt::FunctionMetadata metadata; if (npcomprt::failed(npcomprt::getMetadata( descriptor, toNpcomprt(functionName), metadata))) return make_string_error("unknown function: " + Twine(functionName)); SmallVector, 6> outputs(metadata.numOutputs); if (metadata.numInputs != static_cast(inputs.size())) return make_string_error("invoking '" + Twine(functionName) + "': expected " + Twine(metadata.numInputs) + " inputs"); npcomprt::invoke( descriptor, toNpcomprt(functionName), toNpcomprt(inputs), toNpcomprt(llvm::makeMutableArrayRef(outputs.data(), outputs.size()))); return outputs; }