mirror of https://github.com/llvm/torch-mlir
Also fallback autograd dispatch keys for torchvision::nms
Signed-off-by: Bairen Yi <yibairen.byron@bytedance.com>pull/193/head
parent
e7b96ebefc
commit
30a42dea32
|
@ -168,7 +168,8 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle,
|
|||
Stack *stack) {
|
||||
auto redispatchCallback = [&]() {
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
// Passthrough.
|
||||
auto &dispatcher = c10::Dispatcher::singleton();
|
||||
dispatcher.callBoxed(opHandle, stack);
|
||||
|
@ -403,7 +404,8 @@ void AcapController::fallbackKernelImpl(
|
|||
}
|
||||
|
||||
// Exclude recursive dispatch to this kernel.
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey);
|
||||
c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey};
|
||||
c10::impl::ExcludeDispatchKeyGuard exclusion(keySet);
|
||||
|
||||
const FunctionSchema &schema = opHandle.schema();
|
||||
|
||||
|
@ -609,3 +611,8 @@ TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) {
|
|||
// appropriate implementation behind the scenes.
|
||||
m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, ACAP_GRAD_DISPATCH_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<
|
||||
&AcapController::fallbackKernel>());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
boxes = torch.rand(50, 4)
|
||||
scores = torch.rand(50)
|
||||
|
||||
with mb.capture_function("nms", [boxes, scores]) as f:
|
||||
result = torch.ops.torchvision.nms(boxes, scores, 0.5)
|
||||
f.returns([result])
|
||||
|
||||
# CHECK-LABEL: func @nms(%arg0: !numpy.ndarray<[50,4]:f32>, %arg1: !numpy.ndarray<[50]:f32>) -> !numpy.ndarray<[50]:i64> {
|
||||
# CHECK: %[[VAL_0:.*]] = constant 5.000000e-01 : f64
|
||||
# CHECK: %[[VAL_1:.*]] = torch.kernel_call "torchvision::nms" %arg0, %arg1, %[[VAL_0]] : (!numpy.ndarray<[50,4]:f32>, !numpy.ndarray<[50]:f32>, f64) -> !numpy.ndarray<[50]:i64> {sigArgTypes = ["Tensor", "Tensor", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
# CHECK: return %[[VAL_1]] : !numpy.ndarray<[50]:i64>
|
||||
# CHECK: }
|
||||
mb.module.operation.print(large_elements_limit=2)
|
Loading…
Reference in New Issue