diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 2974882c4..7a61f311f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -340,6 +340,7 @@ STABLEHLO_PASS_SET = { "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseAbsModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 929ba7323..e7c264c55 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1451,6 +1451,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); + INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp); #undef INSERT_UNARY_PATTERN #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index b1d560e4f..367985233 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -624,3 +624,16 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } +// ----- + +// CHECK-LABEL: func.func @torch.aten.abs( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64> +// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_1]] : tensor<15x15xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64> +// CHECK: } +func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{ + %0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64> + return %0 : !torch.vtensor<[15,15],si64> +} \ No newline at end of file