[Stablehlo] Add stablehlo support for aten.abs (#2068)

Co-authored-by: AmosLewis <Amos_Lewsi@foxmail.com>
pull/2105/head snapshot-20230509.833
Chi_Liu 2023-05-08 22:13:00 -07:00 committed by GitHub
parent c7a24c4d21
commit 51e0a2c933
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 0 deletions

View File

@ -340,6 +340,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic", "ElementwiseSubScalarIntModule_basic",
"ElementwiseWhereScalarModule_basic", "ElementwiseWhereScalarModule_basic",
"ElementwiseAbsModule_basic",
"EmbeddingModule1DIndices_basic", "EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic", "EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic", "EmbeddingModuleI32_basic",

View File

@ -1451,6 +1451,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp);
#undef INSERT_UNARY_PATTERN #undef INSERT_UNARY_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \

View File

@ -624,3 +624,16 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.abs(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64>
// CHECK: %[[VAL_2:.*]] = stablehlo.abs %[[VAL_1]] : tensor<15x15xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64>
// CHECK: }
func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{
%0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64>
return %0 : !torch.vtensor<[15,15],si64>
}