diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b263786c3..7c5f2c88c 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -191,8 +191,9 @@ public: Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); + } auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); @@ -260,7 +261,26 @@ public: return success(); } - // Fourth Case: Batch-Matrix Multiplication. + // Fourth Case: Vec-Vec Multiplication. + if (lhsRank == 2 && rhsRank == 2) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + Value zeroTensor = createZeroInitTensor( + rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); + Value matmul = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + + // Fifth Case: Batch-Matrix Multiplication. // TODO: Handle batch matrix multiplication when one of the matrix is unity // rank and the other has batch dimension. if (lhsRank > 1 && rhsRank > 1) { diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 0aaca941b..486b8b641 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.2d +func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> + // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32> + // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32> + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>