[MLIR][TORCH] Add a test for sum.dim_IntList op working for tosa (#1387)

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>

Co-authored-by: Suraj Sudhir <16977902+sjarus@users.noreply.github.com>
pull/1395/head
Vivek Khandelwal 2022-09-21 00:08:09 +05:30 committed by GitHub
parent 1ffd42bbde
commit 5090ac9359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 0 deletions

View File

@ -364,6 +364,7 @@ TOSA_PASS_SET = {
"ArgmaxModule_keepDim", "ArgmaxModule_keepDim",
"ArgmaxModule_with_dim", "ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"BroadcastToIdentityCaseStaticModule_basic", "BroadcastToIdentityCaseStaticModule_basic",
} }

View File

@ -125,6 +125,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 12, 7, 7], torch.float32, True),
])
def forward(self, a):
return torch.sum(a, dim=(-1), keepdim=True)
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 12, 7, 7))
# ==============================================================================
class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()