add nondefault test case, add to illegal ops in backend contract

pull/2422/head snapshot-20230828.944
Arham Khan 2023-08-25 09:42:29 -05:00 committed by Vivek Khandelwal
parent 8855fa3ace
commit bc6bba9077
3 changed files with 24 additions and 0 deletions

View File

@ -454,6 +454,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
@ -908,6 +909,7 @@ TOSA_PASS_SET = {
"ElementwiseReluModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",

View File

@ -426,6 +426,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenEluOp>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();

View File

@ -476,6 +476,27 @@ def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseEluNonDefaultModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.elu(x, scale=1.5, alpha=2.0, input_scale=3.0)
@register_test_case(module_factory=lambda: ElementwiseEluNonDefaultModule())
def ElementwiseEluNonDefaultModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5,3, low=-1, high=1))
# ==============================================================================
class ElementwiseEluModule(torch.nn.Module):
def __init__(self):