diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 588d10f49..fa327d0ed 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -222,6 +222,24 @@ def TransposeIntModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 2)) +class TransposeIntNegDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 2], torch.float32, True), + ]) + def forward(self, x): + return torch.transpose(x, -1, -2) + + +@register_test_case(module_factory=lambda: TransposeIntNegDimsModule()) +def TransposeIntNegDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2)) + + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 2bd03e45b..e98aa1099 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -132,6 +132,31 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + # As mentioned in `unsqueeze` docstring, + # valid dim values are [-input.dim()-1, input.dim()+1). + # This tests the lower bound + return torch.unsqueeze(a, -3) + + +@register_test_case( + module_factory=lambda: ElementwiseUnsqueezeNegDimsModule()) +def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3)) + + +# ============================================================================== + + class ElementwiseFlattenBroadcastModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c60075cc6..121f41b2c 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1890,13 +1890,11 @@ public: .cast(); auto elementType = inType.getElementType(); - if (dim0 < 0) - dim0 += inputRank + 1; - if (dim0 < 0 || dim0 >= inputRank) + dim0 = toPositiveDim(dim0, inputRank); + if (!isValidDim(dim0, inputRank)) return rewriter.notifyMatchFailure(op, "dim0 out of range"); - if (dim1 < 0) - dim1 += inputRank + 1; - if (dim1 < 0 || dim1 >= inputRank) + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) return rewriter.notifyMatchFailure(op, "dim1 out of range"); auto loc = op.getLoc();