mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `torch` lowering for determinant (#3639)
The determinant lowering had some extract / insert shape mismatches. Replumbed shape manipulations to correctly implement the determinant operation.pull/3641/head
parent
43e3118eb9
commit
f09cb766dc
|
@ -3032,11 +3032,28 @@ public:
|
|||
Value cstZeroF = getConstant(rewriter, loc, 0, elemTy);
|
||||
// get some shapes
|
||||
SmallVector<int64_t> inputShape(inputType.getShape());
|
||||
|
||||
SmallVector<int64_t> sliceShape(inputShape);
|
||||
sliceShape.pop_back();
|
||||
SmallVector<int64_t> diagShape({isBatched ? inputType.getShape()[0] : 1});
|
||||
sliceShape[sliceShape.size() - 2] = 1;
|
||||
|
||||
SmallVector<int64_t> diagShape(inputType.getShape());
|
||||
diagShape[diagShape.size() - 2] = 1;
|
||||
diagShape[diagShape.size() - 1] = 1;
|
||||
|
||||
ArrayRef<int64_t> diagCollapseShape(diagShape);
|
||||
diagCollapseShape = diagCollapseShape.drop_back();
|
||||
|
||||
auto sliceTy = RankedTensorType::get(sliceShape, elemTy);
|
||||
auto diagTy = RankedTensorType::get(diagShape, elemTy);
|
||||
auto diagCollapseTy = RankedTensorType::get(diagCollapseShape, elemTy);
|
||||
|
||||
SmallVector<ReassociationIndices> diagReassociations;
|
||||
diagReassociations.reserve(diagCollapseShape.size());
|
||||
int64_t diagRank = diagCollapseShape.size();
|
||||
for (int i = 0, s = diagRank - 1; i < s; ++i)
|
||||
diagReassociations.push_back(ReassociationIndices{i});
|
||||
diagReassociations.push_back(ReassociationIndices{diagRank - 1, diagRank});
|
||||
|
||||
// get some sizes
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
Value chDim = isBatched ? inputSizes[0] : cstOne;
|
||||
|
@ -3072,6 +3089,10 @@ public:
|
|||
// offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row)
|
||||
Value diag = b.create<tensor::ExtractSliceOp>(
|
||||
loc, diagTy, vals[0], offsets, sizes, strides);
|
||||
|
||||
Value diagCollapse = b.create<tensor::CollapseShapeOp>(
|
||||
loc, diagCollapseTy, diag, diagReassociations);
|
||||
|
||||
SmallVector<OpFoldResult> diagOffsets(inputRank - 1, cstZeroFold);
|
||||
diagOffsets.back() = row;
|
||||
SmallVector<OpFoldResult> diagStrides(inputRank - 1, cstOneFold);
|
||||
|
@ -3079,7 +3100,7 @@ public:
|
|||
diagSizes.back() = cstOneFold;
|
||||
// offsets = [0, row], sizes = [C, 1] insert to [C,N]
|
||||
Value updatedDiags = b.create<tensor::InsertSliceOp>(
|
||||
loc, diag, vals[1], diagOffsets, diagSizes, diagStrides);
|
||||
loc, diagCollapse, vals[1], diagOffsets, diagSizes, diagStrides);
|
||||
// the subpivot matrix column size, as a Value, is matDim - row -
|
||||
// cstOne. This can't be statically converted to an int64_t, since row
|
||||
// is the loop index, so this is left as a dynamic dim.
|
||||
|
@ -3117,11 +3138,16 @@ public:
|
|||
if (isBatched) {
|
||||
rowIterator.push_back(allDims[1]);
|
||||
colIterator.push_back(allDims[0]);
|
||||
colIterator.push_back(rewriter.getAffineConstantExpr(0));
|
||||
colIterator.push_back(allDims[2]);
|
||||
batchIterator.push_back(allDims[0]);
|
||||
batchIterator.push_back(getAffineConstantExpr(0, context));
|
||||
batchIterator.push_back(getAffineConstantExpr(0, context));
|
||||
} else {
|
||||
colIterator.push_back(rewriter.getAffineConstantExpr(0));
|
||||
colIterator.push_back(allDims[1]);
|
||||
batchIterator.push_back(getAffineConstantExpr(0, context));
|
||||
batchIterator.push_back(getAffineConstantExpr(0, context));
|
||||
}
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
indexingMaps.push_back(
|
||||
|
@ -3183,6 +3209,10 @@ public:
|
|||
offsets.pop_back();
|
||||
strides.pop_back();
|
||||
sizes.pop_back();
|
||||
|
||||
lastDiag = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, diagCollapseTy, lastDiag, diagReassociations);
|
||||
|
||||
Value allDiags = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides);
|
||||
// linalg generic to do reduce prod for allDiags along back dim.
|
||||
|
@ -3193,7 +3223,8 @@ public:
|
|||
: getAffineConstantExpr(0, context);
|
||||
indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr));
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
inputRank - 1, utils::IteratorType::parallel);
|
||||
inputRank - 2, utils::IteratorType::parallel);
|
||||
iteratorTypes.push_back(utils::IteratorType::reduction);
|
||||
Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy,
|
||||
getConstant(rewriter, loc, 1.0, elemTy));
|
||||
Value determinant =
|
||||
|
@ -3213,10 +3244,11 @@ public:
|
|||
determinant);
|
||||
return success();
|
||||
}
|
||||
Value detVal = rewriter.create<tensor::ExtractOp>(
|
||||
loc, determinant, SmallVector<Value>(1, cstZero));
|
||||
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, newResultType,
|
||||
ValueRange{detVal});
|
||||
|
||||
determinant = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, newResultType, determinant,
|
||||
llvm::ArrayRef<ReassociationIndices>{});
|
||||
rewriter.replaceOp(op, ValueRange{determinant});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue