diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index acc28a510..bab4ad600 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -364,6 +364,7 @@ TOSA_PASS_SET = { "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "_LogSoftmaxModuleStable_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "BroadcastToIdentityCaseStaticModule_basic", } diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 90b18e6a2..b28d78a12 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -125,6 +125,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 12, 7, 7], torch.float32, True), + ]) + def forward(self, a): + return torch.sum(a, dim=(-1), keepdim=True) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule()) +def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 12, 7, 7)) + +# ============================================================================== + class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__()