diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a540f0b0d..576534d18 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2404,6 +2404,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( exp); return success(); }); + patterns.onOp("Softsign", 22, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Value absX = rewriter.create( + binder.getLoc(), resultType, input); + + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + + Value absXPlusOne = rewriter.create( + binder.getLoc(), resultType, absX, constOne, constOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, absXPlusOne); + return success(); + }); patterns.onOp( "Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0fc82da74..65b7f08e6 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -580,6 +580,21 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // ----- +// CHECK-LABEL: func.func @test_softsign +func.func @test_softsign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[RES:.+]] = torch.aten.add.Scalar %[[ABS]], %[[INT1]], %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SCALE_T:.*]] = torch.aten.div.Tensor %arg0, %[[RES]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: return %[[SCALE_T]] : !torch.vtensor<[3,4,5],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.Softsign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1