Add e2e native_group_norm test-cases

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
pull/2120/head
rahul shrivastava 2023-05-09 23:49:27 -07:00 committed by rahuls-cerebras
parent 40a2c501a1
commit 86429d9656
2 changed files with 49 additions and 0 deletions

View File

@ -7,6 +7,8 @@
# These represent further work needed in torch-mlir to lower them properly
# to the backend contract.
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormModule_basic",
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
}

View File

@ -217,6 +217,53 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
# ==============================================================================
class NativeGroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 6, 2, 2], torch.float32, True),
([6], torch.float32, True),
([6], torch.float32, True),
])
def forward(self, x, weight, bias):
return torch.ops.aten.native_group_norm(
x, weight, bias,
2, 6, 4, 3, 0.000001);
@register_test_case(module_factory=lambda: NativeGroupNormModule())
def NativeGroupNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
class NativeGroupNormBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 6, 2, 2], torch.float32, True),
([2, 6, 2, 2], torch.float32, True),
([2, 3], torch.float32, True),
([2, 3], torch.float32, True),
([6], torch.float32, True),
])
def forward(self, grad_out, x, mean, rstd, weight):
return torch.ops.aten.native_group_norm_backward(
grad_out, x, mean, rstd, weight,
2, 6, 4, 3, [True, True, True]);
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
def NativeGroupNormBackwardModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 2, 2), tu.rand(2, 6, 2, 2), tu.rand(2, 3),
tu.rand(2, 3), tu.rand(6))
# ==============================================================================
class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()