mirror of https://github.com/llvm/torch-mlir
[aten] Make `torch.aten.matmul` to `linalg` work for non-broadcasting case (#2659)
Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK matmul should be legalized to a `linalg.matmul`.pull/2411/merge
parent
8fa81d181b
commit
a24aadbfab
|
@ -191,8 +191,9 @@ public:
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
||||||
return failure();
|
return failure();
|
||||||
|
}
|
||||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -260,7 +261,26 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fourth Case: Batch-Matrix Multiplication.
|
// Fourth Case: Vec-Vec Multiplication.
|
||||||
|
if (lhsRank == 2 && rhsRank == 2) {
|
||||||
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||||
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||||
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||||
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||||
|
|
||||||
|
Value zeroTensor = createZeroInitTensor(
|
||||||
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||||
|
Value matmul =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
|
||||||
|
ValueRange{lhs, rhs}, zeroTensor)
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fifth Case: Batch-Matrix Multiplication.
|
||||||
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
||||||
// rank and the other has batch dimension.
|
// rank and the other has batch dimension.
|
||||||
if (lhsRank > 1 && rhsRank > 1) {
|
if (lhsRank > 1 && rhsRank > 1) {
|
||||||
|
|
|
@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.matmul.2d
|
||||||
|
func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32>
|
||||||
|
// CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
|
||||||
|
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32>
|
||||||
|
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32>
|
||||||
|
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32>
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||||
|
return %0 : !torch.vtensor<[8,8],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
|
// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
|
||||||
// CHECK-NOT: assert
|
// CHECK-NOT: assert
|
||||||
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
|
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
|
||||||
|
|
Loading…
Reference in New Issue