//===- acap_dispatch.h ------------------------------------------*- C++ -*-===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// // "ATen Capture" dispatcher: Defines facility for capturing programs by // registering dispatch keys to intercept op execution. // References: // http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/ // //===----------------------------------------------------------------------===// #ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_ACAP_DISPATCH_H #define NPCOMP_FRONTENDS_PYTORCH_CSRC_BUILDER_ACAP_DISPATCH_H #include #include #include "../pybind.h" #include "func_builder.h" #include "mlir-c/IR.h" #include #include #include namespace torch_mlir { /// Main entry point for managing device capture. class AcapController : public std::enable_shared_from_this { public: AcapController(TypeMapper &typeMapper, std::unique_ptr funcBuilder) : typeMapper(typeMapper), funcBuilder(std::move(funcBuilder)) {} // Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object exc_type, pybind11::object exc_val, pybind11::object exc_tb); // Terminates capture and returns tensors from the function. void returns(std::vector tensors); // Returns the current AcapController (if it has been activated on this // thread. Returns nullptr if none (not active on the current thread). static std::shared_ptr getCurrentThreadAcapController(); // The fallback boxed kernel that we route captured dispatches through. static void fallbackKernel(const c10::OperatorHandle &opHandle, c10::Stack *stack); // Kernel implementation for the boxing-incompatible convolution kernel. static at::Tensor convolutionKernel(const at::Tensor &input, const at::Tensor &weight, const c10::optional &bias, const at::IntArrayRef stride, const at::IntArrayRef padding, const at::IntArrayRef dilation, const bool transposed, const at::IntArrayRef output_padding, const int64_t groups); // Kernel implementation for the boxing-incompatible convolution kernel. static std::tuple mklConvolutionBackward( const at::Tensor &input, const at::Tensor &grad_output, const at::Tensor &weight, const at::IntArrayRef padding, const at::IntArrayRef stride, const at::IntArrayRef dilation, const int64_t groups, std::array output_mask); // Implementation for the aten::copy_ kernel. static at::Tensor ©UnderKernel(at::Tensor &self, const at::Tensor &src, bool non_blocking); // Backend select kernel for arange factory function. static at::Tensor arangeBackendSelectKernel( const at::Scalar &end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); private: /// Builds a kernel call step by step. class TracedKernelCallBuilder : private KernelCallBuilder { public: TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, c10::optional overrideKernelName = c10::nullopt); void addOperand(const c10::IValue &value); void addResult(const c10::IValue &result); MlirOperation create(); private: AcapController &parent; const c10::OperatorHandle &opHandle; int resultCount = 0; std::vector> resultIndexToTensorMap; }; MlirLocation getCurrentLocation(); void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack); void fallbackKernelImpl(const c10::OperatorHandle &opHandle, c10::Stack *stack, std::function redispatchCallback); MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival); MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival); /// Imports a tensor by value (as a constant), remembering the association. MlirValue importTensorByValue(at::Tensor tensor); void verifyHasNotReturned(); struct Activation { Activation(std::shared_ptr controller) : controller(std::move(controller)) {} std::shared_ptr controller; // The RAII dispatch key guard is not movable, so heap allocate it. This is // a bit outside of its intended design, but since this is thread local as // well, it should be fine. std::unique_ptr includeGuard; std::unique_ptr excludeGuard; }; // Gets the thread local stack of active acap controllers. static std::list &getThreadLocalActiveStack(); TypeMapper &typeMapper; std::unique_ptr funcBuilder; bool hasReturned = false; }; } // namespace torch_mlir #endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_ACAP_DISPATCH_H