diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index edb36aee9..197d9c536 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1449,18 +1449,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Sinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + // 1/2 * (exp(x) – exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); // split with fixed-size parts // Arguments: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5fe9c79d3..2748a640a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1265,9 +1265,15 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh +// CHECK-LABEL: func.func @test_sinh_example func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> }