[tosa] Fix TOSA batch matmul lowering to correct transpose ordering (#2959)

The corrective transpose at the end is computed incorrectly. Is it
actually computin the inverse transpose. Inverting the permutations
fixes the issue.
pull/2958/head
Rob Suderman 2024-02-28 09:46:58 -08:00 committed by GitHub
parent 4a7a7d76f8
commit 08bc013fcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 11 deletions

View File

@ -1234,7 +1234,7 @@ public:
return false;
};
SmallVector<TensorShape_t> commonElems, lhsSqueezedElems, rhsSqueezedElems;
SmallVector<TensorShape_t> batchElems, lhsSqueezedElems, rhsSqueezedElems;
if (!performBatchDimBroadcast) {
// Simple with no broadcasting artifacts. Just reshape up to 3D
@ -1288,7 +1288,7 @@ public:
if (isDynamicDim ||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
commonValue *= lhsBroadcastedShape[dim];
commonElems.push_back({dim, lhsBroadcastedShape[dim]});
batchElems.push_back({dim, lhsBroadcastedShape[dim]});
}
}
commonValue = commonValue < 0 ? kUnknownSize : commonValue;
@ -1315,9 +1315,9 @@ public:
// Step: Create the tosa.transpose array. If this array has a
// non-monotonic series of dims, perform transpose.
// First the common_elems
for (uint32_t i = 0; i < commonElems.size(); i++) {
transposedLhsShape.push_back(commonElems[i].shape);
transposedLhsDims.push_back(commonElems[i].dim);
for (uint32_t i = 0; i < batchElems.size(); i++) {
transposedLhsShape.push_back(batchElems[i].shape);
transposedLhsDims.push_back(batchElems[i].dim);
}
// then the lhs_squeezed elems
for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) {
@ -1373,9 +1373,9 @@ public:
// Step: Create the RHS transpose sequence
// RHS = {common, matmul_dim, rhs_squeezed}
// first the common_dims
for (uint32_t i = 0; i < commonElems.size(); i++) {
transposedRhsShape.push_back(commonElems[i].shape);
transposedRhsDims.push_back(commonElems[i].dim);
for (uint32_t i = 0; i < batchElems.size(); i++) {
transposedRhsShape.push_back(batchElems[i].shape);
transposedRhsDims.push_back(batchElems[i].dim);
}
// The matmul_dim of RHS
transposedRhsDims.push_back(maxInputRank - 2);
@ -1497,9 +1497,9 @@ public:
// Step: Construct the output transpose/reshape information
// First the common_dims
for (uint32_t i = 0; i < commonElems.size(); i++) {
reshapedOpShape.push_back(commonElems[i].shape);
transposedOpDims.push_back(commonElems[i].dim);
for (uint32_t i = 0; i < batchElems.size(); i++) {
reshapedOpShape.push_back(batchElems[i].shape);
transposedOpDims.push_back(batchElems[i].dim);
}
// Then the LHS squeezed dims
@ -1532,6 +1532,14 @@ public:
transposedOpDims.push_back(maxInputRank - 1);
}
// The transposition order is the inverse of what we actually want,
// inversing should fix this:
llvm::SmallVector<int> inverseTransposeDims(transposedOpDims.size());
for (int i = 0, s = transposedOpDims.size(); i < s; ++i)
inverseTransposeDims[transposedOpDims[i]] = i;
transposedOpDims = inverseTransposeDims;
// Final transposed output shape construction
for (uint32_t i = 0; i < maxInputRank - 2; i++) {
if (lhsBroadcastedTy.isDynamicDim(i)) {

View File

@ -1125,6 +1125,7 @@ TOSA_PASS_SET = {
"Matmul4dStatic_basic",
"Matmul_3d",
"Matmul_dot",
"MatmulStaticBroadcast_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticModule_basic",
@ -1303,6 +1304,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
"MaxPool2dEmptyStrideStaticModule_basic",