mirror of https://github.com/llvm/torch-mlir
Add hopefully short-lived mnist-playground utility.
This unblocks backend progress while the PyTorch frontend work is coming online. Hopefully we can delete this soon. See tools/mnist-playground/README.md for more context on what this tool is for, next steps, and current status.pull/67/head
parent
8022dfaf1a
commit
dd1fa2607f
|
@ -1,2 +1,3 @@
|
|||
add_subdirectory(npcomp-opt)
|
||||
add_subdirectory(mnist-playground)
|
||||
add_subdirectory(npcomp-run-mlir)
|
||||
|
|
|
@ -28,6 +28,19 @@ npcomp-run-mlir() {
|
|||
-shared-libs="${build_dir}/lib/libNPCOMPCompilerRuntimeShlib.so" "$@"
|
||||
}
|
||||
|
||||
mnist-playground() {
|
||||
# Helper for building and invoking mnist-playground
|
||||
#
|
||||
# This also automatically builds and adds the npcomp runtime shared
|
||||
# library.
|
||||
#
|
||||
# Usage:
|
||||
# $ mnist-playground <regular mnist-playground options>
|
||||
ninja -C "$build_dir" mnist-playground NPCOMPCompilerRuntimeShlib 1>&2 || return 1
|
||||
$build_dir/tools/mnist-playground/mnist-playground \
|
||||
-shared-libs="${build_dir}/lib/libNPCOMPCompilerRuntimeShlib.so" "$@"
|
||||
}
|
||||
|
||||
# Go to the root of your npcomp checkout.
|
||||
alias npd="cd $td"
|
||||
# Handy so that `npctest -v` runs lit with -v and thus prints out errors,
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# TODO: This is copied from frontends/pytorch/csrc/c10_dispatch/CMakeLists.txt
|
||||
# What is the idiomatic way of sharing this in CMake?
|
||||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include/TH
|
||||
${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${PYTHON_INCLUDE_DIRS}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
nativecodegen
|
||||
)
|
||||
|
||||
add_llvm_tool(mnist-playground
|
||||
mnist-playground.cpp
|
||||
)
|
||||
llvm_update_compile_flags(mnist-playground)
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
target_link_libraries(mnist-playground PRIVATE
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRJitRunner
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRTargetLLVMIR
|
||||
MLIRSupport
|
||||
NPCOMPInitAll
|
||||
NPCOMPJITRuntime
|
||||
${conversion_libs}
|
||||
${dialect_libs}
|
||||
${TORCH_LIBRARIES}
|
||||
)
|
||||
add_dependencies(mnist-playground
|
||||
NPCOMPCompilerRuntimeShlib
|
||||
)
|
|
@ -0,0 +1,42 @@
|
|||
# mnist-playground
|
||||
|
||||
This is intended to be a short-lived "playground" for doing various experiments, guided by a real model use case, for improving the npcomp reference backend.
|
||||
|
||||
It's expected that utilities developed here will graduate to a more general utility or that this utility will be obsoleted by Python-driven flows once those come online.
|
||||
|
||||
## Goals:
|
||||
|
||||
- Obtain a performance-grounded analysis of the TCF/TCP design + reference backend design, and improve the designs.
|
||||
|
||||
- Make forward progress on TCF/TCP + reference backend while the PyTorch frontend is being brought up.
|
||||
|
||||
## Rough sketch of how we intend to get there:
|
||||
|
||||
1. Link against PyTorch, and write a simple routine to do inference on a simple FC MNIST.
|
||||
|
||||
2. Write a similar routine in TCF, extending TCF and the reference backend as needed for functional completeness. The PyTorch code serves as a numerical correctness reference.
|
||||
|
||||
3. Run and profile the reference backend and obtain a set of action items for design improvements, both to performance and stability. The PyTorch code serves as a performance baseline.
|
||||
|
||||
4. Implement important action items on a priority basis, and document remaining major design issues that don't make sense to address at this time, along with a justification for why the current design doesn't prevent us from eventually solving them. Iterate the previous step and this one as makes sense.
|
||||
|
||||
5. (Stretch) Add support for convolutional MNIST and/or training.
|
||||
|
||||
## Current Status
|
||||
|
||||
Step 1. DONE
|
||||
|
||||
Step 2. MOSTLY DONE. Still need to improve the op set to make the FC MNIST more complete. In particular, implementing functionality for reshape and softmax.
|
||||
|
||||
Step 3. STARTING. Initial performance on 10x784x100 (10 FC feature, batch 100) is 66x off from PyTorch. No profiling done yet.
|
||||
|
||||
Example command line (the .mlir file and `-invoke` are similar to npcomp-run-mlir):
|
||||
|
||||
```
|
||||
$ mnist-playground tools/mnist-playground/fc.mlir -invoke fc
|
||||
PyTorch: numRuns: 16384 nsPerRun: 3.947563e+05
|
||||
RefE2E: numRuns: 256 nsPerRun: 2.471073e+07
|
||||
Ratio (RefE2E / PyTorch): 62.5974
|
||||
```
|
||||
|
||||
There is currently a fragile dependency between hardcoded `at::` function calls in the .cpp file and the TCF code in the `.mlir` file. A correctness check is done to make sure they agree. Once we have a PyTorch frontend and/or ATen roundrip ATen backend oneline, we can avoid this fragility.
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
func @fc(
|
||||
// TODO: Implement "reshape" so that %image can be passed as batch of 2D tensors.
|
||||
%image: tensor<?x?xf32>,
|
||||
%weights: tensor<?x?xf32>,
|
||||
%biases: tensor<?x?xf32>)
|
||||
-> (
|
||||
tensor<?x?xf32>
|
||||
) {
|
||||
%0 = tcf.matmul %weights, %image : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = tcf.add %0, %biases : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// TODO: Implement softmax for classification.
|
||||
// For now, this returns a not-terribly useful number.
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
|
@ -0,0 +1,300 @@
|
|||
//===- mnist-playground.cpp -------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "npcomp/JITRuntime/JITModule.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::Error;
|
||||
using llvm::ErrorOr;
|
||||
using llvm::Expected;
|
||||
using llvm::StringError;
|
||||
using llvm::Twine;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrap a string into an llvm::StringError.
|
||||
static Error make_string_error(const Twine &message) {
|
||||
return llvm::make_error<StringError>(message.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
Expected<std::unique_ptr<npcomp::JITModule>>
|
||||
createJITModule(std::string mlirFile, mlir::DialectRegistry ®istry,
|
||||
ArrayRef<StringRef> sharedLibs, bool optimize) {
|
||||
MLIRContext context;
|
||||
registry.loadAll(&context);
|
||||
OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context);
|
||||
if (!moduleRef)
|
||||
return make_string_error(Twine("could not open ") + mlirFile);
|
||||
|
||||
ModuleOp module = *moduleRef;
|
||||
|
||||
// Compile.
|
||||
PassManager pm(module.getContext(), /*verifyPasses=*/true);
|
||||
applyPassManagerCLOptions(pm);
|
||||
npcomp::JITModule::buildBackendCompilationPipeline(pm, optimize);
|
||||
if (failed(pm.run(module)))
|
||||
return make_string_error(Twine("error compiling to jit backend"));
|
||||
|
||||
return npcomp::JITModule::fromCompiledModule(module, sharedLibs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Benchmarking / correctness-testing code.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Expected<std::vector<at::Tensor>>
|
||||
invokeJITModuleWithATenTensors(npcomp::JITModule &jitModule,
|
||||
StringRef invokeFunction,
|
||||
std::vector<at::Tensor> &args) {
|
||||
|
||||
// Do a bit of checking. We don't handle all possible tensors right now.
|
||||
std::vector<at::TensorArg> tensorArgs;
|
||||
for (auto arg : llvm::enumerate(args))
|
||||
tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index()));
|
||||
at::CheckedFrom c = "converting to npcomprt::Tensor";
|
||||
for (auto &tensorArg : tensorArgs)
|
||||
at::checkScalarType(c, tensorArg, at::ScalarType::Float);
|
||||
at::checkAllContiguous(c, tensorArgs);
|
||||
|
||||
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> npcomprtInputs;
|
||||
for (at::Tensor arg : args) {
|
||||
SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end());
|
||||
float *data = arg.storage().data<float>();
|
||||
// This does a deep copy of the data. Let's see if it shows up on the
|
||||
// profile.
|
||||
npcomprtInputs.push_back(npcomprt::Tensor::create(
|
||||
npcomprt::ArrayRef<int32_t>(extents.data(), extents.size()),
|
||||
npcomprt::ElementType::F32, data));
|
||||
}
|
||||
|
||||
// Invoke the RefE2E function.
|
||||
// TODO: The mishmash of terminology "npcomprt", "refe2e", "npcomp" in this
|
||||
// file is getting out of hand.
|
||||
auto expectedOutputs = jitModule.invoke(invokeFunction, npcomprtInputs);
|
||||
if (!expectedOutputs)
|
||||
return expectedOutputs.takeError();
|
||||
auto npcomprtOutputs = std::move(*expectedOutputs);
|
||||
|
||||
std::vector<at::Tensor> results;
|
||||
for (auto output : npcomprtOutputs) {
|
||||
std::vector<int64_t> sizes(output->getExtents().data(),
|
||||
output->getExtents().data() +
|
||||
output->getExtents().size());
|
||||
// Make a copy for passing to at::from_blob, which does its own internal
|
||||
// reference counting.
|
||||
auto *dataCopy = std::malloc(output->getDataByteSize());
|
||||
std::memcpy(dataCopy, output->getData(), output->getDataByteSize());
|
||||
results.push_back(at::from_blob(
|
||||
dataCopy, sizes, [](void *p) { std::free(p); }, at::kFloat));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
using InvocationFunction =
|
||||
std::function<Expected<std::vector<at::Tensor>>(std::vector<at::Tensor>)>;
|
||||
|
||||
struct BenchmarkResult {
|
||||
int numRuns;
|
||||
float nsPerRun;
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const BenchmarkResult &result) {
|
||||
os << "numRuns: " << result.numRuns << " nsPerRun: " << std::scientific
|
||||
<< result.nsPerRun << std::defaultfloat;
|
||||
return os;
|
||||
}
|
||||
|
||||
Expected<BenchmarkResult> benchmark(std::function<Error()> f) {
|
||||
for (int itersAtATime = 1;; itersAtATime *= 2) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < itersAtATime; i++) {
|
||||
auto error = f();
|
||||
if (error)
|
||||
return std::move(error);
|
||||
}
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> elapsed = end - start;
|
||||
|
||||
// If the runtime is longer than 0.5 seconds, it's reliable enough.
|
||||
if (elapsed.count() > 0.5f) {
|
||||
BenchmarkResult result;
|
||||
result.numRuns = itersAtATime;
|
||||
result.nsPerRun = elapsed.count() * 10e9 / itersAtATime;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return make_string_error("too short running to benchmark!");
|
||||
}
|
||||
|
||||
static Error doIt(InvocationFunction ptFunc, InvocationFunction refE2EFunc,
|
||||
bool doBenchmark, int numCorrectnessTests) {
|
||||
|
||||
torch::manual_seed(42);
|
||||
torch::set_num_threads(1);
|
||||
|
||||
std::vector<at::Tensor> args;
|
||||
args.push_back(at::rand({784, 100}));
|
||||
args.push_back(at::rand({10, 784}));
|
||||
args.push_back(at::rand({10, 1}));
|
||||
|
||||
// Initial correctness check of the two functions.
|
||||
for (int correctnessTest = 0; correctnessTest < numCorrectnessTests;
|
||||
correctnessTest++) {
|
||||
auto expectedPt = ptFunc(args);
|
||||
auto expectedRefE2E = refE2EFunc(args);
|
||||
if (!expectedPt)
|
||||
return expectedPt.takeError();
|
||||
if (!expectedRefE2E)
|
||||
return expectedRefE2E.takeError();
|
||||
auto pt = std::move(*expectedPt);
|
||||
auto refE2E = std::move(*expectedRefE2E);
|
||||
if (pt.size() != refE2E.size())
|
||||
return make_string_error("mismatch in result arity!");
|
||||
for (int i = 0, e = pt.size(); i < e; i++) {
|
||||
if (!at::allclose(pt[i], refE2E[i])) {
|
||||
std::cout << "PyTorch:\n" << pt[i] << "\n";
|
||||
std::cout << "RefE2E:\n" << refE2E[i] << "\n";
|
||||
return make_string_error(Twine("mismatch in result contents ") +
|
||||
Twine(i) + Twine(" on correctness test #") +
|
||||
Twine(correctnessTest));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!doBenchmark)
|
||||
return Error::success();
|
||||
|
||||
// Benchmark the two against each other.
|
||||
BenchmarkResult ptBenchmarkResult;
|
||||
BenchmarkResult refE2EBenchmarkResult;
|
||||
{
|
||||
auto expectedResult =
|
||||
benchmark([&]() -> Error { return ptFunc(args).takeError(); });
|
||||
if (!expectedResult)
|
||||
return expectedResult.takeError();
|
||||
ptBenchmarkResult = std::move(*expectedResult);
|
||||
}
|
||||
|
||||
{
|
||||
auto expectedResult =
|
||||
benchmark([&]() -> Error { return refE2EFunc(args).takeError(); });
|
||||
if (!expectedResult)
|
||||
return expectedResult.takeError();
|
||||
refE2EBenchmarkResult = std::move(*expectedResult);
|
||||
}
|
||||
std::cout << "PyTorch: " << ptBenchmarkResult << "\n";
|
||||
std::cout << "RefE2E: " << refE2EBenchmarkResult << "\n";
|
||||
std::cout << "Ratio (RefE2E / PyTorch): "
|
||||
<< refE2EBenchmarkResult.nsPerRun / ptBenchmarkResult.nsPerRun
|
||||
<< "\n";
|
||||
|
||||
// TODO: Check for memory leaks?
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Main-related init and option parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace cl = llvm::cl;
|
||||
struct Options {
|
||||
cl::opt<std::string> inputFile{
|
||||
cl::Positional, cl::desc("the input .mlir file"), cl::init("-")};
|
||||
cl::opt<std::string> invokeFunction{"invoke", cl::Required,
|
||||
cl::desc("function to invoke")};
|
||||
|
||||
cl::list<std::string> sharedLibs{"shared-libs", cl::ZeroOrMore,
|
||||
cl::MiscFlags::CommaSeparated,
|
||||
cl::desc("Libraries to link dynamically")};
|
||||
cl::opt<bool> optimize{
|
||||
"optimize", cl::Optional,
|
||||
cl::desc("whether the e2e pass pipeline should run optimizations"),
|
||||
cl::init(false)};
|
||||
|
||||
cl::opt<bool> benchmark{"benchmark", cl::Optional,
|
||||
cl::desc("whether to do a benchmark comparison"),
|
||||
cl::init(true)};
|
||||
|
||||
cl::opt<uint32_t> numCorrectnessTests{
|
||||
"num-correctness-tests", cl::Optional,
|
||||
cl::desc("how many correctness tests to run (useful for nondeterministic "
|
||||
"correctness failures"),
|
||||
cl::init(1)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
mlir::registerAllPasses();
|
||||
mlir::NPCOMP::registerAllDialects(registry);
|
||||
mlir::NPCOMP::registerAllPasses();
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
mlir::registerAsmPrinterCLOptions();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
|
||||
Options options;
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "mnist playground utility\n");
|
||||
|
||||
SmallVector<StringRef, 6> sharedLibs(options.sharedLibs.begin(),
|
||||
options.sharedLibs.end());
|
||||
auto expectedJITModule = createJITModule(options.inputFile, registry,
|
||||
sharedLibs, options.optimize);
|
||||
if (Error error = expectedJITModule.takeError())
|
||||
llvm::report_fatal_error(llvm::toString(std::move(error)),
|
||||
/*gen_crash_diag=*/false);
|
||||
auto jitModule = std::move(*expectedJITModule);
|
||||
|
||||
Error error = doIt(
|
||||
[](std::vector<at::Tensor> args) {
|
||||
auto image = args[0];
|
||||
auto weights = args[1];
|
||||
auto biases = args[2];
|
||||
auto v0 = at::matmul(weights, image);
|
||||
auto v1 = at::add(v0, biases);
|
||||
return std::vector<at::Tensor>{v1};
|
||||
},
|
||||
[&](std::vector<at::Tensor> args) {
|
||||
return invokeJITModuleWithATenTensors(*jitModule,
|
||||
options.invokeFunction, args);
|
||||
},
|
||||
options.benchmark, options.numCorrectnessTests);
|
||||
|
||||
int exitCode = EXIT_SUCCESS;
|
||||
llvm::handleAllErrors(std::move(error),
|
||||
[&exitCode](const llvm::ErrorInfoBase &info) {
|
||||
llvm::errs() << "Error: ";
|
||||
info.log(llvm::errs());
|
||||
llvm::errs() << '\n';
|
||||
exitCode = EXIT_FAILURE;
|
||||
});
|
||||
return exitCode;
|
||||
}
|
Loading…
Reference in New Issue