Fix onnx sinh lowering (#3253)

iree tests `test_sinh` and `test_sinh_example` passed
pull/3269/head
jinchen 2024-04-30 00:44:41 -07:00 committed by GitHub
parent 122cf22cc2
commit b64c22cfc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 13 deletions

View File

@ -1449,16 +1449,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return success(); return success();
}); });
patterns.onOp("Sinh", 9, patterns.onOp(
[](OpBinder binder, ConversionPatternRewriter &rewriter) { "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
if (binder.tensorOperand(operand) || if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) binder.tensorResultType(resultType))
return failure(); return failure();
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>( // 1/2 * (exp(x) exp(-x))
binder.op, resultType, operand); Value x = rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType,
operand);
Value neg = rewriter.create<Torch::AtenNegOp>(binder.getLoc(),
resultType, operand);
Value y =
rewriter.create<Torch::AtenExpOp>(binder.getLoc(), resultType, neg);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value z = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, x, y, cstOne);
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<Torch::AtenDivScalarOp>(
binder.op, resultType, z, cstTwo);
return success(); return success();
}); });

View File

@ -1265,9 +1265,15 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>,
// ----- // -----
// CHECK-LABEL: func.func @test_sinh // CHECK-LABEL: func.func @test_sinh_example
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[C1:.+]] = torch.constant.int 1
// CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[C2:.+]] = torch.constant.int 2
// CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32>
} }