Fix bug with transpose of negative dims

Summary:
This commit fixes an off-by-one error in how negative dimensiosn were
being handled in the lowering of transpose. This commit also adds
tests to transpose and unsqueeze to test negative dimensions.
pull/376/head
Ramiro Leal-Cavazos 2021-10-25 18:52:51 +00:00 committed by Yi Zhang
parent a23d77100b
commit 8bfb819d35
3 changed files with 47 additions and 6 deletions

View File

@ -222,6 +222,24 @@ def TransposeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2)) module.forward(tu.rand(3, 4, 2))
class TransposeIntNegDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 2], torch.float32, True),
])
def forward(self, x):
return torch.transpose(x, -1, -2)
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
class TensorsConcatModule(torch.nn.Module): class TensorsConcatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -132,6 +132,31 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
# As mentioned in `unsqueeze` docstring,
# valid dim values are [-input.dim()-1, input.dim()+1).
# This tests the lower bound
return torch.unsqueeze(a, -3)
@register_test_case(
module_factory=lambda: ElementwiseUnsqueezeNegDimsModule())
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3))
# ==============================================================================
class ElementwiseFlattenBroadcastModule(torch.nn.Module): class ElementwiseFlattenBroadcastModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1890,13 +1890,11 @@ public:
.cast<RankedTensorType>(); .cast<RankedTensorType>();
auto elementType = inType.getElementType(); auto elementType = inType.getElementType();
if (dim0 < 0) dim0 = toPositiveDim(dim0, inputRank);
dim0 += inputRank + 1; if (!isValidDim(dim0, inputRank))
if (dim0 < 0 || dim0 >= inputRank)
return rewriter.notifyMatchFailure(op, "dim0 out of range"); return rewriter.notifyMatchFailure(op, "dim0 out of range");
if (dim1 < 0) dim1 = toPositiveDim(dim1, inputRank);
dim1 += inputRank + 1; if (!isValidDim(dim1, inputRank))
if (dim1 < 0 || dim1 >= inputRank)
return rewriter.notifyMatchFailure(op, "dim1 out of range"); return rewriter.notifyMatchFailure(op, "dim1 out of range");
auto loc = op.getLoc(); auto loc = op.getLoc();