[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)

Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
    def forward(self, x):
        if len(x.shape) == 5:
            b0, t, c0, h0, w0 = x.shape
            b, c, h, w = torch.mul(b0, t), c0, h0, w0
        else:
            b1, c1, h1, w1 = x.shape
            b, c, h, w = b1, c1, h1, w1
        res = torch.reshape(x, [b, c, h, w])
        return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
pull/3141/head
Xinyu Yang 2024-04-11 19:48:49 +08:00 committed by GitHub
parent 6524838bcb
commit 308c45e61a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 0 deletions

View File

@ -3088,6 +3088,9 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (!listConstruct)
return failure();
if (op->getNumResults() != listConstruct.getElements().size())
return failure();
rewriter.replaceOp(op, listConstruct.getElements());
return success();
});

View File

@ -657,6 +657,7 @@ STABLEHLO_PASS_SET = {
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
@ -1216,6 +1217,7 @@ TOSA_PASS_SET = {
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
@ -1391,6 +1393,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"ElementwisePreluStaticModule_basic",
# Shape Related failures
"PrimListUnpackNumMismatchModule_basic",
"ReshapeExpandModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",

View File

@ -674,6 +674,33 @@ def SliceCopyNonZeroDim_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(10, 2, 4))
# ==============================================================================
class PrimListUnpackNumMismatchModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 4, 3, 2, 1], torch.float32, True),
])
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res
@register_test_case(module_factory=lambda: PrimListUnpackNumMismatchModule())
def PrimListUnpackNumMismatchModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2, 1))
# ==============================================================================