//===- acap_dispatch.h ------------------------------------------*- C++ -*-===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// // "ATen Capture" dispatcher: Defines facility for capturing programs by // registering dispatch keys to intercept op execution. // References: // http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/ // //===----------------------------------------------------------------------===// #include #include #include #include #include namespace torch_mlir { /// Main entry point for managing device capture. class AcapController : public std::enable_shared_from_this { public: AcapController() = default; // Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object exc_type, pybind11::object exc_val, pybind11::object exc_tb); // Gets and clears the current debug log. std::vector getDebugLog(); // Returns the current AcapController (if it has been activated on this // thread. Returns nullptr if none. static std::shared_ptr getCurrent(); // The fallback boxed kernel that we route captured dispatches through. static void fallbackKernel(const c10::OperatorHandle &opHandle, c10::Stack *stack); private: struct Activation { Activation(std::shared_ptr controller) : controller(std::move(controller)) {} std::shared_ptr controller; // The RAII dispatch key guard is not movable, so heap allocate it. This is // a bit outside of its intended design, but since this is thread local as // well, it should be fine. std::unique_ptr dispatchGuard; }; // Gets the thread local stack of active acap controllers. static std::list &getThreadLocalActiveStack(); std::vector captureLog; }; } // namespace torch_mlir