mirror of https://github.com/llvm/torch-mlir
Fix dispatch of arange.
* Fixes #107 * I wouldn't say I love what had to be done here. Worth a conversation with the PT devs (probably as part of a rollup of a bunch of this stuff).pull/115/head pytorch-1.3
parent
b4c7ae1e0c
commit
e359167562
|
@ -176,14 +176,9 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
|||
for (auto &tensor : tensors) {
|
||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||
if (mlirValueIsNull(v)) {
|
||||
// Finalize this instance so that everything that goes into printing an
|
||||
// error message does not capture.
|
||||
hasReturned = true;
|
||||
// Exclude recursive dispatch in order to print tensor.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
std::stringstream msg;
|
||||
msg << "Cannot return a tensor that is not from the capture context";
|
||||
throw std::invalid_argument(msg.str());
|
||||
debugTrace(
|
||||
"Return of imported-constant tensor (intentional memorization?)");
|
||||
v = importTensorByValue(tensor);
|
||||
}
|
||||
|
||||
returnsTypes.push_back(mlirValueGetType(v));
|
||||
|
@ -217,12 +212,20 @@ void AcapController::verifyHasNotReturned() {
|
|||
/* static */
|
||||
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||
Stack *stack) {
|
||||
auto redispatchCallback = [&]() {
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
// Passthrough.
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
dispatcher.callBoxed(opHandle, stack);
|
||||
};
|
||||
|
||||
auto current = getCurrentThreadAcapController();
|
||||
if (!current) {
|
||||
current->redispatch(opHandle, stack);
|
||||
redispatchCallback();
|
||||
return;
|
||||
}
|
||||
current->fallbackKernelImpl(opHandle, stack);
|
||||
current->fallbackKernelImpl(opHandle, stack, redispatchCallback);
|
||||
}
|
||||
|
||||
at::Tensor AcapController::convolutionKernel(
|
||||
|
@ -406,25 +409,42 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
|
|||
return result;
|
||||
}
|
||||
|
||||
at::Tensor AcapController::arangeBackendSelectKernel(
|
||||
at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
static c10::OperatorName opName{"aten::arange", ""};
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
auto opHandle = dispatcher.findOp(opName);
|
||||
assert(opHandle && "could not find arange op");
|
||||
|
||||
// Exclude recursive calls.
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
// Dispatching in this fashion replicates the exact way that PyTorch
|
||||
// built-in handlers dispatch to BackendSelect kernels.
|
||||
auto targetDk = c10::computeDispatchKey(dtype, layout, device);
|
||||
auto opTyped = opHandle->typed<at::Tensor(
|
||||
at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory)>();
|
||||
return opTyped.callWithDispatchKey(targetDk, end, dtype, layout, device,
|
||||
pin_memory);
|
||||
}
|
||||
|
||||
MlirLocation AcapController::getCurrentLocation() {
|
||||
return mlirLocationUnknownGet(funcBuilder->getContext());
|
||||
}
|
||||
|
||||
void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
||||
c10::Stack *stack) {
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
// Passthrough.
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
dispatcher.callBoxed(opHandle, stack);
|
||||
}
|
||||
|
||||
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||
Stack *stack) {
|
||||
void AcapController::fallbackKernelImpl(
|
||||
const OperatorHandle &opHandle, Stack *stack,
|
||||
std::function<void()> redispatchCallback) {
|
||||
verifyHasNotReturned();
|
||||
if (isDebugTraceEnabled()) {
|
||||
std::stringstream s;
|
||||
s << "Fallback (boxed) dispatch: " << opHandle.schema();
|
||||
s << "Fallback (boxed) dispatch: " << opHandle.schema()
|
||||
<< " (stack size=" << stack->size() << ")";
|
||||
debugTrace(s.str());
|
||||
}
|
||||
|
||||
|
@ -454,7 +474,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
|||
}
|
||||
|
||||
// Invoke the original kernel.
|
||||
redispatch(opHandle, stack);
|
||||
redispatchCallback();
|
||||
|
||||
// Map returns to results.
|
||||
size_t returnCount = schema.returns().size();
|
||||
|
@ -642,6 +662,21 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
return constArrayValue;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
|
||||
// PyTorch logs a warning when kernels are overriden, which is unavoidable
|
||||
// for factory-function BackendSelect kernels (there is not yet a "safe"
|
||||
// override mechanism). So, just silence it. Any of them here are coded
|
||||
// to be a superset of the default functionality, and there are only a few.
|
||||
auto orig_log_level = FLAGS_caffe2_log_level;
|
||||
FLAGS_caffe2_log_level = c10::GLOG_ERROR;
|
||||
|
||||
// Disable capture of arange: causes it to memorize the resulting tensor.
|
||||
m.impl("arange", &AcapController::arangeBackendSelectKernel);
|
||||
|
||||
// Restore log level.
|
||||
FLAGS_caffe2_log_level = orig_log_level;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||
&AcapController::fallbackKernel>());
|
||||
|
|
|
@ -71,6 +71,13 @@ public:
|
|||
static at::Tensor ©UnderKernel(at::Tensor &self, const at::Tensor &src,
|
||||
bool non_blocking);
|
||||
|
||||
// Backend select kernel for arange factory function.
|
||||
static at::Tensor
|
||||
arangeBackendSelectKernel(at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||
c10::optional<at::Layout> layout,
|
||||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory);
|
||||
|
||||
private:
|
||||
/// Builds a kernel call step by step.
|
||||
class KernelCallBuilder {
|
||||
|
@ -97,7 +104,8 @@ private:
|
|||
MlirLocation getCurrentLocation();
|
||||
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
||||
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
||||
c10::Stack *stack);
|
||||
c10::Stack *stack,
|
||||
std::function<void()> redispatchCallback);
|
||||
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
|
||||
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
|
||||
/// Imports a tensor by value (as a constant), remembering the association.
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
torch_mlir.debug_trace_to_stderr()
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("arange_test", []) as f:
|
||||
x = torch.arange(10)
|
||||
f.returns([x])
|
||||
|
||||
# CHECK: %[[CST:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi64>
|
||||
# CHECK: %[[R:.*]] = numpy.create_array_from_tensor %[[CST]]
|
||||
# CHECK: return %[[R]]
|
||||
mb.module.operation.print()
|
Loading…
Reference in New Issue