[onnx] Add support for `onnx.sinh` (#2643)

Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding
the `aten.sinh` operator.
pull/2662/head snapshot-20231216.1054
Rob Suderman 2023-12-15 21:23:51 -08:00 committed by GitHub
parent b3e94208a8
commit 61888690bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 1 deletions

View File

@ -526,6 +526,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [
}];
}
def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSinhOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSinh_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -467,11 +467,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp("Sinh", 9,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Transpose", 13,

View File

@ -266,6 +266,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sign : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
"aten::sgn : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",

View File

@ -489,6 +489,15 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
// -----
// CHECK-LABEL: func.func @test_sinh
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// -----
// CHECK-LABEL: func.func @test_transpose_default
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0