Add static type information support to `aten.mm`

This commit adds static type information support to `aten.mm`. This is
needed for the forward pass of Bert training.
bert-staging
Ramiro Leal-Cavazos 2022-02-15 23:08:36 +00:00
parent 54357ea378
commit fa6cf0bed8
3 changed files with 53 additions and 27 deletions

View File

@ -56,6 +56,7 @@ TOSA_PASS_SET = {
"ReturnTwoTensorF32I64_basic",
"ElementwisePowModule_basic",
"BmmModule_basic",
"MmDagModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",

View File

@ -830,32 +830,19 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
auto &rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
// `aten.mm` expects both operands to be rank-2 tensors.
if (lhs.sizes.size() != 2 || rhs.sizes.size() != 2)
return getLatticeElement(op->getResult(0)).join(knowledge);
// If static information is available, check that both tensors are compatible.
if (lhs.sizes[1] != kUnknownSize && rhs.sizes[0] != kUnknownSize &&
lhs.sizes[1] != rhs.sizes[0])
return getLatticeElement(op->getResult(0)).join(knowledge);
knowledge.hasSizes = true;
// WARNING: We could be more precise here by calculating the output
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
// at this stage in the compiler because we don't really have many static
// guarantees about the input ranks because `aten` ops do dynamic error
// checking and safely abort the program. There is nothing preventing us
// from (correctly!) statically inferring the shapes of the operands to
// shapes that are guaranteed to cause an error at runtime.
//
// Example: Suppose a user program calls `aten.mm` with two rank-0
// operands. The program emits an error when invoked, but when running
// this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size()
// == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to
// access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
// function, it's not as simple as taking `lhs.sizes[0]` and
// `rhs.sizes[1]`, as both of those might read out of bounds of the array.
// It would require more complicated logic.
//
// Just knowing dtypes and ranks is sufficient at this stage
// in the compiler. The precise per-dimension size propagation is best
// done lower in the stack, such as at the linalg level, where we have
// more static guarantees and more structure.
knowledge.sizes.resize(2, kUnknownSize);
// TODO: Investigate promotion rules if element types mismatch.
// This is conservatively correct, assuming that if both element types are
// the same, then the result is of that same element type.
knowledge.sizes = {lhs.sizes[0], rhs.sizes[1]};
knowledge.dtype =
getPromotedResultTypeAssumingNonZeroRank(op->getContext(), {&lhs, &rhs});
return getLatticeElement(op->getResult(0)).join(knowledge);

View File

@ -41,8 +41,8 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
@ -51,6 +51,44 @@ builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f3
// -----
// CHECK-LABEL: func @g(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @g(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @h(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @h(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @i(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,5],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
// CHECK: return %[[MM]] : !torch.vtensor
builtin.func @i(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,