mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for Rsqrt legalization (#480)
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/482/head snapshot-20211214.144
parent
6dabf185f5
commit
cce490d71d
|
@ -43,4 +43,5 @@ TOSA_PASS_SET = {
|
||||||
"BoolTensorReturnFalseModule_basic",
|
"BoolTensorReturnFalseModule_basic",
|
||||||
"BoolTensorReturnTrueModule_basic",
|
"BoolTensorReturnTrueModule_basic",
|
||||||
"BoolTensorReturnMixedModule_basic",
|
"BoolTensorReturnMixedModule_basic",
|
||||||
|
"ElementwiseRsqrtModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -442,6 +442,7 @@ public:
|
||||||
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
patterns.add<ConvertAtenUnaryOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||||
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp)
|
||||||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||||
|
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
|
||||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||||
#undef INSERT_UNARY_PATTERN
|
#undef INSERT_UNARY_PATTERN
|
||||||
|
|
||||||
|
|
|
@ -285,3 +285,16 @@ func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtens
|
||||||
return %0 : !torch.vtensor<[1],i1>
|
return %0 : !torch.vtensor<[1],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.rsqrt$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue