mirror of https://github.com/llvm/torch-mlir
Fix bug with transpose of negative dims
Summary: This commit fixes an off-by-one error in how negative dimensiosn were being handled in the lowering of transpose. This commit also adds tests to transpose and unsqueeze to test negative dimensions.pull/376/head
parent
a23d77100b
commit
8bfb819d35
|
@ -222,6 +222,24 @@ def TransposeIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 2))
|
module.forward(tu.rand(3, 4, 2))
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeIntNegDimsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 2], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.transpose(x, -1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
|
||||||
|
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 2))
|
||||||
|
|
||||||
|
|
||||||
class TensorsConcatModule(torch.nn.Module):
|
class TensorsConcatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -132,6 +132,31 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
# As mentioned in `unsqueeze` docstring,
|
||||||
|
# valid dim values are [-input.dim()-1, input.dim()+1).
|
||||||
|
# This tests the lower bound
|
||||||
|
return torch.unsqueeze(a, -3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseUnsqueezeNegDimsModule())
|
||||||
|
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -1890,13 +1890,11 @@ public:
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto elementType = inType.getElementType();
|
auto elementType = inType.getElementType();
|
||||||
|
|
||||||
if (dim0 < 0)
|
dim0 = toPositiveDim(dim0, inputRank);
|
||||||
dim0 += inputRank + 1;
|
if (!isValidDim(dim0, inputRank))
|
||||||
if (dim0 < 0 || dim0 >= inputRank)
|
|
||||||
return rewriter.notifyMatchFailure(op, "dim0 out of range");
|
return rewriter.notifyMatchFailure(op, "dim0 out of range");
|
||||||
if (dim1 < 0)
|
dim1 = toPositiveDim(dim1, inputRank);
|
||||||
dim1 += inputRank + 1;
|
if (!isValidDim(dim1, inputRank))
|
||||||
if (dim1 < 0 || dim1 >= inputRank)
|
|
||||||
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
|
|
Loading…
Reference in New Issue