From 7c289d95222f0297b7f1b36dc123663b95341136 Mon Sep 17 00:00:00 2001 From: Angel Zhang <68571948+angelz913@users.noreply.github.com> Date: Fri, 10 May 2024 11:58:46 -0400 Subject: [PATCH] [ONNX] Handle one-input case for `onnx.Max` operator (#3325) This commit handles the one-input case for the "Max" ONNX operator. A new unit test has also been added. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1f1e2e5d7..d419b0b5b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -651,7 +651,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } - rewriter.replaceOp(binder.op, result.getDefiningOp()); + rewriter.replaceOp(binder.op, result); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d280d5f6b..6041bae1c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -740,6 +740,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_max_one_input_example + func.func @test_max_one_input_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Max"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_min_example func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>