From 30a42dea32c8306fa63a652ffe690f8342d5c323 Mon Sep 17 00:00:00 2001 From: Bairen Yi Date: Mon, 15 Mar 2021 11:06:28 +0800 Subject: [PATCH] Also fallback autograd dispatch keys for torchvision::nms Signed-off-by: Bairen Yi --- .../pytorch/csrc/builder/acap_dispatch.cpp | 11 ++++++-- .../test/acap_export/test_export_nms.py | 25 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 frontends/pytorch/test/acap_export/test_export_nms.py diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.cpp b/frontends/pytorch/csrc/builder/acap_dispatch.cpp index ed8fb4485..a9ac1f2f0 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/builder/acap_dispatch.cpp @@ -168,7 +168,8 @@ void AcapController::fallbackKernel(const OperatorHandle &opHandle, Stack *stack) { auto redispatchCallback = [&]() { // Exclude recursive dispatch to this kernel. - c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey); + c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; + c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); // Passthrough. auto &dispatcher = c10::Dispatcher::singleton(); dispatcher.callBoxed(opHandle, stack); @@ -403,7 +404,8 @@ void AcapController::fallbackKernelImpl( } // Exclude recursive dispatch to this kernel. - c10::impl::ExcludeDispatchKeyGuard exclusion(kAcapDispatchKey); + c10::DispatchKeySet keySet{kAcapDispatchKey, kAcapGradDispatchKey}; + c10::impl::ExcludeDispatchKeyGuard exclusion(keySet); const FunctionSchema &schema = opHandle.schema(); @@ -609,3 +611,8 @@ TORCH_LIBRARY_IMPL(aten, ACAP_GRAD_DISPATCH_KEY, m) { // appropriate implementation behind the scenes. m.impl("mkldnn_convolution_backward", AcapController::mklConvolutionBackward); } + +TORCH_LIBRARY_IMPL(_, ACAP_GRAD_DISPATCH_KEY, m) { + m.fallback(torch::CppFunction::makeFromBoxedFunction< + &AcapController::fallbackKernel>()); +} diff --git a/frontends/pytorch/test/acap_export/test_export_nms.py b/frontends/pytorch/test/acap_export/test_export_nms.py new file mode 100644 index 000000000..bda13aee0 --- /dev/null +++ b/frontends/pytorch/test/acap_export/test_export_nms.py @@ -0,0 +1,25 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torchvision +import torch_mlir + +# RUN: %PYTHON %s | npcomp-opt | FileCheck %s + +mb = torch_mlir.ModuleBuilder() + +boxes = torch.rand(50, 4) +scores = torch.rand(50) + +with mb.capture_function("nms", [boxes, scores]) as f: + result = torch.ops.torchvision.nms(boxes, scores, 0.5) + f.returns([result]) + +# CHECK-LABEL: func @nms(%arg0: !numpy.ndarray<[50,4]:f32>, %arg1: !numpy.ndarray<[50]:f32>) -> !numpy.ndarray<[50]:i64> { +# CHECK: %[[VAL_0:.*]] = constant 5.000000e-01 : f64 +# CHECK: %[[VAL_1:.*]] = torch.kernel_call "torchvision::nms" %arg0, %arg1, %[[VAL_0]] : (!numpy.ndarray<[50,4]:f32>, !numpy.ndarray<[50]:f32>, f64) -> !numpy.ndarray<[50]:i64> {sigArgTypes = ["Tensor", "Tensor", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} +# CHECK: return %[[VAL_1]] : !numpy.ndarray<[50]:i64> +# CHECK: } +mb.module.operation.print(large_elements_limit=2)