[aten] Make `torch.aten.matmul` to `linalg` work for non-broadcasting case (#2659)

Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK
matmul should be legalized to a `linalg.matmul`.
pull/2411/merge
Rob Suderman 2023-12-20 10:09:10 -08:00 committed by GitHub
parent 8fa81d181b
commit a24aadbfab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 2 deletions

View File

@ -191,8 +191,9 @@ public:
Value lhs = adaptor.getSelf(); Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther(); Value rhs = adaptor.getOther();
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
return failure(); return failure();
}
auto lhsType = lhs.getType().cast<RankedTensorType>(); auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>(); auto rhsType = rhs.getType().cast<RankedTensorType>();
@ -260,7 +261,26 @@ public:
return success(); return success();
} }
// Fourth Case: Batch-Matrix Multiplication. // Fourth Case: Vec-Vec Multiplication.
if (lhsRank == 2 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value matmul =
rewriter
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Fifth Case: Batch-Matrix Multiplication.
// TODO: Handle batch matrix multiplication when one of the matrix is unity // TODO: Handle batch matrix multiplication when one of the matrix is unity
// rank and the other has batch dimension. // rank and the other has batch dimension.
if (lhsRank > 1 && rhsRank > 1) { if (lhsRank > 1 && rhsRank > 1) {

View File

@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.matmul.2d
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32>
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32>
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32>
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
return %0 : !torch.vtensor<[8,8],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
// CHECK-NOT: assert // CHECK-NOT: assert
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>