mirror of https://github.com/llvm/torch-mlir
[LINALG] Fix handling of size-1 dims in `aten.view` again. (#992)
A previous fix to the handling of size-1 dims in `aten.view` (https://github.com/llvm/torch-mlir/pull/962) resulted in the wrong grouping of dimensions when size-1 dims where between two dims of size greater than 1. This commit fixes that.pull/1003/head
parent
f947443f98
commit
f204210266
|
@ -142,6 +142,8 @@ TOSA_PASS_SET = {
|
|||
"DropoutModule_basic",
|
||||
"ViewExpandModule_basic",
|
||||
"ViewExpandOnesModule_basic",
|
||||
"ViewExpandOnesBeforeAndAfterModule_basic",
|
||||
"ViewExpandOnesMiddleModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
|
|
|
@ -321,22 +321,30 @@ public:
|
|||
reassociation[collapsedDim].push_back(expandedDim++);
|
||||
} else {
|
||||
int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
|
||||
for (int64_t i = expandedDim; i < expandedDimNext; i++) {
|
||||
int64_t expandedDimSize = expandedShape[i];
|
||||
if (expandedDimSize == kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected expanded dim sizes to be known");
|
||||
}
|
||||
if (remainingSizeToExpand % expandedDimSize != 0) {
|
||||
if (expandedDimSize > remainingSizeToExpand &&
|
||||
remainingSizeToExpand == 1)
|
||||
break;
|
||||
// A do-while loop is used here to handle the cases where the
|
||||
// collapsed shape tensor has a dimension of size 1.
|
||||
do {
|
||||
int64_t expandedDimSize = expandedShape[expandedDim];
|
||||
if (expandedDim >= expandedDimNext ||
|
||||
expandedShape[expandedDim] == kUnknownSize ||
|
||||
remainingSizeToExpand % expandedDimSize != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "total number of elements mismatch in the expansion");
|
||||
}
|
||||
|
||||
remainingSizeToExpand /= expandedDimSize;
|
||||
reassociation[collapsedDim].push_back(expandedDim++);
|
||||
remainingSizeToExpand /= expandedDimSize;
|
||||
} while (remainingSizeToExpand != 1);
|
||||
|
||||
// If all dims until `expandedDimNext` are of size 1, then group those
|
||||
// with the reassociation for the current `collapsedDim`.
|
||||
auto expandedShapeSlice =
|
||||
llvm::makeArrayRef(expandedShape)
|
||||
.slice(expandedDim, expandedDimNext - expandedDim);
|
||||
if (llvm::all_of(expandedShapeSlice,
|
||||
[](int64_t val) { return val == 1; })) {
|
||||
reassociation[collapsedDim].append(
|
||||
llvm::to_vector(llvm::seq(expandedDim, expandedDimNext)));
|
||||
expandedDim = expandedDimNext;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,25 @@ class ViewExpandOnesModule(torch.nn.Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(1, 1, 1, 1, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
|
||||
def ViewExpandOnesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -42,12 +61,31 @@ class ViewExpandOnesModule(torch.nn.Module):
|
|||
def forward(self, a):
|
||||
return a.view(1, 1, 3, 1, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
|
||||
def ViewExpandOnesModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
|
||||
def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandOnesMiddleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 1, 2], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, 1, 1, 1, 1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandOnesMiddleModule())
|
||||
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewDynamicExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue