mirror of https://github.com/llvm/torch-mlir
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList length. See the following example: ```python def forward(self, x): if len(x.shape) == 5: b0, t, c0, h0, w0 = x.shape b, c, h, w = torch.mul(b0, t), c0, h0, w0 else: b1, c1, h1, w1 = x.shape b, c, h, w = b1, c1, h1, w1 res = torch.reshape(x, [b, c, h, w]) return res ``` Without this fix, the following error message will occur: ``` /root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed. ```pull/3141/head
parent
6524838bcb
commit
308c45e61a
|
@ -3088,6 +3088,9 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
if (!listConstruct)
|
||||
return failure();
|
||||
|
||||
if (op->getNumResults() != listConstruct.getElements().size())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, listConstruct.getElements());
|
||||
return success();
|
||||
});
|
||||
|
|
|
@ -657,6 +657,7 @@ STABLEHLO_PASS_SET = {
|
|||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
"PrimMinIntModule_basic",
|
||||
|
@ -1216,6 +1217,7 @@ TOSA_PASS_SET = {
|
|||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
|
@ -1391,6 +1393,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"ElementwisePreluStaticModule_basic",
|
||||
|
||||
# Shape Related failures
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"ReshapeExpandModule_basic",
|
||||
"UnsafeViewCollapseModule_basic",
|
||||
"UnsafeViewDynamicExpandModule_basic",
|
||||
|
|
|
@ -674,6 +674,33 @@ def SliceCopyNonZeroDim_Module_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(10, 4, 4), tu.rand(10, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class PrimListUnpackNumMismatchModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 4, 3, 2, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 5:
|
||||
b0, t, c0, h0, w0 = x.shape
|
||||
b, c, h, w = torch.mul(b0, t), c0, h0, w0
|
||||
else:
|
||||
b1, c1, h1, w1 = x.shape
|
||||
b, c, h, w = b1, c1, h1, w1
|
||||
res = torch.reshape(x, [b, c, h, w])
|
||||
return res
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: PrimListUnpackNumMismatchModule())
|
||||
def PrimListUnpackNumMismatchModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 2, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue