mirror of https://github.com/llvm/torch-mlir
Add e2e native_group_norm test-cases
Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2120/head
parent
40a2c501a1
commit
86429d9656
|
@ -7,6 +7,8 @@
|
|||
# These represent further work needed in torch-mlir to lower them properly
|
||||
# to the backend contract.
|
||||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
"NativeGroupNormModule_basic",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
}
|
||||
|
|
|
@ -217,6 +217,53 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class NativeGroupNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
return torch.ops.aten.native_group_norm(
|
||||
x, weight, bias,
|
||||
2, 6, 4, 3, 0.000001);
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeGroupNormModule())
|
||||
def NativeGroupNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6))
|
||||
|
||||
class NativeGroupNormBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([2, 6, 2, 2], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
([2, 3], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, x, mean, rstd, weight):
|
||||
return torch.ops.aten.native_group_norm_backward(
|
||||
grad_out, x, mean, rstd, weight,
|
||||
2, 6, 4, 3, [True, True, True]);
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeGroupNormBackwardModule())
|
||||
def NativeGroupNormBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 2, 2), tu.rand(2, 6, 2, 2), tu.rand(2, 3),
|
||||
tu.rand(2, 3), tu.rand(6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class NativeLayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue