mirror of https://github.com/llvm/torch-mlir
Add boilerplate to do device capture (pytorch 1.6).
* Uses the new dispatcher API. * Just prints to the console for the moment when an op is captured. * Executes the op through the existing implementation.pull/61/head
parent
16c26ef57e
commit
b5f010284f
|
@ -14,7 +14,7 @@ set -o xtrace
|
||||||
clang-format -i \
|
clang-format -i \
|
||||||
$(find_cc_sources include) \
|
$(find_cc_sources include) \
|
||||||
$(find_cc_sources lib) \
|
$(find_cc_sources lib) \
|
||||||
$(find_cc_sources python_native)
|
$(find_cc_sources frontends/pytorch/csrc)
|
||||||
|
|
||||||
# Python sources.
|
# Python sources.
|
||||||
yapf --recursive -i "$td/python" "$td/pytest"
|
yapf --recursive -i "$td/python" "$td/pytest"
|
||||||
|
|
|
@ -8,6 +8,7 @@ include_directories(
|
||||||
)
|
)
|
||||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||||
add_library(npcomp_torch_c10_dispatch_bindings
|
add_library(npcomp_torch_c10_dispatch_bindings
|
||||||
|
acap_dispatch.cpp
|
||||||
python_bindings.cpp
|
python_bindings.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
//===- acap_dispatch.cpp --------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "acap_dispatch.h"
|
||||||
|
|
||||||
|
#include "npcomp/Python/PybindUtils.h"
|
||||||
|
|
||||||
|
#include <c10/core/DispatchKey.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
// TODO: Private use dispatch keys are not made for real uses. Allocate a proper
|
||||||
|
// dispatch key in upstream PyTorch (DispatchKey.h) prior to maturity. Note
|
||||||
|
// that the TORCH_LIBRARY_* macros expand this by name and other APIs use its
|
||||||
|
// enum value, so we define both. We can get rid of both once we have our
|
||||||
|
// own key.
|
||||||
|
#define ACAP_DISPATCH_KEY PrivateUse1
|
||||||
|
static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY;
|
||||||
|
|
||||||
|
std::list<AcapController::Activation> &
|
||||||
|
AcapController::getThreadLocalActiveStack() {
|
||||||
|
static thread_local std::list<Activation> threadLocalActiveStack;
|
||||||
|
return threadLocalActiveStack;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object AcapController::contextEnter() {
|
||||||
|
auto &stack = getThreadLocalActiveStack();
|
||||||
|
stack.emplace_front(shared_from_this());
|
||||||
|
Activation ¤t = stack.front();
|
||||||
|
current.dispatchGuard =
|
||||||
|
std::make_unique<c10::impl::IncludeDispatchKeyGuard>(kAcapDispatchKey);
|
||||||
|
return py::cast(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcapController::contextExit(py::object exc_type, py::object exc_val,
|
||||||
|
py::object exc_tb) {
|
||||||
|
auto &stack = getThreadLocalActiveStack();
|
||||||
|
if (stack.empty() || stack.front().controller.get() != this) {
|
||||||
|
throw py::raisePyError(PyExc_RuntimeError,
|
||||||
|
"Mismatched context manager __exit__");
|
||||||
|
}
|
||||||
|
stack.pop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> AcapController::getDebugLog() {
|
||||||
|
std::vector<std::string> copy;
|
||||||
|
captureLog.swap(copy);
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<AcapController> AcapController::getCurrent() {
|
||||||
|
auto &stack = getThreadLocalActiveStack();
|
||||||
|
if (stack.empty())
|
||||||
|
return nullptr;
|
||||||
|
return stack.front().controller;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AcapController::fallbackKernel(const c10::OperatorHandle &opHandle,
|
||||||
|
c10::Stack *stack) {
|
||||||
|
// Exclude recursive dispatch to this kernel.
|
||||||
|
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||||
|
|
||||||
|
auto current = getCurrent();
|
||||||
|
if (current) {
|
||||||
|
// Capture the dispatch.
|
||||||
|
std::stringstream sout;
|
||||||
|
sout << "CAPTURE: " << opHandle.schema() << "\n";
|
||||||
|
current->captureLog.push_back(sout.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &dispatcher = c10::Dispatcher::singleton();
|
||||||
|
dispatcher.callBoxed(opHandle, stack);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||||
|
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||||
|
&AcapController::fallbackKernel>());
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
//===- acap_dispatch.h ------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// "ATen Capture" dispatcher: Defines facility for capturing programs by
|
||||||
|
// registering dispatch keys to intercept op execution.
|
||||||
|
// References:
|
||||||
|
// http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Main entry point for managing device capture.
|
||||||
|
class AcapController : public std::enable_shared_from_this<AcapController> {
|
||||||
|
public:
|
||||||
|
AcapController() = default;
|
||||||
|
|
||||||
|
// Enter and exit the context manager.
|
||||||
|
pybind11::object contextEnter();
|
||||||
|
void contextExit(pybind11::object exc_type, pybind11::object exc_val,
|
||||||
|
pybind11::object exc_tb);
|
||||||
|
|
||||||
|
// Gets and clears the current debug log.
|
||||||
|
std::vector<std::string> getDebugLog();
|
||||||
|
|
||||||
|
// Returns the current AcapController (if it has been activated on this
|
||||||
|
// thread. Returns nullptr if none.
|
||||||
|
static std::shared_ptr<AcapController> getCurrent();
|
||||||
|
|
||||||
|
// The fallback boxed kernel that we route captured dispatches through.
|
||||||
|
static void fallbackKernel(const c10::OperatorHandle &opHandle,
|
||||||
|
c10::Stack *stack);
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Activation {
|
||||||
|
Activation(std::shared_ptr<AcapController> controller)
|
||||||
|
: controller(std::move(controller)) {}
|
||||||
|
std::shared_ptr<AcapController> controller;
|
||||||
|
// The RAII dispatch key guard is not movable, so heap allocate it. This is
|
||||||
|
// a bit outside of its intended design, but since this is thread local as
|
||||||
|
// well, it should be fine.
|
||||||
|
std::unique_ptr<c10::impl::IncludeDispatchKeyGuard> dispatchGuard;
|
||||||
|
};
|
||||||
|
// Gets the thread local stack of active acap controllers.
|
||||||
|
static std::list<Activation> &getThreadLocalActiveStack();
|
||||||
|
std::vector<std::string> captureLog;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
|
@ -5,12 +5,14 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "../init_python_bindings.h"
|
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
|
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "../init_python_bindings.h"
|
||||||
|
#include "acap_dispatch.h"
|
||||||
|
|
||||||
|
using namespace torch_mlir;
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -123,6 +125,12 @@ py::list GetRegisteredOps() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitModuleBindings(py::module &m) {
|
void InitModuleBindings(py::module &m) {
|
||||||
|
py::class_<AcapController, std::shared_ptr<AcapController>>(m,
|
||||||
|
"AcapController")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def("__enter__", &AcapController::contextEnter)
|
||||||
|
.def("__exit__", &AcapController::contextExit)
|
||||||
|
.def("get_debug_log", &AcapController::getDebugLog);
|
||||||
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
m.def("get_registered_ops", &GetRegisteredOps, kGetRegisteredOpsDocstring);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
# RUN: python %s | FileCheck %s
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import _torch_mlir as m
|
||||||
|
|
||||||
|
t0 = torch.randn((4,4))
|
||||||
|
t1 = torch.randn((4,4))
|
||||||
|
|
||||||
|
with m.c10.AcapController() as c:
|
||||||
|
result = t0 + t1
|
||||||
|
|
||||||
|
result = result * t0
|
||||||
|
|
||||||
|
# NOTE: Ops involved with printing throw RuntimeError about calling a kernel
|
||||||
|
# from an unboxed API.
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# CHECK: CAPTURE: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)
|
||||||
|
# CHECK-NOT: CAPTURE: aten::mul
|
||||||
|
log = c.get_debug_log()
|
||||||
|
for line in log: print(line)
|
|
@ -33,8 +33,11 @@ namespace pybind11 {
|
||||||
/// Raises a python exception with the given message.
|
/// Raises a python exception with the given message.
|
||||||
/// Correct usage:
|
/// Correct usage:
|
||||||
// throw RaiseValueError(PyExc_ValueError, "Foobar'd");
|
// throw RaiseValueError(PyExc_ValueError, "Foobar'd");
|
||||||
pybind11::error_already_set raisePyError(PyObject *exc_class,
|
inline pybind11::error_already_set raisePyError(PyObject *exc_class,
|
||||||
const char *message);
|
const char *message) {
|
||||||
|
PyErr_SetString(exc_class, message);
|
||||||
|
return pybind11::error_already_set();
|
||||||
|
}
|
||||||
|
|
||||||
/// Raises a value error with the given message.
|
/// Raises a value error with the given message.
|
||||||
/// Correct usage:
|
/// Correct usage:
|
||||||
|
|
|
@ -23,7 +23,6 @@ set(PYBIND_SOURCES
|
||||||
MlirIr.cpp
|
MlirIr.cpp
|
||||||
MlirPass.cpp
|
MlirPass.cpp
|
||||||
NpcompDialect.cpp
|
NpcompDialect.cpp
|
||||||
PybindUtils.cpp
|
|
||||||
CoreDialects.cpp
|
CoreDialects.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
//===- PybindUtils.cpp - Utilities for interop with python ----------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "npcomp/Python/PybindUtils.h"
|
|
||||||
|
|
||||||
namespace pybind11 {
|
|
||||||
|
|
||||||
pybind11::error_already_set raisePyError(PyObject *exc_class,
|
|
||||||
const char *message) {
|
|
||||||
PyErr_SetString(exc_class, message);
|
|
||||||
return pybind11::error_already_set();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace pybind11
|
|
Loading…
Reference in New Issue