From 551a4e45f36574b6f0bf892a2d67d877e0253441 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:58:38 -0800 Subject: [PATCH] [onnx] Add support for `onnx.Gemm` with no bias (#2993) Previous gemm version required a bias vector. This provides an alternate path to `Torch::AtenMm` with no bias operation. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 11 ++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 8 -------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 13 +++++++++++-- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b956666f8..8a677b8ce 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -687,7 +687,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || - binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0f) || @@ -724,6 +723,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( b = transpose(b); } + if (binder.getNumOperands() == 2) { + rewriter.replaceOpWithNewOp(binder.op, resultType, a, + b); + return success(); + } + + if (binder.tensorOperandAtIndex(c, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Expected either 2 or 3 inputs"); + Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9f5db2c86..d16a20893 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1920,14 +1920,6 @@ ONNX_XFAIL_SET = { "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - # Failure - onnx_lowering: onnx.Gemm - "AtenMmFloatTypes_basic", - "AtenMmIntTypes_basic", - "MmDagModule_basic", - "MmModule_basic", - "MmModule_chained", - "MmTanhModule_basic", - # Failure - onnx_lowering: onnx.HardSwish "HardswishModule_basic", "HardswishRandomModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index dcd0d2893..d1f4307d4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -97,8 +97,17 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc // ----- -// CHECK-LABEL: func.func @test_gemm_default -func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { +// CHECK-LABEL: func.func @test_gemm_defaultA +func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_defaultB +func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[I1:.+]] = torch.constant.int 1 // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>