mirror of https://github.com/llvm/torch-mlir
Add static type information support to `aten.mm`
This commit adds static type information support to `aten.mm`. This is needed for the forward pass of Bert training.bert-staging
parent
54357ea378
commit
fa6cf0bed8
|
@ -56,6 +56,7 @@ TOSA_PASS_SET = {
|
|||
"ReturnTwoTensorF32I64_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
"BmmModule_basic",
|
||||
"MmDagModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_3d",
|
||||
"RsubModule_basic",
|
||||
|
|
|
@ -830,32 +830,19 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
|
|||
auto &rhs = operands[1]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
||||
// `aten.mm` expects both operands to be rank-2 tensors.
|
||||
if (lhs.sizes.size() != 2 || rhs.sizes.size() != 2)
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
|
||||
// If static information is available, check that both tensors are compatible.
|
||||
if (lhs.sizes[1] != kUnknownSize && rhs.sizes[0] != kUnknownSize &&
|
||||
lhs.sizes[1] != rhs.sizes[0])
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
|
||||
knowledge.hasSizes = true;
|
||||
// WARNING: We could be more precise here by calculating the output
|
||||
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
|
||||
// at this stage in the compiler because we don't really have many static
|
||||
// guarantees about the input ranks because `aten` ops do dynamic error
|
||||
// checking and safely abort the program. There is nothing preventing us
|
||||
// from (correctly!) statically inferring the shapes of the operands to
|
||||
// shapes that are guaranteed to cause an error at runtime.
|
||||
//
|
||||
// Example: Suppose a user program calls `aten.mm` with two rank-0
|
||||
// operands. The program emits an error when invoked, but when running
|
||||
// this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size()
|
||||
// == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to
|
||||
// access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
|
||||
// function, it's not as simple as taking `lhs.sizes[0]` and
|
||||
// `rhs.sizes[1]`, as both of those might read out of bounds of the array.
|
||||
// It would require more complicated logic.
|
||||
//
|
||||
// Just knowing dtypes and ranks is sufficient at this stage
|
||||
// in the compiler. The precise per-dimension size propagation is best
|
||||
// done lower in the stack, such as at the linalg level, where we have
|
||||
// more static guarantees and more structure.
|
||||
knowledge.sizes.resize(2, kUnknownSize);
|
||||
// TODO: Investigate promotion rules if element types mismatch.
|
||||
// This is conservatively correct, assuming that if both element types are
|
||||
// the same, then the result is of that same element type.
|
||||
knowledge.sizes = {lhs.sizes[0], rhs.sizes[1]};
|
||||
|
||||
knowledge.dtype =
|
||||
getPromotedResultTypeAssumingNonZeroRank(op->getContext(), {&lhs, &rhs});
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
|
|
|
@ -41,8 +41,8 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
|
|||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2,?],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
||||
|
@ -51,6 +51,44 @@ builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f3
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @g(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
builtin.func @g(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @h(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
|
||||
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
|
||||
builtin.func @h(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @i(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,5],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
|
||||
// CHECK: return %[[MM]] : !torch.vtensor
|
||||
builtin.func @i(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>,
|
||||
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,
|
||||
|
|
Loading…
Reference in New Issue