mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for `onnx.sinh` (#2643)
Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding the `aten.sinh` operator.pull/2662/head snapshot-20231216.1054
parent
b3e94208a8
commit
61888690bb
|
@ -526,6 +526,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenSinhOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenSinh_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -467,11 +467,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::Aten_ShapeAsTensorOp>(
|
||||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp("Sinh", 9,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
if (binder.tensorOperand(operand) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSinhOp>(
|
||||
binder.op, resultType, operand);
|
||||
return success();
|
||||
});
|
||||
|
||||
patterns.onOp(
|
||||
"Transpose", 13,
|
||||
|
|
|
@ -266,6 +266,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::selu : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sign : (Tensor) -> (Tensor)",
|
||||
"aten::sinh : (Tensor) -> (Tensor)",
|
||||
"aten::sgn : (Tensor) -> (Tensor)",
|
||||
"aten::hardsigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardswish : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -489,6 +489,15 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_sinh
|
||||
func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} {
|
||||
// CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||
%0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_transpose_default
|
||||
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue