[ONNX] Implement Softsign op (#3373)

pull/3334/head
RattataKing 2024-05-21 15:10:26 -04:00 committed by GitHub
parent c2c1c2cfa4
commit fcf48872b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 0 deletions

View File

@ -2404,6 +2404,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
exp);
return success();
});
patterns.onOp("Softsign", 22,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
if (binder.tensorOperand(input) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value absX = rewriter.create<Torch::AtenAbsOp>(
binder.getLoc(), resultType, input);
Value constOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value absXPlusOne = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, absX, constOne, constOne);
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, input, absXPlusOne);
return success();
});
patterns.onOp(
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -580,6 +580,21 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
// -----
// CHECK-LABEL: func.func @test_softsign
func.func @test_softsign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[RES:.+]] = torch.aten.add.Scalar %[[ABS]], %[[INT1]], %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SCALE_T:.*]] = torch.aten.div.Tensor %arg0, %[[RES]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[SCALE_T]] : !torch.vtensor<[3,4,5],f32>
%none = torch.constant.none
%0 = torch.operator "onnx.Softsign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_selu
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1