mirror of https://github.com/llvm/torch-mlir
[ONNX] Implement Softsign op (#3373)
parent
c2c1c2cfa4
commit
fcf48872b3
|
@ -2404,6 +2404,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
exp);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Softsign", 22,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) ||
|
||||
binder.tensorResultType(resultType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value absX = rewriter.create<Torch::AtenAbsOp>(
|
||||
binder.getLoc(), resultType, input);
|
||||
|
||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
|
||||
Value absXPlusOne = rewriter.create<Torch::AtenAddScalarOp>(
|
||||
binder.getLoc(), resultType, absX, constOne, constOne);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
|
||||
binder.op, resultType, input, absXPlusOne);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -580,6 +580,21 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_softsign
|
||||
func.func @test_softsign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RES:.+]] = torch.aten.add.Scalar %[[ABS]], %[[INT1]], %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: %[[SCALE_T:.*]] = torch.aten.div.Tensor %arg0, %[[RES]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: return %[[SCALE_T]] : !torch.vtensor<[3,4,5],f32>
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.Softsign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_selu
|
||||
func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} {
|
||||
// CHECK-DAG: %[[F1:.+]] = torch.constant.float 1
|
||||
|
|
Loading…
Reference in New Issue