mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for `onnx.Gemm` with no bias (#2993)
Previous gemm version required a bias vector. This provides an alternate path to `Torch::AtenMm` with no bias operation.pull/2995/head
parent
1964208d19
commit
551a4e45f3
|
@ -687,7 +687,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
int64_t transA, transB;
|
||||
if (binder.tensorOperandAtIndex(a, 0) ||
|
||||
binder.tensorOperandAtIndex(b, 1) ||
|
||||
binder.tensorOperandAtIndex(c, 2) ||
|
||||
binder.s64IntegerAttr(transA, "transA", 0) ||
|
||||
binder.s64IntegerAttr(transB, "transB", 0) ||
|
||||
binder.f32FloatAttr(alpha, "alpha", 1.0f) ||
|
||||
|
@ -724,6 +723,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
b = transpose(b);
|
||||
}
|
||||
|
||||
if (binder.getNumOperands() == 2) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMmOp>(binder.op, resultType, a,
|
||||
b);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (binder.tensorOperandAtIndex(c, 2))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expected either 2 or 3 inputs");
|
||||
|
||||
Value mm =
|
||||
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
|
||||
if (alpha == 1.0 && beta == 1.0) {
|
||||
|
|
|
@ -1920,14 +1920,6 @@ ONNX_XFAIL_SET = {
|
|||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Gemm
|
||||
"AtenMmFloatTypes_basic",
|
||||
"AtenMmIntTypes_basic",
|
||||
"MmDagModule_basic",
|
||||
"MmModule_basic",
|
||||
"MmModule_chained",
|
||||
"MmTanhModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.HardSwish
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
|
|
|
@ -97,8 +97,17 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gemm_default
|
||||
func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-LABEL: func.func @test_gemm_defaultA
|
||||
func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_gemm_defaultB
|
||||
func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||
|
|
Loading…
Reference in New Issue