test: add end-to-end test for aten.neg (#760)

pull/518/head snapshot-20220415.391
Ashay Rane 2022-04-15 12:37:57 -07:00 committed by GitHub
parent a893c7d5cf
commit d3c08376af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View File

@ -160,4 +160,5 @@ TOSA_PASS_SET = {
"ElementwiseNeIntScalarModule_basic", "ElementwiseNeIntScalarModule_basic",
"ElementwiseNeFloatTensorModule_basic", "ElementwiseNeFloatTensorModule_basic",
"ConvolutionModule2DStatic_basic", "ConvolutionModule2DStatic_basic",
"ElementwiseNegModule_basic",
} }

View File

@ -1282,3 +1282,23 @@ class ElementwiseCosIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseCosIntModule()) @register_test_case(module_factory=lambda: ElementwiseCosIntModule())
def ElementwiseCosIntModule_basic(module, tu: TestUtils): def ElementwiseCosIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseNegModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.neg(a)
@register_test_case(module_factory=lambda: ElementwiseNegModule())
def ElementwiseNegModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))