mirror of https://github.com/llvm/torch-mlir
Fix dispatch of arange.
* Fixes #107 * I wouldn't say I love what had to be done here. Worth a conversation with the PT devs (probably as part of a rollup of a bunch of this stuff).pull/115/head pytorch-1.3
parent
b4c7ae1e0c
commit
e359167562
|
@ -176,14 +176,9 @@ void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||||
for (auto &tensor : tensors) {
|
for (auto &tensor : tensors) {
|
||||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||||
if (mlirValueIsNull(v)) {
|
if (mlirValueIsNull(v)) {
|
||||||
// Finalize this instance so that everything that goes into printing an
|
debugTrace(
|
||||||
// error message does not capture.
|
"Return of imported-constant tensor (intentional memorization?)");
|
||||||
hasReturned = true;
|
v = importTensorByValue(tensor);
|
||||||
// Exclude recursive dispatch in order to print tensor.
|
|
||||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
|
||||||
std::stringstream msg;
|
|
||||||
msg << "Cannot return a tensor that is not from the capture context";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
returnsTypes.push_back(mlirValueGetType(v));
|
returnsTypes.push_back(mlirValueGetType(v));
|
||||||
|
@ -217,12 +212,20 @@ void AcapController::verifyHasNotReturned() {
|
||||||
/* static */
|
/* static */
|
||||||
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
||||||
Stack *stack) {
|
Stack *stack) {
|
||||||
|
auto redispatchCallback = [&]() {
|
||||||
|
// Exclude recursive dispatch to this kernel.
|
||||||
|
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||||
|
// Passthrough.
|
||||||
|
auto &dispatcher = c10::Dispatcher::singleton();
|
||||||
|
dispatcher.callBoxed(opHandle, stack);
|
||||||
|
};
|
||||||
|
|
||||||
auto current = getCurrentThreadAcapController();
|
auto current = getCurrentThreadAcapController();
|
||||||
if (!current) {
|
if (!current) {
|
||||||
current->redispatch(opHandle, stack);
|
redispatchCallback();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
current->fallbackKernelImpl(opHandle, stack);
|
current->fallbackKernelImpl(opHandle, stack, redispatchCallback);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor AcapController::convolutionKernel(
|
at::Tensor AcapController::convolutionKernel(
|
||||||
|
@ -406,25 +409,42 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor AcapController::arangeBackendSelectKernel(
|
||||||
|
at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||||
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
|
static c10::OperatorName opName{"aten::arange", ""};
|
||||||
|
auto &dispatcher = c10::Dispatcher::singleton();
|
||||||
|
auto opHandle = dispatcher.findOp(opName);
|
||||||
|
assert(opHandle && "could not find arange op");
|
||||||
|
|
||||||
|
// Exclude recursive calls.
|
||||||
|
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||||
|
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||||
|
|
||||||
|
// Dispatching in this fashion replicates the exact way that PyTorch
|
||||||
|
// built-in handlers dispatch to BackendSelect kernels.
|
||||||
|
auto targetDk = c10::computeDispatchKey(dtype, layout, device);
|
||||||
|
auto opTyped = opHandle->typed<at::Tensor(
|
||||||
|
at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||||
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory)>();
|
||||||
|
return opTyped.callWithDispatchKey(targetDk, end, dtype, layout, device,
|
||||||
|
pin_memory);
|
||||||
|
}
|
||||||
|
|
||||||
MlirLocation AcapController::getCurrentLocation() {
|
MlirLocation AcapController::getCurrentLocation() {
|
||||||
return mlirLocationUnknownGet(funcBuilder->getContext());
|
return mlirLocationUnknownGet(funcBuilder->getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
void AcapController::redispatch(const c10::OperatorHandle &opHandle,
|
void AcapController::fallbackKernelImpl(
|
||||||
c10::Stack *stack) {
|
const OperatorHandle &opHandle, Stack *stack,
|
||||||
// Exclude recursive dispatch to this kernel.
|
std::function<void()> redispatchCallback) {
|
||||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
|
||||||
// Passthrough.
|
|
||||||
auto &dispatcher = c10::Dispatcher::singleton();
|
|
||||||
dispatcher.callBoxed(opHandle, stack);
|
|
||||||
}
|
|
||||||
|
|
||||||
void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
|
||||||
Stack *stack) {
|
|
||||||
verifyHasNotReturned();
|
verifyHasNotReturned();
|
||||||
if (isDebugTraceEnabled()) {
|
if (isDebugTraceEnabled()) {
|
||||||
std::stringstream s;
|
std::stringstream s;
|
||||||
s << "Fallback (boxed) dispatch: " << opHandle.schema();
|
s << "Fallback (boxed) dispatch: " << opHandle.schema()
|
||||||
|
<< " (stack size=" << stack->size() << ")";
|
||||||
debugTrace(s.str());
|
debugTrace(s.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -454,7 +474,7 @@ void AcapController::fallbackKernelImpl(const OperatorHandle &opHandle,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke the original kernel.
|
// Invoke the original kernel.
|
||||||
redispatch(opHandle, stack);
|
redispatchCallback();
|
||||||
|
|
||||||
// Map returns to results.
|
// Map returns to results.
|
||||||
size_t returnCount = schema.returns().size();
|
size_t returnCount = schema.returns().size();
|
||||||
|
@ -642,6 +662,21 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
return constArrayValue;
|
return constArrayValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
|
||||||
|
// PyTorch logs a warning when kernels are overriden, which is unavoidable
|
||||||
|
// for factory-function BackendSelect kernels (there is not yet a "safe"
|
||||||
|
// override mechanism). So, just silence it. Any of them here are coded
|
||||||
|
// to be a superset of the default functionality, and there are only a few.
|
||||||
|
auto orig_log_level = FLAGS_caffe2_log_level;
|
||||||
|
FLAGS_caffe2_log_level = c10::GLOG_ERROR;
|
||||||
|
|
||||||
|
// Disable capture of arange: causes it to memorize the resulting tensor.
|
||||||
|
m.impl("arange", &AcapController::arangeBackendSelectKernel);
|
||||||
|
|
||||||
|
// Restore log level.
|
||||||
|
FLAGS_caffe2_log_level = orig_log_level;
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
TORCH_LIBRARY_IMPL(_, ACAP_DISPATCH_KEY, m) {
|
||||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||||
&AcapController::fallbackKernel>());
|
&AcapController::fallbackKernel>());
|
||||||
|
|
|
@ -71,6 +71,13 @@ public:
|
||||||
static at::Tensor ©UnderKernel(at::Tensor &self, const at::Tensor &src,
|
static at::Tensor ©UnderKernel(at::Tensor &self, const at::Tensor &src,
|
||||||
bool non_blocking);
|
bool non_blocking);
|
||||||
|
|
||||||
|
// Backend select kernel for arange factory function.
|
||||||
|
static at::Tensor
|
||||||
|
arangeBackendSelectKernel(at::Scalar end, c10::optional<at::ScalarType> dtype,
|
||||||
|
c10::optional<at::Layout> layout,
|
||||||
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Builds a kernel call step by step.
|
/// Builds a kernel call step by step.
|
||||||
class KernelCallBuilder {
|
class KernelCallBuilder {
|
||||||
|
@ -97,7 +104,8 @@ private:
|
||||||
MlirLocation getCurrentLocation();
|
MlirLocation getCurrentLocation();
|
||||||
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
void redispatch(const c10::OperatorHandle &opHandle, c10::Stack *stack);
|
||||||
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
void fallbackKernelImpl(const c10::OperatorHandle &opHandle,
|
||||||
c10::Stack *stack);
|
c10::Stack *stack,
|
||||||
|
std::function<void()> redispatchCallback);
|
||||||
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
|
MlirValue mapIValueToMlirValue(MlirLocation loc, const c10::IValue &ival);
|
||||||
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
|
MlirType mapIValueToMlirType(MlirLocation loc, const c10::IValue &ival);
|
||||||
/// Imports a tensor by value (as a constant), remembering the association.
|
/// Imports a tensor by value (as a constant), remembering the association.
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
torch_mlir.debug_trace_to_stderr()
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
with mb.capture_function("arange_test", []) as f:
|
||||||
|
x = torch.arange(10)
|
||||||
|
f.returns([x])
|
||||||
|
|
||||||
|
# CHECK: %[[CST:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi64>
|
||||||
|
# CHECK: %[[R:.*]] = numpy.create_array_from_tensor %[[CST]]
|
||||||
|
# CHECK: return %[[R]]
|
||||||
|
mb.module.operation.print()
|
Loading…
Reference in New Issue