//===- mnist-playground.cpp -------------------------------------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/AsmState.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Parser.h" #include "mlir/Pass/PassManager.h" #include "npcomp/InitAll.h" #include "npcomp/RefBackend/JITHelpers/JITModule.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" #include #include using namespace mlir; using llvm::Error; using llvm::ErrorOr; using llvm::Expected; using llvm::StringError; using llvm::Twine; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// /// Wrap a string into an llvm::StringError. static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } Expected> createJITModule(std::string mlirFile, mlir::DialectRegistry ®istry, ArrayRef sharedLibs, bool optimize) { MLIRContext context; registry.loadAll(&context); OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context); if (!moduleRef) return make_string_error(Twine("could not open ") + mlirFile); ModuleOp module = *moduleRef; // Compile. PassManager pm(module.getContext(), /*verifyPasses=*/true); applyPassManagerCLOptions(pm); refback::JITModule::buildBackendCompilationPipeline(pm, optimize); if (failed(pm.run(module))) return make_string_error(Twine("error compiling to jit backend")); return refback::JITModule::fromCompiledModule(module, sharedLibs); } //===----------------------------------------------------------------------===// // Benchmarking / correctness-testing code. //===----------------------------------------------------------------------===// static Expected> invokeJITModuleWithATenTensors(refback::JITModule &jitModule, StringRef invokeFunction, std::vector &args) { // Do a bit of checking. We don't handle all possible tensors right now. std::vector tensorArgs; for (auto arg : llvm::enumerate(args)) tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index())); at::CheckedFrom c = "converting to refbackrt::Tensor"; for (auto &tensorArg : tensorArgs) at::checkScalarType(c, tensorArg, at::ScalarType::Float); at::checkAllContiguous(c, tensorArgs); SmallVector, 6> refbackInputs; for (at::Tensor arg : args) { SmallVector extents(arg.sizes().begin(), arg.sizes().end()); float *data = arg.storage().data(); // This does a deep copy of the data. Let's see if it shows up on the // profile. refbackInputs.push_back(refbackrt::Tensor::create( refbackrt::ArrayRef(extents.data(), extents.size()), refbackrt::ElementType::F32, data)); } // Invoke the RefBackend function. auto expectedOutputs = jitModule.invoke(invokeFunction, refbackInputs); if (!expectedOutputs) return expectedOutputs.takeError(); auto refbackrtOutputs = std::move(*expectedOutputs); std::vector results; for (auto output : refbackrtOutputs) { std::vector sizes(output->getExtents().data(), output->getExtents().data() + output->getExtents().size()); // Make a copy for passing to at::from_blob, which does its own internal // reference counting. auto *dataCopy = std::malloc(output->getDataByteSize()); std::memcpy(dataCopy, output->getData(), output->getDataByteSize()); results.push_back(at::from_blob( dataCopy, sizes, [](void *p) { std::free(p); }, at::kFloat)); } return results; } using InvocationFunction = std::function>(std::vector)>; struct BenchmarkResult { int numRuns; float nsPerRun; }; std::ostream &operator<<(std::ostream &os, const BenchmarkResult &result) { os << "numRuns: " << result.numRuns << " nsPerRun: " << std::scientific << result.nsPerRun << std::defaultfloat; return os; } Expected benchmark(std::function f) { for (int itersAtATime = 1;; itersAtATime *= 2) { auto start = std::chrono::steady_clock::now(); for (int i = 0; i < itersAtATime; i++) { auto error = f(); if (error) return std::move(error); } auto end = std::chrono::steady_clock::now(); std::chrono::duration elapsed = end - start; // If the runtime is longer than 0.5 seconds, it's reliable enough. if (elapsed.count() > 0.5f) { BenchmarkResult result; result.numRuns = itersAtATime; result.nsPerRun = elapsed.count() * 10e9 / itersAtATime; return result; } } return make_string_error("too short running to benchmark!"); } static Error doIt(InvocationFunction ptFunc, InvocationFunction refBackendFunc, bool doBenchmark, int numCorrectnessTests) { torch::manual_seed(42); torch::set_num_threads(1); std::vector args; args.push_back(at::rand({784, 100})); args.push_back(at::rand({10, 784})); args.push_back(at::rand({10, 1})); // Initial correctness check of the two functions. for (int correctnessTest = 0; correctnessTest < numCorrectnessTests; correctnessTest++) { auto expectedPt = ptFunc(args); auto expectedRefBackend = refBackendFunc(args); if (!expectedPt) return expectedPt.takeError(); if (!expectedRefBackend) return expectedRefBackend.takeError(); auto pt = std::move(*expectedPt); auto refBackend = std::move(*expectedRefBackend); if (pt.size() != refBackend.size()) return make_string_error("mismatch in result arity!"); for (int i = 0, e = pt.size(); i < e; i++) { if (!at::allclose(pt[i], refBackend[i])) { std::cout << "PyTorch:\n" << pt[i] << "\n"; std::cout << "RefBackend:\n" << refBackend[i] << "\n"; return make_string_error(Twine("mismatch in result contents ") + Twine(i) + Twine(" on correctness test #") + Twine(correctnessTest)); } } } if (!doBenchmark) return Error::success(); // Benchmark the two against each other. BenchmarkResult ptBenchmarkResult; BenchmarkResult refBackendBenchmarkResult; { auto expectedResult = benchmark([&]() -> Error { return ptFunc(args).takeError(); }); if (!expectedResult) return expectedResult.takeError(); ptBenchmarkResult = std::move(*expectedResult); } { auto expectedResult = benchmark([&]() -> Error { return refBackendFunc(args).takeError(); }); if (!expectedResult) return expectedResult.takeError(); refBackendBenchmarkResult = std::move(*expectedResult); } std::cout << "PyTorch: " << ptBenchmarkResult << "\n"; std::cout << "RefBackend: " << refBackendBenchmarkResult << "\n"; std::cout << "Ratio (RefBackend / PyTorch): " << refBackendBenchmarkResult.nsPerRun / ptBenchmarkResult.nsPerRun << "\n"; // TODO: Check for memory leaks? return Error::success(); } //===----------------------------------------------------------------------===// // Main-related init and option parsing. //===----------------------------------------------------------------------===// namespace { namespace cl = llvm::cl; struct Options { cl::opt inputFile{ cl::Positional, cl::desc("the input .mlir file"), cl::init("-")}; cl::opt invokeFunction{"invoke", cl::Required, cl::desc("function to invoke")}; cl::list sharedLibs{"shared-libs", cl::ZeroOrMore, cl::MiscFlags::CommaSeparated, cl::desc("Libraries to link dynamically")}; cl::opt optimize{ "optimize", cl::Optional, cl::desc("whether the refback pass pipeline should run optimizations"), cl::init(false)}; cl::opt benchmark{"benchmark", cl::Optional, cl::desc("whether to do a benchmark comparison"), cl::init(true)}; cl::opt numCorrectnessTests{ "num-correctness-tests", cl::Optional, cl::desc("how many correctness tests to run (useful for nondeterministic " "correctness failures"), cl::init(1)}; }; } // namespace int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllDialects(registry); mlir::registerAllPasses(); mlir::NPCOMP::registerAllDialects(registry); mlir::NPCOMP::registerAllPasses(); llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); mlir::initializeLLVMPasses(); mlir::registerAsmPrinterCLOptions(); mlir::registerPassManagerCLOptions(); Options options; llvm::cl::ParseCommandLineOptions(argc, argv, "mnist playground utility\n"); SmallVector sharedLibs(options.sharedLibs.begin(), options.sharedLibs.end()); auto expectedJITModule = createJITModule(options.inputFile, registry, sharedLibs, options.optimize); if (Error error = expectedJITModule.takeError()) llvm::report_fatal_error(llvm::toString(std::move(error)), /*gen_crash_diag=*/false); auto jitModule = std::move(*expectedJITModule); Error error = doIt( [](std::vector args) { auto image = args[0]; auto weights = args[1]; auto biases = args[2]; auto v0 = at::matmul(weights, image); auto v1 = at::add(v0, biases); return std::vector{v1}; }, [&](std::vector args) { return invokeJITModuleWithATenTensors(*jitModule, options.invokeFunction, args); }, options.benchmark, options.numCorrectnessTests); int exitCode = EXIT_SUCCESS; llvm::handleAllErrors(std::move(error), [&exitCode](const llvm::ErrorInfoBase &info) { llvm::errs() << "Error: "; info.log(llvm::errs()); llvm::errs() << '\n'; exitCode = EXIT_FAILURE; }); return exitCode; }