mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add a test for sum.dim_IntList op working for tosa (#1387)
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com> Co-authored-by: Suraj Sudhir <16977902+sjarus@users.noreply.github.com>pull/1395/head
parent
1ffd42bbde
commit
5090ac9359
|
@ -364,6 +364,7 @@ TOSA_PASS_SET = {
|
|||
"ArgmaxModule_keepDim",
|
||||
"ArgmaxModule_with_dim",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"BroadcastToIdentityCaseStaticModule_basic",
|
||||
}
|
||||
|
||||
|
|
|
@ -125,6 +125,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListKeepDimNegativeDimStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 12, 7, 7], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dim=(-1), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListKeepDimNegativeDimStaticModule())
|
||||
def ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 12, 7, 7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue