From 336cfb64b531a3d34cae74f007448749420d64ac Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 27 Dec 2023 10:50:08 -0800 Subject: [PATCH] OnnxToTorch support for onnx.Mul op (#2699) --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 13 ++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index e7544d6c1..c24fd0c65 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -99,6 +99,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("Mul", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -342,4 +354,3 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); } - diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5dd0225b9..085c6ea6a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -151,6 +151,15 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: func.func @test_mul + func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Mul"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none"