From 8bfb819d35cfe5857edc1d76706b3985166e82c6 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 25 Oct 2021 18:52:51 +0000 Subject: [PATCH] Fix bug with transpose of negative dims Summary: This commit fixes an off-by-one error in how negative dimensiosn were being handled in the lowering of transpose. This commit also adds tests to transpose and unsqueeze to test negative dimensions. --- e2e_testing/torchscript/basic.py | 18 +++++++++++++ e2e_testing/torchscript/elementwise.py | 25 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 10 +++----- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 588d10f49..fa327d0ed 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -222,6 +222,24 @@ def TransposeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) +class TransposeIntNegDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 2], torch.float32, True), + ]) + def forward(self, x): + return torch.transpose(x, -1, -2) + + +@register_test_case(module_factory=lambda: TransposeIntNegDimsModule()) +def TransposeIntNegDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2)) + + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 2bd03e45b..e98aa1099 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -132,6 +132,31 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + # As mentioned in `unsqueeze` docstring, + # valid dim values are [-input.dim()-1, input.dim()+1). + # This tests the lower bound + return torch.unsqueeze(a, -3) + + +@register_test_case( + module_factory=lambda: ElementwiseUnsqueezeNegDimsModule()) +def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3)) + + +# ============================================================================== + + class ElementwiseFlattenBroadcastModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c60075cc6..121f41b2c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1890,13 +1890,11 @@ public: .cast(); auto elementType = inType.getElementType(); - if (dim0 < 0) - dim0 += inputRank + 1; - if (dim0 < 0 || dim0 >= inputRank) + dim0 = toPositiveDim(dim0, inputRank); + if (!isValidDim(dim0, inputRank)) return rewriter.notifyMatchFailure(op, "dim0 out of range"); - if (dim1 < 0) - dim1 += inputRank + 1; - if (dim1 < 0 || dim1 >= inputRank) + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range"); auto loc = op.getLoc();