diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 99528fcb2..e44ba8300 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -55,6 +55,7 @@ TOSA_PASS_SET = { "ReturnTwoTensorF32I64_basic", "ElementwisePowModule_basic", "BmmModule_basic", + "MmDagModule_basic", "Matmul_dot", "Matmul_3d", "RsubModule_basic", diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 091112b2f..7ebdc35cf 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -823,32 +823,23 @@ ChangeResult TypeAnalyzer::visitAtenMmOp( auto &rhs = operands[1]->getValue(); auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + + auto isRank2 = [](const ValueKnowledge &operand) -> bool { + return operand.hasSizes && operand.sizes.size() == 2; + }; + + // `aten.mm` expects both operands to be rank-2 tensors. + if (!isRank2(lhs) || !isRank2(rhs)) + return getLatticeElement(op->getResult(0)).join(knowledge); + + // If static information is available, check that both tensors are compatible. + if (lhs.sizes[1] != kUnknownSize && rhs.sizes[0] != kUnknownSize && + lhs.sizes[1] != rhs.sizes[0]) + return getLatticeElement(op->getResult(0)).join(knowledge); + knowledge.hasSizes = true; - // WARNING: We could be more precise here by calculating the output - // shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky - // at this stage in the compiler because we don't really have many static - // guarantees about the input ranks because `aten` ops do dynamic error - // checking and safely abort the program. There is nothing preventing us - // from (correctly!) statically inferring the shapes of the operands to - // shapes that are guaranteed to cause an error at runtime. - // - // Example: Suppose a user program calls `aten.mm` with two rank-0 - // operands. The program emits an error when invoked, but when running - // this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size() - // == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to - // access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer - // function, it's not as simple as taking `lhs.sizes[0]` and - // `rhs.sizes[1]`, as both of those might read out of bounds of the array. - // It would require more complicated logic. - // - // Just knowing dtypes and ranks is sufficient at this stage - // in the compiler. The precise per-dimension size propagation is best - // done lower in the stack, such as at the linalg level, where we have - // more static guarantees and more structure. - knowledge.sizes.resize(2, kUnknownSize); - // TODO: Investigate promotion rules if element types mismatch. - // This is conservatively correct, assuming that if both element types are - // the same, then the result is of that same element type. + knowledge.sizes = {lhs.sizes[0], rhs.sizes[1]}; + knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(), {&lhs, &rhs}); return getLatticeElement(op->getResult(0)).join(knowledge); diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index fedccd54a..84df11b2e 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -41,8 +41,8 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor { // CHECK-LABEL: func @f( // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>, // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { -// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor +// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2,?],f32> +// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,?],f32> to !torch.vtensor // CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { %1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor @@ -51,6 +51,44 @@ builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f3 // ----- +// CHECK-LABEL: func @g( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,3],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { +// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> +// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor +// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor +builtin.func @g(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { + %1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: func @h( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { +// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> +// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor +// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor +builtin.func @h(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { + %1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- + +// CHECK-LABEL: func @i( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,5],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { +// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor +// CHECK: return %[[MM]] : !torch.vtensor +builtin.func @i(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor { + %1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- + // CHECK-LABEL: func @f( // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>, // CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,