* [tosa] Support for Rsqrt legalization (#480)

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
pull/482/head snapshot-20211214.144
Anup Gangwar 2021-12-14 12:03:58 -06:00 committed by GitHub
parent 6dabf185f5
commit cce490d71d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 0 deletions

View File

@ -43,4 +43,5 @@ TOSA_PASS_SET = {
"BoolTensorReturnFalseModule_basic", "BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic", "BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic", "BoolTensorReturnMixedModule_basic",
"ElementwiseRsqrtModule_basic",
} }

View File

@ -442,6 +442,7 @@ public:
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context); patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
#undef INSERT_UNARY_PATTERN #undef INSERT_UNARY_PATTERN

View File

@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
return %0 : !torch.vtensor<[1],i1> return %0 : !torch.vtensor<[1],i1>
} }
// -----
// CHECK-LABEL: func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}