Revert hyperbolic trigonometric decompositions (#3271)

We should be using the `torch` path and handling decomposition in the
`math` dialect.
pull/3285/head
Rob Suderman 2024-05-03 09:06:44 -07:00 committed by GitHub
parent 67d6a665a4
commit 321b844df7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 176 deletions

View File

@ -300,27 +300,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand); binder.op, resultType, operand);
return success(); return success();
}); });
patterns.onOp( patterns.onOp("Asinh", 9,
"Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAsinhOp>(
// log(x + sqrt(x**2 + 1)) binder.op, resultType, operand);
Value square = rewriter.create<Torch::AtenSquareOp>(
binder.getLoc(), resultType, operand);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value add0 = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, square, cstOne, cstOne);
Value sqrt = rewriter.create<Torch::AtenSqrtOp>(binder.getLoc(),
resultType, add0);
Value add1 = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, operand, sqrt, cstOne);
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(binder.op, resultType,
add1);
return success(); return success();
}); });
patterns.onOp("Atan", 7, patterns.onOp("Atan", 7,
@ -334,31 +322,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand); binder.op, resultType, operand);
return success(); return success();
}); });
patterns.onOp( patterns.onOp("Atanh", 9,
"Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAtanhOp>(
// 1/2 * log((1 + x) / (1 - x)) binder.op, resultType, operand);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value add = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, operand, cstOne, cstOne);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value sub = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, neg, cstOne, cstOne);
Value div = rewriter.create<Torch::AtenDivTensorOp>(
binder.getLoc(), resultType, add, sub);
Value log =
rewriter.create<Torch::AtenLogOp>(binder.getLoc(), resultType, div);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, log, cstTwo);
return success(); return success();
}); });
patterns.onOp("Acos", 7, patterns.onOp("Acos", 7,
@ -372,27 +344,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand); binder.op, resultType, operand);
return success(); return success();
}); });
patterns.onOp( patterns.onOp("Acosh", 9,
"Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAcoshOp>(
// log(x + sqrt(x**2 - 1)) binder.op, resultType, operand);
Value square = rewriter.create<Torch::AtenSquareOp>(
binder.getLoc(), resultType, operand);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
binder.getLoc(), resultType, square, cstOne, cstOne);
Value sqrt = rewriter.create<Torch::AtenSqrtOp>(binder.getLoc(),
resultType, sub);
Value add = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, operand, sqrt, cstOne);
rewriter.replaceOpWithNewOp<Torch::AtenLogOp>(binder.op, resultType,
add);
return success(); return success();
}); });
patterns.onOp("BatchNormalization", 15, patterns.onOp("BatchNormalization", 15,
@ -1490,29 +1450,15 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand); binder.op, resultType, operand);
return success(); return success();
}); });
patterns.onOp( patterns.onOp("Cosh", 9,
"Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
rewriter.replaceOpWithNewOp<Torch::AtenCoshOp>(
// 1/2 * (exp(x) + exp(-x)) binder.op, resultType, operand);
Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType,
operand);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value y =
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value z = rewriter.create<Torch::AtenAddTensorOp>(
binder.getLoc(), resultType, x, y, cstOne);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, z, cstTwo);
return success(); return success();
}); });
patterns.onOp( patterns.onOp(

View File

@ -1439,29 +1439,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success(); return success();
}); });
patterns.onOp( patterns.onOp("Sinh", 9,
"Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
// 1/2 * (exp(x) exp(-x)) rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, binder.op, resultType, operand);
operand);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value y =
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value z = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, x, y, cstOne);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, z, cstTwo);
return success(); return success();
}); });

View File

@ -201,14 +201,7 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// CHECK-LABEL: @test_atanh // CHECK-LABEL: @test_atanh
func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[NEG:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SUB:.*]] = torch.aten.add.Scalar %[[NEG]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Tensor %[[ADD]], %[[SUB]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[LOG:.*]] = torch.aten.log %[[DIV]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: torch.aten.div.Scalar %[[LOG]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32>
} }
@ -672,13 +665,7 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5
// CHECK-LABEL: @test_cosh_example // CHECK-LABEL: @test_cosh_example
func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[C2:.+]] = torch.constant.int 2
// CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32>
} }
@ -687,13 +674,7 @@ func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[
// CHECK-LABEL: @test_cosh // CHECK-LABEL: @test_cosh
func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C2:.+]] = torch.constant.int 2
// CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32>
} }
@ -702,12 +683,7 @@ func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// CHECK-LABEL: @test_acosh_example // CHECK-LABEL: @test_acosh_example
func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32>
} }
@ -716,12 +692,7 @@ func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<
// CHECK-LABEL: @test_acosh // CHECK-LABEL: @test_acosh
func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32>
} }
@ -748,12 +719,7 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// CHECK-LABEL: @test_asinh_example // CHECK-LABEL: @test_asinh_example
func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32>
} }
@ -762,12 +728,7 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<
// CHECK-LABEL: @test_asinh // CHECK-LABEL: @test_asinh
func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
// CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32>
} }

View File

@ -1341,15 +1341,9 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>,
// ----- // -----
// CHECK-LABEL: func.func @test_sinh_example // CHECK-LABEL: func.func @test_sinh
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
// CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[C2:.+]] = torch.constant.int 2
// CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32>
} }