[onnx] Fix `torch` lowering for determinant (#3639)

The determinant lowering had some extract / insert shape mismatches.
Replumbed shape manipulations to correctly implement the determinant
operation.
pull/3641/head
Rob Suderman 2024-08-15 15:41:50 -07:00 committed by GitHub
parent 43e3118eb9
commit f09cb766dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 40 additions and 8 deletions

View File

@ -3032,11 +3032,28 @@ public:
Value cstZeroF = getConstant(rewriter, loc, 0, elemTy);
// get some shapes
SmallVector<int64_t> inputShape(inputType.getShape());
SmallVector<int64_t> sliceShape(inputShape);
sliceShape.pop_back();
SmallVector<int64_t> diagShape({isBatched ? inputType.getShape()[0] : 1});
sliceShape[sliceShape.size() - 2] = 1;
SmallVector<int64_t> diagShape(inputType.getShape());
diagShape[diagShape.size() - 2] = 1;
diagShape[diagShape.size() - 1] = 1;
ArrayRef<int64_t> diagCollapseShape(diagShape);
diagCollapseShape = diagCollapseShape.drop_back();
auto sliceTy = RankedTensorType::get(sliceShape, elemTy);
auto diagTy = RankedTensorType::get(diagShape, elemTy);
auto diagCollapseTy = RankedTensorType::get(diagCollapseShape, elemTy);
SmallVector<ReassociationIndices> diagReassociations;
diagReassociations.reserve(diagCollapseShape.size());
int64_t diagRank = diagCollapseShape.size();
for (int i = 0, s = diagRank - 1; i < s; ++i)
diagReassociations.push_back(ReassociationIndices{i});
diagReassociations.push_back(ReassociationIndices{diagRank - 1, diagRank});
// get some sizes
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
Value chDim = isBatched ? inputSizes[0] : cstOne;
@ -3072,6 +3089,10 @@ public:
// offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row)
Value diag = b.create<tensor::ExtractSliceOp>(
loc, diagTy, vals[0], offsets, sizes, strides);
Value diagCollapse = b.create<tensor::CollapseShapeOp>(
loc, diagCollapseTy, diag, diagReassociations);
SmallVector<OpFoldResult> diagOffsets(inputRank - 1, cstZeroFold);
diagOffsets.back() = row;
SmallVector<OpFoldResult> diagStrides(inputRank - 1, cstOneFold);
@ -3079,7 +3100,7 @@ public:
diagSizes.back() = cstOneFold;
// offsets = [0, row], sizes = [C, 1] insert to [C,N]
Value updatedDiags = b.create<tensor::InsertSliceOp>(
loc, diag, vals[1], diagOffsets, diagSizes, diagStrides);
loc, diagCollapse, vals[1], diagOffsets, diagSizes, diagStrides);
// the subpivot matrix column size, as a Value, is matDim - row -
// cstOne. This can't be statically converted to an int64_t, since row
// is the loop index, so this is left as a dynamic dim.
@ -3117,11 +3138,16 @@ public:
if (isBatched) {
rowIterator.push_back(allDims[1]);
colIterator.push_back(allDims[0]);
colIterator.push_back(rewriter.getAffineConstantExpr(0));
colIterator.push_back(allDims[2]);
batchIterator.push_back(allDims[0]);
batchIterator.push_back(getAffineConstantExpr(0, context));
batchIterator.push_back(getAffineConstantExpr(0, context));
} else {
colIterator.push_back(rewriter.getAffineConstantExpr(0));
colIterator.push_back(allDims[1]);
batchIterator.push_back(getAffineConstantExpr(0, context));
batchIterator.push_back(getAffineConstantExpr(0, context));
}
SmallVector<AffineMap> indexingMaps;
indexingMaps.push_back(
@ -3183,6 +3209,10 @@ public:
offsets.pop_back();
strides.pop_back();
sizes.pop_back();
lastDiag = rewriter.create<tensor::CollapseShapeOp>(
loc, diagCollapseTy, lastDiag, diagReassociations);
Value allDiags = rewriter.create<tensor::InsertSliceOp>(
loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides);
// linalg generic to do reduce prod for allDiags along back dim.
@ -3193,7 +3223,8 @@ public:
: getAffineConstantExpr(0, context);
indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr));
SmallVector<utils::IteratorType> iteratorTypes(
inputRank - 1, utils::IteratorType::parallel);
inputRank - 2, utils::IteratorType::parallel);
iteratorTypes.push_back(utils::IteratorType::reduction);
Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy,
getConstant(rewriter, loc, 1.0, elemTy));
Value determinant =
@ -3213,10 +3244,11 @@ public:
determinant);
return success();
}
Value detVal = rewriter.create<tensor::ExtractOp>(
loc, determinant, SmallVector<Value>(1, cstZero));
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, newResultType,
ValueRange{detVal});
determinant = rewriter.create<tensor::CollapseShapeOp>(
loc, newResultType, determinant,
llvm::ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, ValueRange{determinant});
return success();
}
};