[onnx] Add support for `onnx.Gemm` with no bias (#2993)

Previous gemm version required a bias vector. 
This provides an alternate path to `Torch::AtenMm`
with no bias operation.
pull/2995/head
Andreas Falkenberg 2024-03-07 15:58:38 -08:00 committed by GitHub
parent 1964208d19
commit 551a4e45f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 11 deletions

View File

@ -687,7 +687,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
int64_t transA, transB;
if (binder.tensorOperandAtIndex(a, 0) ||
binder.tensorOperandAtIndex(b, 1) ||
binder.tensorOperandAtIndex(c, 2) ||
binder.s64IntegerAttr(transA, "transA", 0) ||
binder.s64IntegerAttr(transB, "transB", 0) ||
binder.f32FloatAttr(alpha, "alpha", 1.0f) ||
@ -724,6 +723,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
b = transpose(b);
}
if (binder.getNumOperands() == 2) {
rewriter.replaceOpWithNewOp<Torch::AtenMmOp>(binder.op, resultType, a,
b);
return success();
}
if (binder.tensorOperandAtIndex(c, 2))
return rewriter.notifyMatchFailure(binder.op,
"Expected either 2 or 3 inputs");
Value mm =
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
if (alpha == 1.0 && beta == 1.0) {

View File

@ -1920,14 +1920,6 @@ ONNX_XFAIL_SET = {
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
# Failure - onnx_lowering: onnx.Gemm
"AtenMmFloatTypes_basic",
"AtenMmIntTypes_basic",
"MmDagModule_basic",
"MmModule_basic",
"MmModule_chained",
"MmTanhModule_basic",
# Failure - onnx_lowering: onnx.HardSwish
"HardswishModule_basic",
"HardswishRandomModule_basic",

View File

@ -97,8 +97,17 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc
// -----
// CHECK-LABEL: func.func @test_gemm_default
func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-LABEL: func.func @test_gemm_defaultA
func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_gemm_defaultB
func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
// CHECK: %[[I1:.+]] = torch.constant.int 1
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>