mirror of https://github.com/llvm/torch-mlir
[linalg] Add handling for leadin and trailing size-1 dims in ViewOp
This commit adds to the lowering of `aten.view` handling for the following cases: - `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)` - `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))` - `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)` - `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)`pull/2502/head snapshot-20231004.981
parent
1c508af0ba
commit
2e5d65064c
|
@ -193,6 +193,9 @@ public:
|
||||||
ArrayRef<int64_t> yDims,
|
ArrayRef<int64_t> yDims,
|
||||||
SmallVector<int64_t> &xIndices,
|
SmallVector<int64_t> &xIndices,
|
||||||
SmallVector<int64_t> &yIndices) {
|
SmallVector<int64_t> &yIndices) {
|
||||||
|
if (xDims.empty() || yDims.empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
auto isValidReduction = [](int64_t expectedReductionProduct,
|
auto isValidReduction = [](int64_t expectedReductionProduct,
|
||||||
ArrayRef<int64_t> arrayToReduce) -> bool {
|
ArrayRef<int64_t> arrayToReduce) -> bool {
|
||||||
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
|
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
|
||||||
|
@ -262,6 +265,8 @@ public:
|
||||||
// all the dimensions in `outputShape`.
|
// all the dimensions in `outputShape`.
|
||||||
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
|
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
|
||||||
MutableArrayRef<int64_t> outputShape) {
|
MutableArrayRef<int64_t> outputShape) {
|
||||||
|
if (inputShape.empty() || outputShape.empty())
|
||||||
|
return;
|
||||||
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
|
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
|
||||||
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
|
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
|
||||||
if (inputDynamicDimCount + outputDynamicDimCount != 1)
|
if (inputDynamicDimCount + outputDynamicDimCount != 1)
|
||||||
|
@ -488,12 +493,29 @@ public:
|
||||||
outputDim = outputAssociations.back().back() + 1;
|
outputDim = outputAssociations.back().back() + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the associations for the dims matching `aten.size.int`
|
// Handle any leading or trailing size-1 dimensions and append the
|
||||||
if (nextUnchangedInput != inputRank &&
|
// associations for the dims matching `aten.size.int`.
|
||||||
nextUnchangedOutput != resultRank) {
|
if (nextUnchangedInput != inputRank) {
|
||||||
|
assert(nextUnchangedOutput != resultRank &&
|
||||||
|
"`nextUnchangedInput` and `nextUnchangedOutput` should equal "
|
||||||
|
"the respective input and output rank at the same time");
|
||||||
inputAssociations.emplace_back();
|
inputAssociations.emplace_back();
|
||||||
outputAssociations.emplace_back();
|
outputAssociations.emplace_back();
|
||||||
|
}
|
||||||
|
while (inputDim <= nextUnchangedInput && inputDim < inputRank) {
|
||||||
|
if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only collapsing of static size-1 into "
|
||||||
|
"unchanged dim supported");
|
||||||
|
}
|
||||||
inputAssociations.back().push_back(inputDim++);
|
inputAssociations.back().push_back(inputDim++);
|
||||||
|
}
|
||||||
|
while (outputDim <= nextUnchangedOutput && outputDim < resultRank) {
|
||||||
|
if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only expanding of static size-1 out of "
|
||||||
|
"unchanged dim supported");
|
||||||
|
}
|
||||||
outputAssociations.back().push_back(outputDim++);
|
outputAssociations.back().push_back(outputDim++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -672,6 +672,108 @@ class ViewNegativeStaticModule(torch.nn.Module):
|
||||||
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
|
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 128))
|
module.forward(tu.rand(1, 128))
|
||||||
|
|
||||||
|
class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(a.size(0), 1, 1, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
|
||||||
|
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128))
|
||||||
|
|
||||||
|
class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, 1, 1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(a.size(0))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
|
||||||
|
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128, 1, 1, 1))
|
||||||
|
|
||||||
|
class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(1, 1, 1, a.size(0))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule())
|
||||||
|
def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128))
|
||||||
|
|
||||||
|
class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 1, 1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(a.size(3))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule())
|
||||||
|
def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 1, 128))
|
||||||
|
|
||||||
|
class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(1, 1, 1, a.size(0), 1, 1, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule())
|
||||||
|
def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(128))
|
||||||
|
|
||||||
|
class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 1, 1, -1, 1, 1, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, a):
|
||||||
|
return a.view(a.size(3))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule())
|
||||||
|
def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class ReshapeAliasExpandModule(torch.nn.Module):
|
class ReshapeAliasExpandModule(torch.nn.Module):
|
||||||
|
@ -710,4 +812,4 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
|
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
|
||||||
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
|
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 4))
|
module.forward(tu.rand(2, 4))
|
||||||
|
|
Loading…
Reference in New Issue