Also fallback autograd dispatch keys for torchvision::nms

Signed-off-by: Bairen Yi <yibairen.byron@bytedance.com>
pull/193/head
Bairen Yi 2021-03-15 11:06:28 +08:00 committed by Sean Silva
parent e7b96ebefc
commit 30a42dea32
2 changed files with 34 additions and 2 deletions

View File

@ -168,7 +168,8 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
Stack *stack) {
auto redispatchCallback = [&]() {
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
// Passthrough.
auto &dispatcher = c10::Dispatcher::singleton();
dispatcher.callBoxed(opHandle, stack);
@ -403,7 +404,8 @@ void AcapController::fallbackKernelImpl(
}
// Exclude recursive dispatch to this kernel.
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
const FunctionSchema &schema = opHandle.schema();
@ -609,3 +611,8 @@ TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
// appropriate implementation behind the scenes.
m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward);
}
TORCH_LIBRARY_IMPL(_, ACAP_GRAD_DISPATCH_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<
&AcapController::fallbackKernel>());
}

View File

@ -0,0 +1,25 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torchvision
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
boxes = torch.rand(50, 4)
scores = torch.rand(50)
with mb.capture_function("nms", [boxes, scores]) as f:
result = torch.ops.torchvision.nms(boxes, scores, 0.5)
f.returns([result])
# CHECK-LABEL: func @nms(%arg0: !numpy.ndarray<[50,4]:f32>, %arg1: !numpy.ndarray<[50]:f32>) -> !numpy.ndarray<[50]:i64> {
# CHECK: %[[VAL_0:.*]] = constant 5.000000e-01 : f64
# CHECK: %[[VAL_1:.*]] = torch.kernel_call "torchvision::nms" %arg0, %arg1, %[[VAL_0]] : (!numpy.ndarray<[50,4]:f32>, !numpy.ndarray<[50]:f32>, f64) -> !numpy.ndarray<[50]:i64> {sigArgTypes = ["Tensor", "Tensor", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[VAL_1]] : !numpy.ndarray<[50]:i64>
# CHECK: }
mb.module.operation.print(large_elements_limit=2)