mirror of https://github.com/llvm/torch-mlir
[aten] Make `torch.aten.matmul` to `linalg` work for non-broadcasting case (#2659)
Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK matmul should be legalized to a `linalg.matmul`.pull/2411/merge
parent
8fa81d181b
commit
a24aadbfab
|
@ -191,8 +191,9 @@ public:
|
|||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
||||
return failure();
|
||||
}
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
|
@ -260,7 +261,26 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
// Fourth Case: Batch-Matrix Multiplication.
|
||||
// Fourth Case: Vec-Vec Multiplication.
|
||||
if (lhsRank == 2 && rhsRank == 2) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
Value zeroTensor = createZeroInitTensor(
|
||||
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Fifth Case: Batch-Matrix Multiplication.
|
||||
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
||||
// rank and the other has batch dimension.
|
||||
if (lhsRank > 1 && rhsRank > 1) {
|
||||
|
|
|
@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul.2d
|
||||
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
|
||||
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
|
||||
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32>
|
||||
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32>
|
||||
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32>
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||
return %0 : !torch.vtensor<[8,8],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
|
||||
// CHECK-NOT: assert
|
||||
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||
|
|
Loading…
Reference in New Issue