From 9ae33e482ef02ff4bc345463128901a514480313 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 25 Mar 2024 20:29:07 +0530 Subject: [PATCH] [MLIR][TORCH] Add OnnxToTorch lowering for ops (#3049) This commit adds the OnnxToTorch lowering for the Mish, Softplus, HardSwish, Trilu, ThresholdedRelu op Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 24 +++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 63 ++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 19 ---- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 19 ++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 99 +++++++++++++++++++ 5 files changed, 205 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 46ecc5b70..22fd82371 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1501,4 +1501,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, self, other); return success(); }); + patterns.onOp("Mish", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, input); + return success(); + }); + patterns.onOp("HardSwish", 14, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, input); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b5e9162bc..9218f5376 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2099,4 +2099,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Softplus", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } + // out = ln(exp(x) + 1) + Value exp = rewriter.create(binder.getLoc(), + resultType, input); + rewriter.replaceOpWithNewOp(binder.op, resultType, + exp); + return success(); + }); + patterns.onOp( + "Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t upper; + if (binder.tensorOperandAtIndex(input, 0) || + binder.s64IntegerAttr(upper, "upper", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Value diagonal; + if (binder.tensorOperandAtIndex(diagonal, 1)) { + diagonal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + } else { + diagonal = rewriter.create( + binder.getLoc(), rewriter.getType(), diagonal); + } + + if (upper) { + rewriter.replaceOpWithNewOp(binder.op, resultType, + input, diagonal); + return success(); + } + rewriter.replaceOpWithNewOp(binder.op, resultType, + input, diagonal); + return success(); + }); + patterns.onOp("ThresholdedRelu", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + float alpha; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(alpha, "alpha", 1.0)) { + return failure(); + } + Value cstAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); + Value value = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstAlpha, value); + return success(); + }); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 55b7d8883..9bbd115cd 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1924,11 +1924,6 @@ ONNX_XFAIL_SET = { "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - # Failure - onnx_lowering: onnx.HardSwish - "HardswishModule_basic", - "HardswishRandomModule_basic", - "MobilenetV3Module_basic", - # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", @@ -2043,10 +2038,6 @@ ONNX_XFAIL_SET = { "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", - # Failure - onnx_lowering: onnx.Softplus - "ElementwiseMishModule_basic", - "SoftplusModule_basic", - # Failure - onnx_lowering: onnx.Squeeze "SqueezeModule_allUnitDim", "SqueezeModule_broadcast", @@ -2059,16 +2050,6 @@ ONNX_XFAIL_SET = { "SortTensorSpecificDimension_basic", "SortTensor_basic", - # Failure - onnx_lowering: onnx.Trilu - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", - "AtenTriuModule_basic", - "AtenTriuWithNegDiagonalModule_basic", - "AtenTriuWithPosDiagonalModule_basic", - "TriuBroadcastModule_basic", - "TriuModule_basic", - # Failure - incorrect dtype "ReduceMaxAlongDimUnsignedInt_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 46eda0496..3741e5c9f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -912,3 +912,22 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], %0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_mish +func.func @test_mish(%arg0: !torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.mish %arg0 : !torch.vtensor<[10000],f32> -> !torch.vtensor<[10000],f32> + %0 = torch.operator "onnx.Mish"(%arg0) : (!torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32> + return %0 : !torch.vtensor<[10000],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_hardswish +func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.hardswish %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 508ed55d3..dc652bff9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1664,3 +1664,102 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si return %0 : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: func.func @test_softplus +func.func @test_softplus(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[EXP:.*]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.log1p %[[EXP]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Softplus"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_tril +func.func @test_tril(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0 + // CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_tril_neg +func.func @test_tril_neg(%arg0: !torch.vtensor<[4,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_tril_one_row_neg +func.func @test_tril_one_row_neg(%arg0: !torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0 + // CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64> + return %0 : !torch.vtensor<[3,1,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_tril_square +func.func @test_tril_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0 + // CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> + return %0 : !torch.vtensor<[2,3,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_tril_zero +func.func @test_tril_zero(%arg0: !torch.vtensor<[3,0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,0,5],si64>, !torch.int -> !torch.vtensor<[3,0,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64> + return %0 : !torch.vtensor<[3,0,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu +func.func @test_triu(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0 + // CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu_one_row +func.func @test_triu_one_row(%arg0: !torch.vtensor<[3,1,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[3,1,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64> + return %0 : !torch.vtensor<[3,1,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu_square +func.func @test_triu_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0 + // CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> + return %0 : !torch.vtensor<[2,3,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu_zero +func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[0,5],si64>, !torch.int -> !torch.vtensor<[0,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> + return %0 : !torch.vtensor<[0,5],si64> +}