//===------------------------------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Utility binary for compiling and running code through the npcomp // compiler/runtime stack. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AsmState.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR.h" #include "npcomp-c/InitLLVM.h" #include "npcomp/InitAll.h" #include "npcomp/RefBackend/JITHelpers/JITModule.h" #include "llvm/Support/InitLLVM.h" using namespace mlir; using llvm::Error; using llvm::Expected; using llvm::StringError; using llvm::Twine; /// Wrap a string into an llvm::StringError. static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } static Expected> convertAttrToTensor(Attribute attr) { auto type = attr.getType().dyn_cast(); if (!type) return make_string_error("unhandled argument type; must be a tensor type"); auto extents = llvm::to_vector<6>(llvm::map_range( type.getShape(), [](int64_t x) { return static_cast(x); })); auto elementType = type.getElementType(); auto denseFp = attr.dyn_cast(); if (denseFp) { if (elementType.isF32()) { auto values = llvm::to_vector<100>(llvm::map_range( denseFp, [](APFloat f) { return f.convertToFloat(); })); return refbackrt::Tensor::create( refbackrt::ArrayRef(extents.data(), extents.size()), refbackrt::ElementType::F32, static_cast(values.data())); } } else { return make_string_error("unhandled argument; must be dense floating-point"); } return make_string_error("unhandled argument"); } static Expected> createInputs(ArrayRef argValues) { MLIRContext context; SmallVector ret; for (auto argValue : argValues) { auto attr = parseAttribute(argValue, &context); if (!attr) return make_string_error(Twine("could not parse arg value: ") + argValue); // TODO(brycearden): Handle multiple input types auto expectedTensor = convertAttrToTensor(attr); if (!expectedTensor) return expectedTensor.takeError(); ret.push_back(std::move(*expectedTensor)); } return ret; } static Type convertToMLIRType(refbackrt::ElementType type, Builder &builder) { switch (type) { case refbackrt::ElementType::F32: return builder.getF32Type(); } llvm_unreachable("unsupported dtype"); } static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor, Builder &builder) { auto elementType = convertToMLIRType(tensor.getElementType(), builder); SmallVector extents; for (int i = 0, e = tensor.getRank(); i < e; i++) extents.push_back(tensor.getExtent(i)); return RankedTensorType::get(extents, elementType); } static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor, Builder &builder) { RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder); switch (tensor.getElementType()) { case refbackrt::ElementType::F32: { SmallVector values; auto *basePtr = tensor.getData(); for (int i = 0, e = type.getNumElements(); i < e; i++) values.push_back(basePtr[i]); return DenseFPElementsAttr::get(type, values); } } llvm_unreachable("unsupported dtype"); } static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) { MLIRContext context; Builder builder(&context); auto attr = convertToMLIRAttribute(tensor, builder); attr.print(os); } static void printOutputs(ArrayRef outputs, llvm::raw_ostream &os) { for (auto output : llvm::enumerate(outputs)) { assert(output.value().isTensor() && "only tensor outputs are supported."); os << "output #" << output.index() << ": "; printOutput(*output.value().toTensor().get(), os); os << "\n"; } } Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context, std::string invokeFunction, ArrayRef argValues, ArrayRef sharedLibs, bool optimize) { OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context); if (!moduleRef) return make_string_error(Twine("could not open ") + mlirFile); ModuleOp module = *moduleRef; // Compile. PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit); applyPassManagerCLOptions(pm); refback::JITModule::buildBackendCompilationPipeline(pm, optimize); if (failed(pm.run(module))) { return make_string_error(Twine("error compiling to jit backend")); } auto expectedJitModule = refback::JITModule::fromCompiledModule(module, sharedLibs); if (!expectedJitModule) return expectedJitModule.takeError(); auto jitModule = std::move(*expectedJitModule); auto expectedInputs = createInputs(argValues); if (!expectedInputs) return expectedInputs.takeError(); auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs); if (!expectedOutputs) return expectedOutputs.takeError(); auto outputs = std::move(*expectedOutputs); printOutputs(outputs, llvm::outs()); llvm::outs() << "SUCCESS\n"; return Error::success(); } //===----------------------------------------------------------------------===// // Main-related init and option parsing. //===----------------------------------------------------------------------===// namespace { namespace cl = llvm::cl; struct Options { cl::opt inputFile{ cl::Positional, cl::desc("the input .mlir file"), cl::init("-")}; cl::opt invokeFunction{"invoke", cl::Required, cl::desc("function to invoke")}; cl::list argValues{"arg-value", cl::ZeroOrMore, cl::desc("Arguments to the called function")}; cl::list sharedLibs{"shared-libs", cl::ZeroOrMore, cl::MiscFlags::CommaSeparated, cl::desc("Libraries to link dynamically")}; cl::opt optimize{ "optimize", cl::Optional, cl::desc("whether the refback pass pipeline should run optimizations"), cl::init(false)}; }; } // namespace int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::registerAllPasses(); mlir::NPCOMP::registerAllDialects(registry); mlir::NPCOMP::registerAllPasses(); MLIRContext context; context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); llvm::InitLLVM y(argc, argv); npcompInitializeLLVMCodegen(); mlir::registerAsmPrinterCLOptions(); mlir::registerPassManagerCLOptions(); Options options; llvm::cl::ParseCommandLineOptions(argc, argv, "npcomp compile+run utility\n"); SmallVector sharedLibs(options.sharedLibs.begin(), options.sharedLibs.end()); SmallVector argValues(options.argValues.begin(), options.argValues.end()); Error error = compileAndRun(options.inputFile, context, options.invokeFunction, argValues, sharedLibs, options.optimize); int exitCode = EXIT_SUCCESS; llvm::handleAllErrors(std::move(error), [&exitCode](const llvm::ErrorInfoBase &info) { llvm::errs() << "Error: "; info.log(llvm::errs()); llvm::errs() << '\n'; exitCode = EXIT_FAILURE; }); return exitCode; }