//===- acap_dispatch.cpp --------------------------------------------------===// // // This file is licensed under a pytorch-style license // See frontends/pytorch/LICENSE for license information. // //===----------------------------------------------------------------------===// #include "acap_dispatch.h" #include "debug.h" #include "mlir_utils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "npcomp-c/Types.h" #include "npcomp/Python/PybindUtils.h" #include #include #include #include #include using namespace torch_mlir; namespace py = pybind11; using c10::FunctionSchema; using c10::IValue; using c10::OperatorHandle; using c10::Stack; // TODO: Private use dispatch keys are not made for real uses. Allocate a proper // dispatch key in upstream PyTorch (DispatchKey.h) prior to maturity. Note // that the TORCH_LIBRARY_* macros expand this by name and other APIs use its // enum value, so we define both. We can get rid of both once we have our // own key. // TODO: Ask the PT devs why conv is special and only shows up if dispatching // through the autograd keys. // https://github.com/llvm/mlir-npcomp/issues/86 #define ACAP_DISPATCH_KEY PrivateUse2 #define ACAP_GRAD_DISPATCH_KEY AutogradPrivateUse2 static c10::DispatchKey kAcapDispatchKey = c10::DispatchKey::ACAP_DISPATCH_KEY; static c10::DispatchKey kAcapGradDispatchKey = c10::DispatchKey::ACAP_GRAD_DISPATCH_KEY; AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, c10::optional overrideKernelName) : KernelCallBuilder(context, loc, overrideKernelName ? *overrideKernelName : opHandle.operator_name().name, opHandle.schema()), parent(parent), opHandle(opHandle) {} void AcapController::TracedKernelCallBuilder::addOperand(const IValue &value) { MlirValue mlirValue = parent.mapIValueToMlirValue(loc, value); if (mlirValueIsNull(mlirValue)) { std::stringstream out; const std::string &kernelName = opHandle.operator_name().name; out << "Unsupported capture value passed to kernel '" << kernelName << "' (" << value.tagKind() << "): " << value; throw std::invalid_argument(out.str()); } KernelCallBuilder::addOperand(mlirValue); } void AcapController::TracedKernelCallBuilder::addResult(const IValue &value) { MlirType resultType = parent.mapIValueToMlirType(loc, value); if (mlirTypeIsNull(resultType)) { std::stringstream out; const std::string &kernelName = opHandle.operator_name().name; out << "Unsupported capture value returned from kernel '" << kernelName << "' (" << value.tagKind() << "): " << value; throw std::invalid_argument(out.str()); } if (value.isTensor()) { resultIndexToTensorMap.emplace_back(resultCount++, value.toTensor()); } KernelCallBuilder::addResultType(resultType); } MlirOperation AcapController::TracedKernelCallBuilder::create() { MlirOperation op = KernelCallBuilder::create(); parent.funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(op); // Map result tensors. for (auto &it : resultIndexToTensorMap) { MlirValue result = mlirOperationGetResult(op, it.first); parent.funcBuilder->mapTensor(it.second, result); } return op; } std::list & AcapController::getThreadLocalActiveStack() { static thread_local std::list threadLocalActiveStack; return threadLocalActiveStack; } py::object AcapController::contextEnter() { auto &stack = getThreadLocalActiveStack(); stack.emplace_front(shared_from_this()); Activation ¤t = stack.front(); c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; current.includeGuard = std::make_unique(keySet); return py::cast(this); } void AcapController::contextExit(py::object exc_type, py::object exc_val, py::object exc_tb) { auto &stack = getThreadLocalActiveStack(); if (stack.empty() || stack.front().controller.get() != this) { throw py::raisePyError(PyExc_RuntimeError, "Mismatched context manager __exit__"); } stack.pop_front(); if (!hasReturned) { returns({}); } } void AcapController::returns(std::vector tensors) { verifyHasNotReturned(); std::vector returnsTypes; std::vector returnsValues; for (auto &tensor : tensors) { MlirValue v = funcBuilder->lookupTensor(tensor); if (mlirValueIsNull(v)) { debugTrace( "Return of imported-constant tensor (intentional memorization?)"); v = importTensorByValue(tensor); } returnsTypes.push_back(mlirValueGetType(v)); returnsValues.push_back(v); } MlirLocation loc = getCurrentLocation(); OperationStateHolder s("std.return", loc); mlirOperationStateAddOperands(s, returnsValues.size(), returnsValues.data()); funcBuilder->getEntryBlockBuilder().insertBeforeTerminator( s.createOperation()); funcBuilder->rewriteFuncReturnTypes(returnsTypes); hasReturned = true; } std::shared_ptr AcapController::getCurrentThreadAcapController() { auto &stack = getThreadLocalActiveStack(); if (stack.empty()) return nullptr; return stack.front().controller; } void AcapController::verifyHasNotReturned() { if (hasReturned) { throw std::runtime_error( "Function has already returned. Cannot trace more operations."); } } /* static */ void AcapController::fallbackKernel(const OperatorHandle &opHandle, Stack *stack) { auto redispatchCallback = [&]() { // Exclude recursive dispatch to this kernel. c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey); // Passthrough. auto &dispatcher = c10::Dispatcher::singleton(); dispatcher.callBoxed(opHandle, stack); }; auto current = getCurrentThreadAcapController(); if (!current) { redispatchCallback(); return; } current->fallbackKernelImpl(opHandle, stack, redispatchCallback); } at::Tensor AcapController::convolutionKernel( const at::Tensor &input, const at::Tensor &weight, const c10::optional &bias, const at::IntArrayRef stride, const at::IntArrayRef padding, const at::IntArrayRef dilation, const bool transposed, const at::IntArrayRef output_padding, const int64_t groups) { static c10::OperatorName opName{"aten::convolution", ""}; auto &dispatcher = c10::Dispatcher::singleton(); auto opHandle = dispatcher.findOp(opName); assert(opHandle && "could not find convolution op"); if (isDebugTraceEnabled()) { std::stringstream s; s << "Convolution (unboxed) dispatch: " << opHandle->schema(); debugTrace(s.str()); } auto opTyped = opHandle->typed &, const at::IntArrayRef, const at::IntArrayRef, const at::IntArrayRef, const bool, const at::IntArrayRef, const int64_t)>(); // Exclude recursive calls: convolution is completely emitted by this // kernel. c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); auto current = getCurrentThreadAcapController(); if (!current) { return opTyped.callWithDispatchKey(c10::DispatchKey::AutogradOther, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); } MlirContext context = current->funcBuilder->getContext(); MlirLocation loc = current->getCurrentLocation(); std::string kernelName{"aten::convolution"}; TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle}; callBuilder.addOperand(IValue(input)); callBuilder.addOperand(IValue(weight)); // This is really sad: instead of storing a none in the optional, it stores // an undefined tensor, which cannot convert to an IValue :( // TODO: File PyTorch bug. Perhaps this is why they don't support boxing // for it. IValue biasIValue; if (bias && bias->defined()) { biasIValue = IValue(bias); } else { biasIValue = IValue(c10::optional()); } callBuilder.addOperand(biasIValue); callBuilder.addOperand(IValue(stride)); callBuilder.addOperand(IValue(padding)); callBuilder.addOperand(IValue(dilation)); callBuilder.addOperand(IValue(transposed)); callBuilder.addOperand(IValue(output_padding)); callBuilder.addOperand(IValue(groups)); auto result = opTyped.callWithDispatchKey( c10::DispatchKey::AutogradOther, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); callBuilder.addResult(result); callBuilder.create(); return result; } std::tuple AcapController::mklConvolutionBackward( const at::Tensor &input, const at::Tensor &grad_output, const at::Tensor &weight, const at::IntArrayRef padding, const at::IntArrayRef stride, const at::IntArrayRef dilation, const int64_t groups, std::array output_mask) { static c10::OperatorName opName{"aten::mkldnn_convolution_backward", ""}; auto &dispatcher = c10::Dispatcher::singleton(); auto opHandle = dispatcher.findOp(opName); assert(opHandle && "could not find mkldnn_convolution_backward op"); if (isDebugTraceEnabled()) { std::stringstream s; s << "mkldnn_convolution_backward dispatch: " << opHandle->schema(); debugTrace(s.str()); } auto opTyped = opHandle->typed( const at::Tensor &input, const at::Tensor &grad_output, const at::Tensor &weight, const at::IntArrayRef padding, const at::IntArrayRef stride, const at::IntArrayRef dilation, const int64_t groups, std::array output_mask)>(); // Exclude recursive calls: convolution is completely emitted by this // kernel. c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); auto current = getCurrentThreadAcapController(); if (!current) { return opTyped.callWithDispatchKey(c10::DispatchKey::AutogradOther, input, grad_output, weight, padding, stride, dilation, groups, output_mask); } // Emit the call as if to aten::convolution_overridable, the generic, full // parameterized versions that backends are supposed to implement. // Requires some parameter swizzling. // It has the signature: // convolution_backward_overrideable(Tensor grad_output, Tensor input, // Tensor weight, int[] stride, int[] padding, int[] dilation, // bool transposed, int[] output_padding, int groups, // bool[3] output_mask) -> // (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) MlirContext context = current->funcBuilder->getContext(); MlirLocation loc = current->getCurrentLocation(); std::string kernelName{"aten::convolution_backward"}; static c10::OperatorName emitOpName{"aten::convolution_backward_overrideable", ""}; auto emitOpHandle = dispatcher.findOp(emitOpName); assert(emitOpHandle && "could not find convolution_backward_overrideable op"); TracedKernelCallBuilder callBuilder{*current, context, loc, *emitOpHandle, kernelName}; callBuilder.addOperand(IValue(grad_output)); callBuilder.addOperand(IValue(input)); callBuilder.addOperand(IValue(weight)); callBuilder.addOperand(IValue(stride)); callBuilder.addOperand(IValue(padding)); callBuilder.addOperand(IValue(dilation)); callBuilder.addOperand(IValue(false)); std::vector output_padding(padding.size()); // Not provided. callBuilder.addOperand(IValue(at::IntArrayRef(output_padding))); callBuilder.addOperand(IValue(groups)); callBuilder.addOperand(IValue(output_mask)); auto results = opTyped.callWithDispatchKey( c10::DispatchKey::AutogradCPU, input, grad_output, weight, padding, stride, dilation, groups, output_mask); callBuilder.addResult(std::get<0>(results)); callBuilder.addResult(std::get<1>(results)); callBuilder.addResult(std::get<2>(results)); callBuilder.create(); return results; } at::Tensor &AcapController::copyUnderKernel(at::Tensor &self, const at::Tensor &src, bool non_blocking) { static c10::OperatorName opName{"aten::copy_", ""}; auto &dispatcher = c10::Dispatcher::singleton(); auto opHandle = dispatcher.findOp(opName); assert(opHandle && "could not find copy_ op"); if (isDebugTraceEnabled()) { std::stringstream s; s << "copy_ dispatch: " << opHandle->schema(); debugTrace(s.str()); } auto opTyped = opHandle->typed(); // Exclude recursive calls. c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); auto current = getCurrentThreadAcapController(); if (!current) { return opTyped.callWithDispatchKey(c10::DispatchKey::AutogradOther, self, src, non_blocking); } MlirContext context = current->funcBuilder->getContext(); MlirLocation loc = current->getCurrentLocation(); TracedKernelCallBuilder callBuilder{*current, context, loc, *opHandle}; callBuilder.addOperand(IValue(self)); callBuilder.addOperand(IValue(src)); auto &result = opTyped.callWithDispatchKey(c10::DispatchKey::CPU, self, src, non_blocking); callBuilder.addResult(result); callBuilder.create(); return result; } at::Tensor AcapController::arangeBackendSelectKernel( at::Scalar end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { static c10::OperatorName opName{"aten::arange", ""}; auto &dispatcher = c10::Dispatcher::singleton(); auto opHandle = dispatcher.findOp(opName); assert(opHandle && "could not find arange op"); // Exclude recursive calls. c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); // Dispatching in this fashion replicates the exact way that PyTorch // built-in handlers dispatch to BackendSelect kernels. auto targetDk = c10::computeDispatchKey(dtype, layout, device); auto opTyped = opHandle->typed dtype, c10::optional layout, c10::optional device, c10::optional pin_memory)>(); return opTyped.callWithDispatchKey(targetDk, end, dtype, layout, device, pin_memory); } MlirLocation AcapController::getCurrentLocation() { return mlirLocationUnknownGet(funcBuilder->getContext()); } void AcapController::fallbackKernelImpl( const OperatorHandle &opHandle, Stack *stack, std::function redispatchCallback) { verifyHasNotReturned(); if (isDebugTraceEnabled()) { std::stringstream s; s << "Fallback (boxed) dispatch: " << opHandle.schema() << " (stack size=" << stack->size() << ")"; debugTrace(s.str()); } // Exclude recursive dispatch to this kernel. c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey); const FunctionSchema &schema = opHandle.schema(); // Check for unsupported. if (schema.is_vararg() || schema.is_varret()) { throw std::invalid_argument( "Cannot capture ops with variable arguments or returns"); } MlirContext context = funcBuilder->getContext(); MlirLocation loc = getCurrentLocation(); auto kernelName = schema.name(); TracedKernelCallBuilder callBuilder{*this, context, loc, opHandle}; // Map arguments to operands. // This must be accumulated into the OperationState prior to re-dispatch // since the stack is modified at that point. size_t argCount = schema.arguments().size(); assert(stack->size() >= argCount && "stack too short"); for (auto argIt = stack->end() - argCount; argIt != stack->end(); ++argIt) { callBuilder.addOperand(*argIt); } // Invoke the original kernel. redispatchCallback(); // Map returns to results. size_t returnCount = schema.returns().size(); assert(stack->size() >= returnCount && "stack too short"); for (auto returnIt = stack->end() - returnCount; returnIt != stack->end(); ++returnIt) { callBuilder.addResult(*returnIt); } callBuilder.create(); } MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc, const IValue &ival) { if (ival.isScalar()) { return funcBuilder->getScalarConstant(loc, ival.toScalar()); } if (ival.isTensor()) { auto tensor = ival.toTensor(); if (!tensor.defined()) { // Optional tensors ("Tensor?" type) are represented as Tensor ivals // that are undefined. return funcBuilder->getNoneConstant(loc); } // Is it an already mapped tensor? MlirValue mappedValue = funcBuilder->lookupTensor(ival.toTensor()); if (!mlirValueIsNull(mappedValue)) { return mappedValue; } mappedValue = importTensorByValue(ival.toTensor()); assert(mappedValue.ptr); return mappedValue; } if (ival.isBool()) { // TODO: Switch to the numpy.bool type as that is a closer domain match. return funcBuilder->getBoolConstant(loc, ival.toBool()); } if (ival.isList()) { auto list = ival.toList(); std::vector elements; for (IValue element : list) { elements.push_back(mapIValueToMlirValue(loc, element)); } return funcBuilder->buildList(loc, elements); } if (ival.isNone()) { return funcBuilder->getNoneConstant(loc); } if (ival.isDevice()) { // TODO: Do we need to model/preserve device? Currently, just None'ing // it out. return funcBuilder->getNoneConstant(loc); } return {nullptr}; // TODO: Implement mappings for the whole set (relevant to this use case): // _(Tensor) // _(Double) // _(Int) // _(Tuple) // _(String) // _(Blob) // _(GenericList) // _(GenericDict) // _(Future) // _(Device) // _(Object) // _(PyObject) // _(Uninitialized) // _(Capsule) // _(RRef) // _(Generator) } MlirType AcapController::mapIValueToMlirType(MlirLocation loc, const IValue &ival) { if (ival.isScalar()) { return typeMapper.mapFromTorchScalarType(ival.toScalar().type()); } if (ival.isTensor()) { return typeMapper.forwardTensorToType(ival.toTensor()); } if (ival.isBool()) { // TODO: Switch to the numpy.bool type as that is a closer domain match. return mlirIntegerTypeGet(funcBuilder->getContext(), 1); } if (ival.isList()) { return npcompListTypeGet(funcBuilder->getContext()); } if (ival.isNone()) { return npcompNoneTypeGet(funcBuilder->getContext()); } if (ival.isDevice()) { return npcompNoneTypeGet(funcBuilder->getContext()); } return {nullptr}; } MlirValue AcapController::importTensorByValue(at::Tensor tensor) { auto loc = getCurrentLocation(); MlirAttribute valueAttribute = converTensorToMlirElementsAttr(tensor, loc); MlirValue constTensorValue = funcBuilder->getGeneralConstant(loc, valueAttribute); // Create an array from the tensor constant via the // numpy.create_array_from_tensor op. MlirType constArrayType = npcompNdArrayTypeGetFromShaped(mlirAttributeGetType(valueAttribute)); MlirOperationState state = mlirOperationStateGet( toMlirStringRef("numpy.create_array_from_tensor"), loc); mlirOperationStateAddOperands(&state, 1, &constTensorValue); mlirOperationStateAddResults(&state, 1, &constArrayType); MlirOperation constArrayOp = mlirOperationCreate(&state); funcBuilder->getEntryBlockBuilder().insertBeforeTerminator(constArrayOp); MlirValue constArrayValue = mlirOperationGetResult(constArrayOp, 0); funcBuilder->mapTensor(tensor, constArrayValue); return constArrayValue; } TORCH_LIBRARY_IMPL(aten, BackendSelect, m) { // PyTorch logs a warning when kernels are overriden, which is unavoidable // for factory-function BackendSelect kernels (there is not yet a "safe" // override mechanism). So, just silence it. Any of them here are coded // to be a superset of the default functionality, and there are only a few. auto orig_log_level = FLAGS_caffe2_log_level; FLAGS_caffe2_log_level = c10::GLOG_ERROR; // Disable capture of arange: causes it to memorize the resulting tensor. m.impl("arange", &AcapController::arangeBackendSelectKernel); // Restore log level. FLAGS_caffe2_log_level = orig_log_level; } TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) { m.fallback(torch::CppFunction::makeFromBoxedFunction< &AcapController::fallbackKernel>()); } TORCH_LIBRARY_IMPL(aten, ACAP_DISPATCH_KEY, m) { m.impl("copy_", &AcapController::copyUnderKernel); } TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) { // The at::convolution op is special in several ways. First, it presently // does not support boxing, so all of the usual fanciness does not apply // and it cannot be intercepted by generic fallthroughs, which is what // would usually allow us to avoid intercepting it at the gradient phase. // Second, the default implementation (see // aten/src/ATen/native/Convolution.cpp) is very switchy based on hard-coded // assumptions about device type. If we do nothing here, we will at best // intercept an mkldnn_convolution, cudnn_convolution, etc on the backend // dispatch keys. Non standard backends that don't have these switches // just route to aten::convolution_overrideable (see the else in // aten::convolution) as a convenience, but that is mostly a pass-through // (except for 3d convolutions which contain a trailing squeeze that needs // special casing). Therefore, we just intercept the aten::convolution op, // record it specially, and then mask ourselves off and ask the CPU backend // to invoke it. Not awesome. // Presumably this is on someone's list to adapt to the dispatch machinery // in a more appropriate way, but as the core of what the framework is, // perhaps people are reticent to touch it. Maybe someday, this can go away. m.impl_UNBOXED("convolution", &AcapController::convolutionKernel); // Sadly, there is no easy intercept point for the backwards convolution // kernel which allows for chaining to an existing backend. And convolution // is exceptionally special cased in this way, moreso than other ops. // The "solution" is to intercept the backend specific backward convolution // ops, emit it with the signature of the more generic // "convolution_backward_overrideable" op, which is available for generic // backends, and twiddle the parameters needed to get back to that form. // For MKL, which is effectively the CPU implementation, this just means that // some parameters are swapped and the full generality is not supported. // The "right" answer at some point is probably just to implement a // convolution kernel that fully does what is needed and delegates to an // appropriate implementation behind the scenes. m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward); }