//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "npcomp/RefBackend/JITHelpers/JITModule.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "npcomp/RefBackend/RefBackend.h" #include using namespace refback; using namespace mlir; using llvm::Error; using llvm::Expected; using llvm::StringError; using llvm::Twine; /// Wrap a string into an llvm::StringError. static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } JITModule::JITModule() {} void JITModule::buildBackendCompilationPipeline(PassManager &pm, bool optimize) { NPCOMP::RefBackendLoweringPipelineOptions options; options.optimize = optimize; NPCOMP::createTCFRefBackendLoweringPipeline(pm, options); } llvm::Expected> JITModule::fromCompiledModule(mlir::ModuleOp module, llvm::ArrayRef sharedLibs) { // Ensure LLVM Dialect -> LLVM IR translations are available. mlir::registerLLVMDialectTranslation(*module->getContext()); // Build the JITModule. auto expectedEngine = ExecutionEngine::create( module, /*llvmModuleBuilder=*/nullptr, /*transformer=*/[](llvm::Module *) { return Error::success(); }, /*jitCodeGenOptLevel=*/llvm::None, llvm::to_vector<6>(sharedLibs)); if (!expectedEngine) return expectedEngine.takeError(); std::unique_ptr ret(new JITModule); ret->engine = std::move(*expectedEngine); // Here we abuse mlir::ExecutionEngine a bit. It technically returns a // function pointer, but here we look up a module descriptor. auto expectedAddress = ret->engine->lookup("__npcomp_module_descriptor"); if (!expectedAddress) return expectedAddress.takeError(); ret->descriptor = reinterpret_cast(*expectedAddress); return std::move(ret); } // Converter for bridging to refbackrt llvm-lookalike data structures. static refbackrt::StringRef toRefbackrt(llvm::StringRef s) { return refbackrt::StringRef(s.data(), s.size()); } template static refbackrt::ArrayRef toRefbackrt(llvm::ArrayRef a) { return refbackrt::ArrayRef(a.data(), a.size()); } template static refbackrt::MutableArrayRef toRefbackrt(llvm::MutableArrayRef a) { return refbackrt::MutableArrayRef(a.data(), a.size()); } static std::string stringifyShape(refbackrt::ArrayRef extents) { static constexpr char kDynamicDimAsString[] = "?"; std::stringstream ss; ss << "("; for (int i = 0, e = extents.size(); i < e; i++) { if (extents[i] < 0) ss << kDynamicDimAsString; else ss << extents[i]; if (i != e - 1) ss << "x"; } ss << ")"; return ss.str(); } llvm::Expected> JITModule::invoke(llvm::StringRef functionName, llvm::ArrayRef inputs) { refbackrt::FunctionMetadata metadata; if (refbackrt::failed(refbackrt::getMetadata( descriptor, toRefbackrt(functionName), metadata))) return make_string_error("unknown function: " + Twine(functionName)); SmallVector outputs(metadata.numOutputs); if (metadata.numInputs != static_cast(inputs.size())) return make_string_error("invoking '" + Twine(functionName) + "': expected " + Twine(metadata.numInputs) + " inputs"); // Verify user input types and shapes match what the compiler expects for (int i = 0; i < metadata.numInputs; i++) { auto &input = inputs[i]; auto &inputArgInfo = metadata.inputArgInfos[i]; if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo))) return make_string_error( "invoking '" + Twine(functionName) + "': input argument type mismatch. actual (provided by user): " + Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " + Twine(getArgTypeAsStringRef(inputArgInfo.argType).str())); if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo))) return make_string_error( "invoking '" + Twine(functionName) + "': input shape mismatch (%arg" + Twine(i) + "). " + "actual (provided by user): " + stringifyShape(input.toTensor()->getExtents()) + ", expected (from compiler): " + stringifyShape(refbackrt::ArrayRef( inputArgInfo.extents.data(), inputArgInfo.rank))); } // Create the correct output RtValue based on FuncMetadata, // which contains the arg types (scalar, Tensor, etc.), element types (only // applicable if not scalar) and shapes (also only applicable if not scalar) // // Currently we have to give each RtValue an output type so that we know // how to pack / unpack the outputs properly across the ABI boundary in // refbackrt::invoke. As a result, we can't just rely on the default // construction of each output argument type (otherwise RtValue will have // Tag::kNone) currently without passing the ArgInfo structs down to the // Runtime level, so we deal with the output type creation here. for (int i = 0; i < metadata.numOutputs; i++) { outputs[i] = refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]); } refbackrt::invoke( descriptor, toRefbackrt(functionName), toRefbackrt(inputs), toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size()))); return outputs; }