diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 93fe9dc1c..9ee8ad895 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2485,10 +2485,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.getSelf().getType().dyn_cast(); - if (!selfType || !selfType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, - "Only ranked tensor types with static shapes are currently supported"); + if (!selfType) + return rewriter.notifyMatchFailure(op, + "Only ranked tensor types supported"); int64_t selfRank = selfType.getRank(); @@ -2520,8 +2519,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { if (idx == start_dim) newShape.push_back(s.value()); - else + // Only updating when the shapes are static + else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize) newShape.back() *= s.value(); + else + newShape.back() = kUnknownSize; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6c6666a28..82d86e739 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -885,6 +885,9 @@ TOSA_PASS_SET = { "ArangeStartNegativeStepFloatModule_basic", "ArangeStartOutDtypeModule_basic", "ArangeStartStepFloatModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", @@ -1077,6 +1080,7 @@ TOSA_PASS_SET = { "EmbeddingModuleI32Static_basic", "FlattenRank0Module_basic", "FlattenStaticModule_basic", + "FlattenDynamicModuleCollapseAll_basic", "FullLikeModuleFloat3DStatic_basic", "FullLikeModuleInt2DStatic_basic", "FullModuleDefaultDtype_basic", @@ -1292,6 +1296,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { }) - { ### Test failing in make_fx_tosa but not in tosa + "FlattenDynamicModuleCollapseAll_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index c5ef92d41..fba52e2e7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -391,6 +391,25 @@ class FlattenDynamicModule(torch.nn.Module): def FlattenDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) +class FlattenDynamicModuleCollapseAll(torch.nn.Module): + + def __init__(self): + super().__init__() + self.flat = torch.nn.Flatten(0) + + @export + @annotate_args([ + None, + ([-1, -1, -1, 9, 3, -1], torch.float32, True), + ]) + def forward(self, x): + return self.flat(x) + + +@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll()) +def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 3, 8, 9, 3, 4)) + # ==============================================================================