diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 406727370..a37d4dddc 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -454,6 +454,7 @@ STABLEHLO_PASS_SET = { "ElementwiseFlattenBroadcastModule_basic", "ElementwiseLeakyReluModule_basic", "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", "ElementwiseLogModule_basic", "ElementwiseNegModule_basic", "ElementwiseRsqrtModule_basic", @@ -908,6 +909,7 @@ TOSA_PASS_SET = { "ElementwiseReluModule_basic", "ElementwiseLeakyReluModule_basic", "ElementwiseEluModule_basic", + "ElementwiseEluNonDefaultModule_basic", "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "ElementwiseBinaryStaticShapeModule_basic", diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 09a64976a..2e2db90d3 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -426,6 +426,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index c75e7d671..40bb9975a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -476,6 +476,27 @@ def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseEluNonDefaultModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.elu(x, scale=1.5, alpha=2.0, input_scale=3.0) + +@register_test_case(module_factory=lambda: ElementwiseEluNonDefaultModule()) +def ElementwiseEluNonDefaultModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseEluModule(torch.nn.Module): def __init__(self):