mirror of https://github.com/llvm/torch-mlir
[tosa] Fix TOSA batch matmul lowering to correct transpose ordering (#2959)
The corrective transpose at the end is computed incorrectly. Is it actually computin the inverse transpose. Inverting the permutations fixes the issue.pull/2958/head
parent
4a7a7d76f8
commit
08bc013fcd
|
@ -1234,7 +1234,7 @@ public:
|
|||
return false;
|
||||
};
|
||||
|
||||
SmallVector<TensorShape_t> commonElems, lhsSqueezedElems, rhsSqueezedElems;
|
||||
SmallVector<TensorShape_t> batchElems, lhsSqueezedElems, rhsSqueezedElems;
|
||||
|
||||
if (!performBatchDimBroadcast) {
|
||||
// Simple with no broadcasting artifacts. Just reshape up to 3D
|
||||
|
@ -1288,7 +1288,7 @@ public:
|
|||
if (isDynamicDim ||
|
||||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
|
||||
commonValue *= lhsBroadcastedShape[dim];
|
||||
commonElems.push_back({dim, lhsBroadcastedShape[dim]});
|
||||
batchElems.push_back({dim, lhsBroadcastedShape[dim]});
|
||||
}
|
||||
}
|
||||
commonValue = commonValue < 0 ? kUnknownSize : commonValue;
|
||||
|
@ -1315,9 +1315,9 @@ public:
|
|||
// Step: Create the tosa.transpose array. If this array has a
|
||||
// non-monotonic series of dims, perform transpose.
|
||||
// First the common_elems
|
||||
for (uint32_t i = 0; i < commonElems.size(); i++) {
|
||||
transposedLhsShape.push_back(commonElems[i].shape);
|
||||
transposedLhsDims.push_back(commonElems[i].dim);
|
||||
for (uint32_t i = 0; i < batchElems.size(); i++) {
|
||||
transposedLhsShape.push_back(batchElems[i].shape);
|
||||
transposedLhsDims.push_back(batchElems[i].dim);
|
||||
}
|
||||
// then the lhs_squeezed elems
|
||||
for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) {
|
||||
|
@ -1373,9 +1373,9 @@ public:
|
|||
// Step: Create the RHS transpose sequence
|
||||
// RHS = {common, matmul_dim, rhs_squeezed}
|
||||
// first the common_dims
|
||||
for (uint32_t i = 0; i < commonElems.size(); i++) {
|
||||
transposedRhsShape.push_back(commonElems[i].shape);
|
||||
transposedRhsDims.push_back(commonElems[i].dim);
|
||||
for (uint32_t i = 0; i < batchElems.size(); i++) {
|
||||
transposedRhsShape.push_back(batchElems[i].shape);
|
||||
transposedRhsDims.push_back(batchElems[i].dim);
|
||||
}
|
||||
// The matmul_dim of RHS
|
||||
transposedRhsDims.push_back(maxInputRank - 2);
|
||||
|
@ -1497,9 +1497,9 @@ public:
|
|||
|
||||
// Step: Construct the output transpose/reshape information
|
||||
// First the common_dims
|
||||
for (uint32_t i = 0; i < commonElems.size(); i++) {
|
||||
reshapedOpShape.push_back(commonElems[i].shape);
|
||||
transposedOpDims.push_back(commonElems[i].dim);
|
||||
for (uint32_t i = 0; i < batchElems.size(); i++) {
|
||||
reshapedOpShape.push_back(batchElems[i].shape);
|
||||
transposedOpDims.push_back(batchElems[i].dim);
|
||||
}
|
||||
|
||||
// Then the LHS squeezed dims
|
||||
|
@ -1532,6 +1532,14 @@ public:
|
|||
transposedOpDims.push_back(maxInputRank - 1);
|
||||
}
|
||||
|
||||
// The transposition order is the inverse of what we actually want,
|
||||
// inversing should fix this:
|
||||
llvm::SmallVector<int> inverseTransposeDims(transposedOpDims.size());
|
||||
for (int i = 0, s = transposedOpDims.size(); i < s; ++i)
|
||||
inverseTransposeDims[transposedOpDims[i]] = i;
|
||||
|
||||
transposedOpDims = inverseTransposeDims;
|
||||
|
||||
// Final transposed output shape construction
|
||||
for (uint32_t i = 0; i < maxInputRank - 2; i++) {
|
||||
if (lhsBroadcastedTy.isDynamicDim(i)) {
|
||||
|
|
|
@ -1125,6 +1125,7 @@ TOSA_PASS_SET = {
|
|||
"Matmul4dStatic_basic",
|
||||
"Matmul_3d",
|
||||
"Matmul_dot",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
|
@ -1303,6 +1304,7 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
|
||||
# Dynamic shape, has extra unsupported broadcast ops
|
||||
"Matmul_3d",
|
||||
"MatmulStaticBroadcast_basic",
|
||||
|
||||
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
|
|
Loading…
Reference in New Issue