mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for `onnx.Gemm` with no bias (#2993)
Previous gemm version required a bias vector. This provides an alternate path to `Torch::AtenMm` with no bias operation.pull/2995/head
parent
1964208d19
commit
551a4e45f3
|
@ -687,7 +687,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
int64_t transA, transB;
|
int64_t transA, transB;
|
||||||
if (binder.tensorOperandAtIndex(a, 0) ||
|
if (binder.tensorOperandAtIndex(a, 0) ||
|
||||||
binder.tensorOperandAtIndex(b, 1) ||
|
binder.tensorOperandAtIndex(b, 1) ||
|
||||||
binder.tensorOperandAtIndex(c, 2) ||
|
|
||||||
binder.s64IntegerAttr(transA, "transA", 0) ||
|
binder.s64IntegerAttr(transA, "transA", 0) ||
|
||||||
binder.s64IntegerAttr(transB, "transB", 0) ||
|
binder.s64IntegerAttr(transB, "transB", 0) ||
|
||||||
binder.f32FloatAttr(alpha, "alpha", 1.0f) ||
|
binder.f32FloatAttr(alpha, "alpha", 1.0f) ||
|
||||||
|
@ -724,6 +723,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
b = transpose(b);
|
b = transpose(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (binder.getNumOperands() == 2) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMmOp>(binder.op, resultType, a,
|
||||||
|
b);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.tensorOperandAtIndex(c, 2))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Expected either 2 or 3 inputs");
|
||||||
|
|
||||||
Value mm =
|
Value mm =
|
||||||
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
|
rewriter.create<Torch::AtenMmOp>(binder.getLoc(), resultType, a, b);
|
||||||
if (alpha == 1.0 && beta == 1.0) {
|
if (alpha == 1.0 && beta == 1.0) {
|
||||||
|
|
|
@ -1920,14 +1920,6 @@ ONNX_XFAIL_SET = {
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Gemm
|
|
||||||
"AtenMmFloatTypes_basic",
|
|
||||||
"AtenMmIntTypes_basic",
|
|
||||||
"MmDagModule_basic",
|
|
||||||
"MmModule_basic",
|
|
||||||
"MmModule_chained",
|
|
||||||
"MmTanhModule_basic",
|
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.HardSwish
|
# Failure - onnx_lowering: onnx.HardSwish
|
||||||
"HardswishModule_basic",
|
"HardswishModule_basic",
|
||||||
"HardswishRandomModule_basic",
|
"HardswishRandomModule_basic",
|
||||||
|
|
|
@ -97,8 +97,17 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_gemm_default
|
// CHECK-LABEL: func.func @test_gemm_defaultA
|
||||||
func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
|
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_gemm_defaultB
|
||||||
|
func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
|
||||||
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
// CHECK: %[[I1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
|
// CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32>
|
||||||
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
// CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32>
|
||||||
|
|
Loading…
Reference in New Issue