Add boilerplate to do device capture (pytorch 1.6).

* Uses the new dispatcher API.
* Just prints to the console for the moment when an op is captured.
* Executes the op through the existing implementation.
pull/61/head
Stella Laurenzo 2020-09-25 18:13:16 -07:00
parent 16c26ef57e
commit b5f010284f
9 changed files with 188 additions and 27 deletions

View File

@ -14,7 +14,7 @@ set -o xtrace
clang-format -i \
$(find_cc_sources include) \
$(find_cc_sources lib) \
$(find_cc_sources python_native)
$(find_cc_sources frontends/pytorch/csrc)
# Python sources.
yapf --recursive -i "$td/python" "$td/pytest"

View File

@ -8,6 +8,7 @@ include_directories(
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
add_library(npcomp_torch_c10_dispatch_bindings
acap_dispatch.cpp
python_bindings.cpp
)

View File

@ -0,0 +1,85 @@
//===- acap_dispatch.cpp --------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "acap_dispatch.h"
#include "npcomp/Python/PybindUtils.h"
#include <c10/core/DispatchKey.h>
#include <torch/library.h>
using namespace torch_mlir;
namespace py = pybind11;
// TODO: Private use dispatch keys are not made for real uses. Allocate a proper
// dispatch key in upstream PyTorch (DispatchKey.h) prior to maturity. Note
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
// enum value, so we define both. We can get rid of both once we have our
// own key.
#define ACAP_DISPATCH_KEY PrivateUse1
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
std::list<AcapController::Activation> &
AcapController::getThreadLocalActiveStack() {
static thread_local std::list<Activation> threadLocalActiveStack;
return threadLocalActiveStack;
}
py::object AcapController::contextEnter() {
auto &stack = getThreadLocalActiveStack();
stack.emplace_front(shared_from_this());
Activation &current = stack.front();
current.dispatchGuard =
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(kAcapDispatchKey);
return py::cast(this);
}
void AcapController::contextExit(py::object exc_type, py::object exc_val,
py::object exc_tb) {
auto &stack = getThreadLocalActiveStack();
if (stack.empty() || stack.front().controller.get() != this) {
throw py::raisePyError(PyExc_RuntimeError,
"Mismatched context manager __exit__");
}
stack.pop_front();
}
std::vector<std::string> AcapController::getDebugLog() {
std::vector<std::string> copy;
captureLog.swap(copy);
return copy;
}
std::shared_ptr<AcapController> AcapController::getCurrent() {
auto &stack = getThreadLocalActiveStack();
if (stack.empty())
return nullptr;
return stack.front().controller;
}
void AcapController::fallbackKernel(const c10::OperatorHandle &opHandle,
c10::Stack *stack) {
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
auto current = getCurrent();
if (current) {
// Capture the dispatch.
std::stringstream sout;
sout << "CAPTURE: " << opHandle.schema() << "\n";
current->captureLog.push_back(sout.str());
}
auto &dispatcher = c10::Dispatcher::singleton();
dispatcher.callBoxed(opHandle, stack);
}
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
}

View File

@ -0,0 +1,60 @@
//===- acap_dispatch.h ------------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
// "ATen Capture" dispatcher: Defines facility for capturing programs by
// registering dispatch keys to intercept op execution.
// References:
// http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
//
//===----------------------------------------------------------------------===//
#include <list>
#include <memory>
#include <pybind11/pybind11.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
namespace torch_mlir {
/// Main entry point for managing device capture.
class AcapController : public std::enable_shared_from_this<AcapController> {
public:
AcapController() = default;
// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(pybind11::object exc_type, pybind11::object exc_val,
pybind11::object exc_tb);
// Gets and clears the current debug log.
std::vector<std::string> getDebugLog();
// Returns the current AcapController (if it has been activated on this
// thread. Returns nullptr if none.
static std::shared_ptr<AcapController> getCurrent();
// The fallback boxed kernel that we route captured dispatches through.
static void fallbackKernel(const c10::OperatorHandle &opHandle,
c10::Stack *stack);
private:
struct Activation {
Activation(std::shared_ptr<AcapController> controller)
: controller(std::move(controller)) {}
std::shared_ptr<AcapController> controller;
// The RAII dispatch key guard is not movable, so heap allocate it. This is
// a bit outside of its intended design, but since this is thread local as
// well, it should be fine.
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> dispatchGuard;
};
// Gets the thread local stack of active acap controllers.
static std::list<Activation> &getThreadLocalActiveStack();
std::vector<std::string> captureLog;
};
} // namespace torch_mlir

View File

@ -5,12 +5,14 @@
//
//===----------------------------------------------------------------------===//
#include "../init_python_bindings.h"
#include <pybind11/pybind11.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "../init_python_bindings.h"
#include "acap_dispatch.h"
using namespace torch_mlir;
namespace py = pybind11;
namespace {
@ -123,6 +125,12 @@ py::list GetRegisteredOps() {
}
void InitModuleBindings(py::module &m) {
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
"AcapController")
.def(py::init<>())
.def("__enter__", &AcapController::contextEnter)
.def("__exit__", &AcapController::contextExit)
.def("get_debug_log", &AcapController::getDebugLog);
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
}

View File

@ -0,0 +1,24 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: python %s | FileCheck %s
import torch
import _torch_mlir as m
t0 = torch.randn((4,4))
t1 = torch.randn((4,4))
with m.c10.AcapController() as c:
result = t0 + t1
result = result * t0
# NOTE: Ops involved with printing throw RuntimeError about calling a kernel
# from an unboxed API.
print(result)
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
# CHECK-NOT: CAPTURE: aten::mul
log = c.get_debug_log()
for line in log: print(line)

View File

@ -33,8 +33,11 @@ namespace pybind11 {
/// Raises a python exception with the given message.
/// Correct usage:
// throw RaiseValueError(PyExc_ValueError, "Foobar'd");
pybind11::error_already_set raisePyError(PyObject *exc_class,
const char *message);
inline pybind11::error_already_set raisePyError(PyObject *exc_class,
const char *message) {
PyErr_SetString(exc_class, message);
return pybind11::error_already_set();
}
/// Raises a value error with the given message.
/// Correct usage:

View File

@ -23,7 +23,6 @@ set(PYBIND_SOURCES
MlirIr.cpp
MlirPass.cpp
NpcompDialect.cpp
PybindUtils.cpp
CoreDialects.cpp
)

View File

@ -1,19 +0,0 @@
//===- PybindUtils.cpp - Utilities for interop with python ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Python/PybindUtils.h"
namespace pybind11 {
pybind11::error_already_set raisePyError(PyObject *exc_class,
const char *message) {
PyErr_SetString(exc_class, message);
return pybind11::error_already_set();
}
} // namespace pybind11