From 08bc013fcd3232cbf01ad029f057b2fc022e56e1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 09:46:58 -0800 Subject: [PATCH] [tosa] Fix TOSA batch matmul lowering to correct transpose ordering (#2959) The corrective transpose at the end is computed incorrectly. Is it actually computin the inverse transpose. Inverting the permutations fixes the issue. --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 30 ++++++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ce0a1af2f..93fe9dc1c 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1234,7 +1234,7 @@ public: return false; }; - SmallVector commonElems, lhsSqueezedElems, rhsSqueezedElems; + SmallVector batchElems, lhsSqueezedElems, rhsSqueezedElems; if (!performBatchDimBroadcast) { // Simple with no broadcasting artifacts. Just reshape up to 3D @@ -1288,7 +1288,7 @@ public: if (isDynamicDim || lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { commonValue *= lhsBroadcastedShape[dim]; - commonElems.push_back({dim, lhsBroadcastedShape[dim]}); + batchElems.push_back({dim, lhsBroadcastedShape[dim]}); } } commonValue = commonValue < 0 ? kUnknownSize : commonValue; @@ -1315,9 +1315,9 @@ public: // Step: Create the tosa.transpose array. If this array has a // non-monotonic series of dims, perform transpose. // First the common_elems - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedLhsShape.push_back(commonElems[i].shape); - transposedLhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedLhsShape.push_back(batchElems[i].shape); + transposedLhsDims.push_back(batchElems[i].dim); } // then the lhs_squeezed elems for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { @@ -1373,9 +1373,9 @@ public: // Step: Create the RHS transpose sequence // RHS = {common, matmul_dim, rhs_squeezed} // first the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedRhsShape.push_back(commonElems[i].shape); - transposedRhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedRhsShape.push_back(batchElems[i].shape); + transposedRhsDims.push_back(batchElems[i].dim); } // The matmul_dim of RHS transposedRhsDims.push_back(maxInputRank - 2); @@ -1497,9 +1497,9 @@ public: // Step: Construct the output transpose/reshape information // First the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - reshapedOpShape.push_back(commonElems[i].shape); - transposedOpDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + reshapedOpShape.push_back(batchElems[i].shape); + transposedOpDims.push_back(batchElems[i].dim); } // Then the LHS squeezed dims @@ -1532,6 +1532,14 @@ public: transposedOpDims.push_back(maxInputRank - 1); } + // The transposition order is the inverse of what we actually want, + // inversing should fix this: + llvm::SmallVector inverseTransposeDims(transposedOpDims.size()); + for (int i = 0, s = transposedOpDims.size(); i < s; ++i) + inverseTransposeDims[transposedOpDims[i]] = i; + + transposedOpDims = inverseTransposeDims; + // Final transposed output shape construction for (uint32_t i = 0; i < maxInputRank - 2; i++) { if (lhsBroadcastedTy.isDynamicDim(i)) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a4ac58b1d..67a4f175d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1125,6 +1125,7 @@ TOSA_PASS_SET = { "Matmul4dStatic_basic", "Matmul_3d", "Matmul_dot", + "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", "MaxPool2dStaticModule_basic", @@ -1303,6 +1304,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", + "MatmulStaticBroadcast_basic", # failed to legalize operation 'torch.aten.max_pool2d_with_indices "MaxPool2dEmptyStrideStaticModule_basic",