Fix dispatch of arange.

* Fixes #107
* I wouldn't say I love what had to be done here. Worth a conversation with the PT devs (probably as part of a rollup of a bunch of this stuff).
pull/115/head pytorch-1.3
Stella Laurenzo 2020-11-12 17:42:19 -08:00
parent b4c7ae1e0c
commit e359167562
3 changed files with 86 additions and 24 deletions

View File

@ -176,14 +176,9 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
for (auto &tensor : tensors) {
MlirValue v = funcBuilder->lookupTensor(tensor);
if (mlirValueIsNull(v)) {
// Finalize this instance so that everything that goes into printing an
// error message does not capture.
hasReturned = true;
// Exclude recursive dispatch in order to print tensor.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
std::stringstream msg;
msg << "Cannot return a tensor that is not from the capture context";
throw std::invalid_argument(msg.str());
debugTrace(
"Return of imported-constant tensor (intentional memorization?)");
v = importTensorByValue(tensor);
}
returnsTypes.push_back(mlirValueGetType(v));
@ -217,12 +212,20 @@ void AcapController::verifyHasNotReturned() {
/* static */
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
Stack *stack) {
auto redispatchCallback = [&]() {
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
// Passthrough.
auto &dispatcher = c10::Dispatcher::singleton();
dispatcher.callBoxed(opHandle, stack);
};
auto current = getCurrentThreadAcapController();
if (!current) {
current->redispatch(opHandle, stack);
redispatchCallback();
return;
}
current->fallbackKernelImpl(opHandle, stack);
current->fallbackKernelImpl(opHandle, stack, redispatchCallback);
}
at::Tensor AcapController::convolutionKernel(
@ -406,25 +409,42 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
return result;
}
at::Tensor AcapController::arangeBackendSelectKernel(
at::Scalar end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
static c10::OperatorName opName{"aten::arange", ""};
auto &dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findOp(opName);
assert(opHandle && "could not find arange op");
// Exclude recursive calls.
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
// Dispatching in this fashion replicates the exact way that PyTorch
// built-in handlers dispatch to BackendSelect kernels.
auto targetDk = c10::computeDispatchKey(dtype, layout, device);
auto opTyped = opHandle->typed<at::Tensor(
at::Scalar end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory)>();
return opTyped.callWithDispatchKey(targetDk, end, dtype, layout, device,
pin_memory);
}
MlirLocation AcapController::getCurrentLocation() {
return mlirLocationUnknownGet(funcBuilder->getContext());
}
void AcapController::redispatch(const c10::OperatorHandle &opHandle,
c10::Stack *stack) {
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
// Passthrough.
auto &dispatcher = c10::Dispatcher::singleton();
dispatcher.callBoxed(opHandle, stack);
}
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
Stack *stack) {
void AcapController::fallbackKernelImpl(
const OperatorHandle &opHandle, Stack *stack,
std::function<void()> redispatchCallback) {
verifyHasNotReturned();
if (isDebugTraceEnabled()) {
std::stringstream s;
s << "Fallback (boxed) dispatch: " << opHandle.schema();
s << "Fallback (boxed) dispatch: " << opHandle.schema()
<< " (stack size=" << stack->size() << ")";
debugTrace(s.str());
}
@ -454,7 +474,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
}
// Invoke the original kernel.
redispatch(opHandle, stack);
redispatchCallback();
// Map returns to results.
size_t returnCount = schema.returns().size();
@ -642,6 +662,21 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
return constArrayValue;
}
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
// PyTorch logs a warning when kernels are overriden, which is unavoidable
// for factory-function BackendSelect kernels (there is not yet a "safe"
// override mechanism). So, just silence it. Any of them here are coded
// to be a superset of the default functionality, and there are only a few.
auto orig_log_level = FLAGS_caffe2_log_level;
FLAGS_caffe2_log_level = c10::GLOG_ERROR;
// Disable capture of arange: causes it to memorize the resulting tensor.
m.impl("arange", &AcapController::arangeBackendSelectKernel);
// Restore log level.
FLAGS_caffe2_log_level = orig_log_level;
}
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());

View File

@ -71,6 +71,13 @@ public:
static at::Tensor &copyUnderKernel(at::Tensor &self, const at::Tensor &src,
bool non_blocking);
// Backend select kernel for arange factory function.
static at::Tensor
arangeBackendSelectKernel(at::Scalar end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory);
private:
/// Builds a kernel call step by step.
class KernelCallBuilder {
@ -97,7 +104,8 @@ private:
MlirLocation getCurrentLocation();
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
c10::Stack *stack);
c10::Stack *stack,
std::function<void()> redispatchCallback);
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
/// Imports a tensor by value (as a constant), remembering the association.

View File

@ -0,0 +1,19 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
torch_mlir.debug_trace_to_stderr()
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("arange_test", []) as f:
x = torch.arange(10)
f.returns([x])
# CHECK: %[[CST:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi64>
# CHECK: %[[R:.*]] = numpy.create_array_from_tensor %[[CST]]
# CHECK: return %[[R]]
mb.module.operation.print()