Add static type information support to `aten.mm` (#602)

This commit adds static type information support to `aten.mm`. This is
needed for the forward pass of Bert training.
pull/606/head
Ramiro Leal-Cavazos 2022-02-18 09:56:48 -08:00 committed by GitHub
parent abbde7d439
commit 2823277f7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 27 deletions

View File

@ -55,6 +55,7 @@ TOSA_PASS_SET = {
"ReturnTwoTensorF32I64_basic",
"ElementwisePowModule_basic",
"BmmModule_basic",
"MmDagModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",

View File

@ -823,32 +823,23 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
auto &rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto isRank2 = [](const ValueKnowledge &operand) -> bool {
return operand.hasSizes && operand.sizes.size() == 2;
};
// `aten.mm` expects both operands to be rank-2 tensors.
if (!isRank2(lhs) || !isRank2(rhs))
return getLatticeElement(op->getResult(0)).join(knowledge);
// If static information is available, check that both tensors are compatible.
if (lhs.sizes[1] != kUnknownSize && rhs.sizes[0] != kUnknownSize &&
lhs.sizes[1] != rhs.sizes[0])
return getLatticeElement(op->getResult(0)).join(knowledge);
knowledge.hasSizes = true;
// WARNING: We could be more precise here by calculating the output
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
// at this stage in the compiler because we don't really have many static
// guarantees about the input ranks because `aten` ops do dynamic error
// checking and safely abort the program. There is nothing preventing us
// from (correctly!) statically inferring the shapes of the operands to
// shapes that are guaranteed to cause an error at runtime.
//
// Example: Suppose a user program calls `aten.mm` with two rank-0
// operands. The program emits an error when invoked, but when running
// this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size()
// == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to
// access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
// function, it's not as simple as taking `lhs.sizes[0]` and
// `rhs.sizes[1]`, as both of those might read out of bounds of the array.
// It would require more complicated logic.
//
// Just knowing dtypes and ranks is sufficient at this stage
// in the compiler. The precise per-dimension size propagation is best
// done lower in the stack, such as at the linalg level, where we have
// more static guarantees and more structure.
knowledge.sizes.resize(2, kUnknownSize);
// TODO: Investigate promotion rules if element types mismatch.
// This is conservatively correct, assuming that if both element types are
// the same, then the result is of that same element type.
knowledge.sizes = {lhs.sizes[0], rhs.sizes[1]};
knowledge.dtype =
getPromotedResultTypeAssumingNonZeroRank(op->getContext(), {&lhs, &rhs});
return getLatticeElement(op->getResult(0)).join(knowledge);

View File

@ -41,8 +41,8 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
@ -51,6 +51,44 @@ builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f3
// -----
// CHECK-LABEL: func @g(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @g(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @h(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @h(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @i(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,5],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
// CHECK: return %[[MM]] : !torch.vtensor
builtin.func @i(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,