mirror of https://github.com/llvm/torch-mlir
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829)
In torch.index_put like ops, `values` is only required to be broadcastable to `input[indices]`, rather than exact dimension match. This patch fixes the problem by add additional stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op. BTW, this patch also enhance the `getBroadcastResultShape` utility in hlo namespace.pull/3804/head
parent
4c1518d365
commit
b75d0e3f8b
|
@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type outElementType);
|
||||
|
||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<Value> tensors,
|
||||
size_t dimSizeIndexBits);
|
||||
FailureOr<std::pair<Value, SmallVector<int64_t>>>
|
||||
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<Value> tensors, size_t dimSizeIndexBits);
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType,
|
||||
|
|
|
@ -220,16 +220,10 @@ namespace {
|
|||
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value> indexTensors,
|
||||
llvm::ArrayRef<int64_t> inputShape,
|
||||
size_t dimSizeIndexBits,
|
||||
int &maxIndexRank) {
|
||||
// Step 1: broadcast indices tensors
|
||||
SmallVector<int64_t> indicesShape;
|
||||
SmallVector<int64_t> expandShape;
|
||||
SmallVector<int64_t> concatShape;
|
||||
|
||||
bool allIndexStaticShape = true;
|
||||
Value bcastSizeTensor;
|
||||
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
|
@ -242,20 +236,15 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|||
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
|
||||
}
|
||||
|
||||
if (!allIndexStaticShape) {
|
||||
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
|
||||
rewriter, op, indexTensors, dimSizeIndexBits);
|
||||
if (failed(bcastSizeTensorInfo)) {
|
||||
auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors,
|
||||
dimSizeIndexBits);
|
||||
if (failed(bcastSizeInfo)) {
|
||||
return failure();
|
||||
}
|
||||
bcastSizeTensor = *bcastSizeTensorInfo;
|
||||
}
|
||||
|
||||
for (int i = 0; i < maxIndexRank; i++) {
|
||||
indicesShape.push_back(inputShape[i]);
|
||||
expandShape.push_back(inputShape[i]);
|
||||
concatShape.push_back(inputShape[i]);
|
||||
}
|
||||
Value bcastSizeTensor = (*bcastSizeInfo).first;
|
||||
auto indicesShape = (*bcastSizeInfo).second;
|
||||
SmallVector<int64_t> expandShape(indicesShape.begin(), indicesShape.end());
|
||||
SmallVector<int64_t> concatShape(indicesShape.begin(), indicesShape.end());
|
||||
expandShape.push_back(1);
|
||||
concatShape.push_back(indexTensors.size());
|
||||
|
||||
|
@ -879,7 +868,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||
auto outType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto outShape = outType.getShape();
|
||||
Value indexList = op.getIndices();
|
||||
SmallVector<Value> indicesTorchType;
|
||||
if (!getListConstructElements(indexList, indicesTorchType))
|
||||
|
@ -890,9 +878,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
indicesTorchType);
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto gatherIndicesInfo =
|
||||
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
|
||||
options.dimSizeIndexBits, maxIndexRank);
|
||||
auto gatherIndicesInfo = broadcastAndConcatIndices(
|
||||
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
|
||||
if (failed(gatherIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate broadcasted indices");
|
||||
|
@ -949,6 +936,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
auto outType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = inputType.getShape();
|
||||
auto inputRank = inputType.getRank();
|
||||
auto valuesType = cast<RankedTensorType>(values.getType());
|
||||
int64_t valueRank = valuesType.getRank();
|
||||
auto valuesShape = valuesType.getShape();
|
||||
|
@ -968,15 +957,58 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
indicesTorchType);
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto scatterIndicesInfo =
|
||||
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
|
||||
options.dimSizeIndexBits, maxIndexRank);
|
||||
auto scatterIndicesInfo = broadcastAndConcatIndices(
|
||||
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
|
||||
if (failed(scatterIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate broadcasted indices");
|
||||
}
|
||||
auto scatterIndices = *scatterIndicesInfo;
|
||||
|
||||
// broadcast `values` tensor to match expectedValuesShape.
|
||||
SmallVector<int64_t> scatterIndicesDims;
|
||||
for (int64_t i = 0; i < maxIndexRank; ++i) {
|
||||
scatterIndicesDims.push_back(i);
|
||||
}
|
||||
auto expectedValuesShapeTensorInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(expectedValuesShapeTensorInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get shape of broadcasted indices");
|
||||
}
|
||||
auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo;
|
||||
SmallVector<int64_t> trailingInputDims;
|
||||
for (int64_t i = indexCnt; i < inputRank; ++i) {
|
||||
trailingInputDims.push_back(i);
|
||||
}
|
||||
auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, input, trailingInputDims, options.dimSizeIndexBits);
|
||||
if (failed(trailingInputShapeTensorInfo)) {
|
||||
return rewriter.notifyMatchFailure(op, "failed to get shape of input");
|
||||
}
|
||||
expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(),
|
||||
(*trailingInputShapeTensorInfo).end());
|
||||
|
||||
llvm::ArrayRef<int64_t> scatterIndicesShape =
|
||||
(cast<RankedTensorType>(scatterIndices.getType())).getShape();
|
||||
SmallVector<int64_t> expectedValuesShape(
|
||||
scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank);
|
||||
for (int64_t i = indexCnt; i < inputRank; i++) {
|
||||
expectedValuesShape.push_back(inputShape[i]);
|
||||
}
|
||||
|
||||
valuesType =
|
||||
RankedTensorType::get(expectedValuesShape, valuesType.getElementType());
|
||||
values =
|
||||
hlo::promoteAndBroadcast(rewriter, values, valuesType,
|
||||
rewriter
|
||||
.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), expectedValuesShapeTensors)
|
||||
.getResult());
|
||||
valueRank = valuesType.getRank();
|
||||
valuesShape = valuesType.getShape();
|
||||
|
||||
// create stablehlo::ScatterOp
|
||||
int64_t indexVecDim = maxIndexRank;
|
||||
SmallVector<int64_t> scatterDimOperandDimMap;
|
||||
|
@ -1216,8 +1248,8 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto gatherIndicesInfo = broadcastAndConcatIndices(
|
||||
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
|
||||
auto gatherIndicesInfo =
|
||||
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
|
||||
dimSizeIndexBits, maxIndexRank);
|
||||
auto gatherIndices = *gatherIndicesInfo;
|
||||
int64_t numIndicesDim = indexTensors.size();
|
||||
|
|
|
@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
|||
return getDimIndexOfTensor(rewriter, op, value, dims);
|
||||
}
|
||||
|
||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<Value> tensors,
|
||||
size_t dimSizeIndexBits) {
|
||||
FailureOr<std::pair<Value, SmallVector<int64_t>>>
|
||||
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<Value> tensors, size_t dimSizeIndexBits) {
|
||||
SmallVector<ArrayRef<int64_t>> tensorSizes;
|
||||
|
||||
int maxRank = 0;
|
||||
|
@ -337,10 +337,11 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
SmallVector<Value> bcastSizeTensors;
|
||||
SmallVector<int64_t> bcastSizes;
|
||||
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
|
||||
int dynamicDimCnt = 0;
|
||||
int staticDimCnt = 0;
|
||||
int64_t staticDimSize;
|
||||
int64_t dimSize = -1;
|
||||
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
||||
|
@ -351,12 +352,16 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
|||
continue;
|
||||
|
||||
// dim size: 1
|
||||
if (tensorSizes[i][inDim] == 1)
|
||||
if (tensorSizes[i][inDim] == 1) {
|
||||
if (dimSize == -1)
|
||||
dimSize = 1;
|
||||
continue;
|
||||
}
|
||||
// dim size: dynamic
|
||||
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
|
||||
tensorSizes[i][inDim] == kUnknownSize) {
|
||||
dynamicDimCnt++;
|
||||
dimSize = ShapedType::kDynamic;
|
||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||
if (failed(dimSizeTensorInfo)) {
|
||||
|
@ -371,12 +376,12 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
|||
return failure();
|
||||
}
|
||||
// we already found static dim size not equal with this, fail.
|
||||
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
|
||||
if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
staticDimCnt++;
|
||||
staticDimSize = tensorSizes[i][inDim];
|
||||
dimSize = tensorSizes[i][inDim];
|
||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||
if (failed(dimSizeTensorInfo)) {
|
||||
|
@ -389,12 +394,15 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
|||
// if (dynamicDimCnt > 1) {
|
||||
// return failure();
|
||||
// }
|
||||
|
||||
bcastSizes.push_back(dimSize);
|
||||
bcastSizeTensors.push_back(dimSizeTensor);
|
||||
}
|
||||
std::reverse(bcastSizes.begin(), bcastSizes.end());
|
||||
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
|
||||
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
|
||||
.getResult();
|
||||
return std::pair<Value, SmallVector<int64_t>>(
|
||||
rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
|
||||
.getResult(),
|
||||
bcastSizes);
|
||||
}
|
||||
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
|
|
|
@ -760,6 +760,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"IndexSelectRank0IdxModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"IntFloatModule_basic",
|
||||
|
|
Loading…
Reference in New Issue