[linalg] Add handling for leadin and trailing size-1 dims in ViewOp

This commit adds to the lowering of `aten.view` handling for the
following cases:

- `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)`
- `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))`
- `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)`
- `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)`
pull/2502/head snapshot-20231004.981
Ramiro Leal-Cavazos 2023-10-03 19:24:01 +00:00
parent 1c508af0ba
commit 2e5d65064c
2 changed files with 128 additions and 4 deletions

View File

@ -193,6 +193,9 @@ public:
ArrayRef<int64_t> yDims, ArrayRef<int64_t> yDims,
SmallVector<int64_t> &xIndices, SmallVector<int64_t> &xIndices,
SmallVector<int64_t> &yIndices) { SmallVector<int64_t> &yIndices) {
if (xDims.empty() || yDims.empty())
return failure();
auto isValidReduction = [](int64_t expectedReductionProduct, auto isValidReduction = [](int64_t expectedReductionProduct,
ArrayRef<int64_t> arrayToReduce) -> bool { ArrayRef<int64_t> arrayToReduce) -> bool {
if (llvm::count(arrayToReduce, kUnknownSize) > 0 || if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
@ -262,6 +265,8 @@ public:
// all the dimensions in `outputShape`. // all the dimensions in `outputShape`.
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape, static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
MutableArrayRef<int64_t> outputShape) { MutableArrayRef<int64_t> outputShape) {
if (inputShape.empty() || outputShape.empty())
return;
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
if (inputDynamicDimCount + outputDynamicDimCount != 1) if (inputDynamicDimCount + outputDynamicDimCount != 1)
@ -488,12 +493,29 @@ public:
outputDim = outputAssociations.back().back() + 1; outputDim = outputAssociations.back().back() + 1;
} }
// Append the associations for the dims matching `aten.size.int` // Handle any leading or trailing size-1 dimensions and append the
if (nextUnchangedInput != inputRank && // associations for the dims matching `aten.size.int`.
nextUnchangedOutput != resultRank) { if (nextUnchangedInput != inputRank) {
assert(nextUnchangedOutput != resultRank &&
"`nextUnchangedInput` and `nextUnchangedOutput` should equal "
"the respective input and output rank at the same time");
inputAssociations.emplace_back(); inputAssociations.emplace_back();
outputAssociations.emplace_back(); outputAssociations.emplace_back();
}
while (inputDim <= nextUnchangedInput && inputDim < inputRank) {
if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only collapsing of static size-1 into "
"unchanged dim supported");
}
inputAssociations.back().push_back(inputDim++); inputAssociations.back().push_back(inputDim++);
}
while (outputDim <= nextUnchangedOutput && outputDim < resultRank) {
if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only expanding of static size-1 out of "
"unchanged dim supported");
}
outputAssociations.back().push_back(outputDim++); outputAssociations.back().push_back(outputDim++);
} }
} }

View File

@ -672,6 +672,108 @@ class ViewNegativeStaticModule(torch.nn.Module):
def ViewNegativeStaticModule_basic(module, tu: TestUtils): def ViewNegativeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 128)) module.forward(tu.rand(1, 128))
class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(0), 1, 1, 1)
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))
class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, 1, 1, 1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(0))
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 1, 1, 1))
class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.view(1, 1, 1, a.size(0))
@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule())
def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))
class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 1, -1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(3))
@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule())
def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 1, 128))
class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.view(1, 1, 1, a.size(0), 1, 1, 1)
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule())
def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))
class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 1, -1, 1, 1, 1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(3))
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule())
def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1))
# ============================================================================== # ==============================================================================
class ReshapeAliasExpandModule(torch.nn.Module): class ReshapeAliasExpandModule(torch.nn.Module):