[LINALG] Fix handling of size-1 dims in `aten.view` again. (#992)

A previous fix to the handling of size-1 dims in
`aten.view` (https://github.com/llvm/torch-mlir/pull/962) resulted in
the wrong grouping of dimensions when size-1 dims where between two
dims of size greater than 1. This commit fixes that.
pull/1003/head
Ramiro Leal-Cavazos 2022-06-30 18:39:25 -05:00 committed by GitHub
parent f947443f98
commit f204210266
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 14 deletions

View File

@ -142,6 +142,8 @@ TOSA_PASS_SET = {
"DropoutModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNoChangeStaticModule_basic",

View File

@ -321,22 +321,30 @@ public:
reassociation[collapsedDim].push_back(expandedDim++);
} else {
int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
for (int64_t i = expandedDim; i < expandedDimNext; i++) {
int64_t expandedDimSize = expandedShape[i];
if (expandedDimSize == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "expected expanded dim sizes to be known");
}
if (remainingSizeToExpand % expandedDimSize != 0) {
if (expandedDimSize > remainingSizeToExpand &&
remainingSizeToExpand == 1)
break;
// A do-while loop is used here to handle the cases where the
// collapsed shape tensor has a dimension of size 1.
do {
int64_t expandedDimSize = expandedShape[expandedDim];
if (expandedDim >= expandedDimNext ||
expandedShape[expandedDim] == kUnknownSize ||
remainingSizeToExpand % expandedDimSize != 0) {
return rewriter.notifyMatchFailure(
op, "total number of elements mismatch in the expansion");
}
remainingSizeToExpand /= expandedDimSize;
reassociation[collapsedDim].push_back(expandedDim++);
remainingSizeToExpand /= expandedDimSize;
} while (remainingSizeToExpand != 1);
// If all dims until `expandedDimNext` are of size 1, then group those
// with the reassociation for the current `collapsedDim`.
auto expandedShapeSlice =
llvm::makeArrayRef(expandedShape)
.slice(expandedDim, expandedDimNext - expandedDim);
if (llvm::all_of(expandedShapeSlice,
[](int64_t val) { return val == 1; })) {
reassociation[collapsedDim].append(
llvm::to_vector(llvm::seq(expandedDim, expandedDimNext)));
expandedDim = expandedDimNext;
}
}
}

View File

@ -33,6 +33,25 @@ class ViewExpandOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1], torch.float32, True),
])
def forward(self, a):
return a.view(1, 1, 1, 1, 1)
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
def ViewExpandOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1))
# ==============================================================================
class ViewExpandOnesBeforeAndAfterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
@ -42,12 +61,31 @@ class ViewExpandOnesModule(torch.nn.Module):
def forward(self, a):
return a.view(1, 1, 3, 1, 1)
@register_test_case(module_factory=lambda: ViewExpandOnesModule())
def ViewExpandOnesModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: ViewExpandOnesBeforeAndAfterModule())
def ViewExpandOnesBeforeAndAfterModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 3))
# ==============================================================================
class ViewExpandOnesMiddleModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 1, 2], torch.float32, True),
])
def forward(self, a):
return a.view(3, 1, 1, 1, 1, 2)
@register_test_case(module_factory=lambda: ViewExpandOnesMiddleModule())
def ViewExpandOnesMiddleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 2))
# ==============================================================================
class ViewDynamicExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()