From f09cb766dc40fca6f72e8535d3a9014ba065919f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 15 Aug 2024 15:41:50 -0700 Subject: [PATCH] [onnx] Fix `torch` lowering for determinant (#3639) The determinant lowering had some extract / insert shape mismatches. Replumbed shape manipulations to correctly implement the determinant operation. --- .../TorchToLinalg/Uncategorized.cpp | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2936e72a2..7823138c9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -3032,11 +3032,28 @@ public: Value cstZeroF = getConstant(rewriter, loc, 0, elemTy); // get some shapes SmallVector inputShape(inputType.getShape()); + SmallVector sliceShape(inputShape); - sliceShape.pop_back(); - SmallVector diagShape({isBatched ? inputType.getShape()[0] : 1}); + sliceShape[sliceShape.size() - 2] = 1; + + SmallVector diagShape(inputType.getShape()); + diagShape[diagShape.size() - 2] = 1; + diagShape[diagShape.size() - 1] = 1; + + ArrayRef diagCollapseShape(diagShape); + diagCollapseShape = diagCollapseShape.drop_back(); + auto sliceTy = RankedTensorType::get(sliceShape, elemTy); auto diagTy = RankedTensorType::get(diagShape, elemTy); + auto diagCollapseTy = RankedTensorType::get(diagCollapseShape, elemTy); + + SmallVector diagReassociations; + diagReassociations.reserve(diagCollapseShape.size()); + int64_t diagRank = diagCollapseShape.size(); + for (int i = 0, s = diagRank - 1; i < s; ++i) + diagReassociations.push_back(ReassociationIndices{i}); + diagReassociations.push_back(ReassociationIndices{diagRank - 1, diagRank}); + // get some sizes SmallVector inputSizes = getTensorSizes(rewriter, loc, input); Value chDim = isBatched ? inputSizes[0] : cstOne; @@ -3072,6 +3089,10 @@ public: // offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row) Value diag = b.create( loc, diagTy, vals[0], offsets, sizes, strides); + + Value diagCollapse = b.create( + loc, diagCollapseTy, diag, diagReassociations); + SmallVector diagOffsets(inputRank - 1, cstZeroFold); diagOffsets.back() = row; SmallVector diagStrides(inputRank - 1, cstOneFold); @@ -3079,7 +3100,7 @@ public: diagSizes.back() = cstOneFold; // offsets = [0, row], sizes = [C, 1] insert to [C,N] Value updatedDiags = b.create( - loc, diag, vals[1], diagOffsets, diagSizes, diagStrides); + loc, diagCollapse, vals[1], diagOffsets, diagSizes, diagStrides); // the subpivot matrix column size, as a Value, is matDim - row - // cstOne. This can't be statically converted to an int64_t, since row // is the loop index, so this is left as a dynamic dim. @@ -3117,11 +3138,16 @@ public: if (isBatched) { rowIterator.push_back(allDims[1]); colIterator.push_back(allDims[0]); + colIterator.push_back(rewriter.getAffineConstantExpr(0)); colIterator.push_back(allDims[2]); batchIterator.push_back(allDims[0]); + batchIterator.push_back(getAffineConstantExpr(0, context)); + batchIterator.push_back(getAffineConstantExpr(0, context)); } else { + colIterator.push_back(rewriter.getAffineConstantExpr(0)); colIterator.push_back(allDims[1]); batchIterator.push_back(getAffineConstantExpr(0, context)); + batchIterator.push_back(getAffineConstantExpr(0, context)); } SmallVector indexingMaps; indexingMaps.push_back( @@ -3183,6 +3209,10 @@ public: offsets.pop_back(); strides.pop_back(); sizes.pop_back(); + + lastDiag = rewriter.create( + loc, diagCollapseTy, lastDiag, diagReassociations); + Value allDiags = rewriter.create( loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides); // linalg generic to do reduce prod for allDiags along back dim. @@ -3193,7 +3223,8 @@ public: : getAffineConstantExpr(0, context); indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr)); SmallVector iteratorTypes( - inputRank - 1, utils::IteratorType::parallel); + inputRank - 2, utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy, getConstant(rewriter, loc, 1.0, elemTy)); Value determinant = @@ -3213,10 +3244,11 @@ public: determinant); return success(); } - Value detVal = rewriter.create( - loc, determinant, SmallVector(1, cstZero)); - rewriter.replaceOpWithNewOp(op, newResultType, - ValueRange{detVal}); + + determinant = rewriter.create( + loc, newResultType, determinant, + llvm::ArrayRef{}); + rewriter.replaceOp(op, ValueRange{determinant}); return success(); } };