[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829)

In torch.index_put like ops, `values` is only required to be
broadcastable to `input[indices]`, rather than exact dimension match.
This patch fixes the problem by add additional
stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op.
BTW, this patch also enhance the `getBroadcastResultShape` utility in
hlo namespace.
pull/3804/head
Jiawei Wu 2024-11-05 19:15:11 +08:00 committed by GitHub
parent 4c1518d365
commit b75d0e3f8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 83 additions and 42 deletions

View File

@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
Type outElementType); Type outElementType);
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter, FailureOr<std::pair<Value, SmallVector<int64_t>>>
Operation *op, ArrayRef<Value> tensors, getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
size_t dimSizeIndexBits); ArrayRef<Value> tensors, size_t dimSizeIndexBits);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType, TensorType outType,

View File

@ -220,16 +220,10 @@ namespace {
FailureOr<Value> broadcastAndConcatIndices(Operation *op, FailureOr<Value> broadcastAndConcatIndices(Operation *op,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter &rewriter,
SmallVector<Value> indexTensors, SmallVector<Value> indexTensors,
llvm::ArrayRef<int64_t> inputShape,
size_t dimSizeIndexBits, size_t dimSizeIndexBits,
int &maxIndexRank) { int &maxIndexRank) {
// Step 1: broadcast indices tensors // Step 1: broadcast indices tensors
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;
bool allIndexStaticShape = true; bool allIndexStaticShape = true;
Value bcastSizeTensor;
// concat index tensor into to indices tensor for concat // concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) { for (size_t i = 0; i < indexTensors.size(); i++) {
@ -242,20 +236,15 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
} }
if (!allIndexStaticShape) { auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors,
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape( dimSizeIndexBits);
rewriter, op, indexTensors, dimSizeIndexBits); if (failed(bcastSizeInfo)) {
if (failed(bcastSizeTensorInfo)) { return failure();
return failure();
}
bcastSizeTensor = *bcastSizeTensorInfo;
}
for (int i = 0; i < maxIndexRank; i++) {
indicesShape.push_back(inputShape[i]);
expandShape.push_back(inputShape[i]);
concatShape.push_back(inputShape[i]);
} }
Value bcastSizeTensor = (*bcastSizeInfo).first;
auto indicesShape = (*bcastSizeInfo).second;
SmallVector<int64_t> expandShape(indicesShape.begin(), indicesShape.end());
SmallVector<int64_t> concatShape(indicesShape.begin(), indicesShape.end());
expandShape.push_back(1); expandShape.push_back(1);
concatShape.push_back(indexTensors.size()); concatShape.push_back(indexTensors.size());
@ -879,7 +868,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto inputTensorType = cast<RankedTensorType>(input.getType()); auto inputTensorType = cast<RankedTensorType>(input.getType());
auto outType = auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto outShape = outType.getShape();
Value indexList = op.getIndices(); Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType; SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType)) if (!getListConstructElements(indexList, indicesTorchType))
@ -890,9 +878,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTorchType); indicesTorchType);
int maxIndexRank = -1; int maxIndexRank = -1;
auto gatherIndicesInfo = auto gatherIndicesInfo = broadcastAndConcatIndices(
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape, op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
options.dimSizeIndexBits, maxIndexRank);
if (failed(gatherIndicesInfo)) { if (failed(gatherIndicesInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices"); op, "failed to generate broadcasted indices");
@ -949,6 +936,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto outType = auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType())); cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
auto inputRank = inputType.getRank();
auto valuesType = cast<RankedTensorType>(values.getType()); auto valuesType = cast<RankedTensorType>(values.getType());
int64_t valueRank = valuesType.getRank(); int64_t valueRank = valuesType.getRank();
auto valuesShape = valuesType.getShape(); auto valuesShape = valuesType.getShape();
@ -968,15 +957,58 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indicesTorchType); indicesTorchType);
int maxIndexRank = -1; int maxIndexRank = -1;
auto scatterIndicesInfo = auto scatterIndicesInfo = broadcastAndConcatIndices(
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape, op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
options.dimSizeIndexBits, maxIndexRank);
if (failed(scatterIndicesInfo)) { if (failed(scatterIndicesInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices"); op, "failed to generate broadcasted indices");
} }
auto scatterIndices = *scatterIndicesInfo; auto scatterIndices = *scatterIndicesInfo;
// broadcast `values` tensor to match expectedValuesShape.
SmallVector<int64_t> scatterIndicesDims;
for (int64_t i = 0; i < maxIndexRank; ++i) {
scatterIndicesDims.push_back(i);
}
auto expectedValuesShapeTensorInfo =
hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims,
options.dimSizeIndexBits);
if (failed(expectedValuesShapeTensorInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get shape of broadcasted indices");
}
auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo;
SmallVector<int64_t> trailingInputDims;
for (int64_t i = indexCnt; i < inputRank; ++i) {
trailingInputDims.push_back(i);
}
auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, input, trailingInputDims, options.dimSizeIndexBits);
if (failed(trailingInputShapeTensorInfo)) {
return rewriter.notifyMatchFailure(op, "failed to get shape of input");
}
expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(),
(*trailingInputShapeTensorInfo).end());
llvm::ArrayRef<int64_t> scatterIndicesShape =
(cast<RankedTensorType>(scatterIndices.getType())).getShape();
SmallVector<int64_t> expectedValuesShape(
scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank);
for (int64_t i = indexCnt; i < inputRank; i++) {
expectedValuesShape.push_back(inputShape[i]);
}
valuesType =
RankedTensorType::get(expectedValuesShape, valuesType.getElementType());
values =
hlo::promoteAndBroadcast(rewriter, values, valuesType,
rewriter
.create<tensor::FromElementsOp>(
op->getLoc(), expectedValuesShapeTensors)
.getResult());
valueRank = valuesType.getRank();
valuesShape = valuesType.getShape();
// create stablehlo::ScatterOp // create stablehlo::ScatterOp
int64_t indexVecDim = maxIndexRank; int64_t indexVecDim = maxIndexRank;
SmallVector<int64_t> scatterDimOperandDimMap; SmallVector<int64_t> scatterDimOperandDimMap;
@ -1216,9 +1248,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX}; SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
int maxIndexRank = -1; int maxIndexRank = -1;
auto gatherIndicesInfo = broadcastAndConcatIndices( auto gatherIndicesInfo =
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(), broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
dimSizeIndexBits, maxIndexRank); dimSizeIndexBits, maxIndexRank);
auto gatherIndices = *gatherIndicesInfo; auto gatherIndices = *gatherIndicesInfo;
int64_t numIndicesDim = indexTensors.size(); int64_t numIndicesDim = indexTensors.size();
int64_t indexVecDim = maxIndexRank; int64_t indexVecDim = maxIndexRank;

View File

@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
return getDimIndexOfTensor(rewriter, op, value, dims); return getDimIndexOfTensor(rewriter, op, value, dims);
} }
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter, FailureOr<std::pair<Value, SmallVector<int64_t>>>
Operation *op, ArrayRef<Value> tensors, getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
size_t dimSizeIndexBits) { ArrayRef<Value> tensors, size_t dimSizeIndexBits) {
SmallVector<ArrayRef<int64_t>> tensorSizes; SmallVector<ArrayRef<int64_t>> tensorSizes;
int maxRank = 0; int maxRank = 0;
@ -337,10 +337,11 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
} }
SmallVector<Value> bcastSizeTensors; SmallVector<Value> bcastSizeTensors;
SmallVector<int64_t> bcastSizes;
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
int dynamicDimCnt = 0; int dynamicDimCnt = 0;
int staticDimCnt = 0; int staticDimCnt = 0;
int64_t staticDimSize; int64_t dimSize = -1;
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>( Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), op->getLoc(),
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
@ -351,12 +352,16 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
continue; continue;
// dim size: 1 // dim size: 1
if (tensorSizes[i][inDim] == 1) if (tensorSizes[i][inDim] == 1) {
if (dimSize == -1)
dimSize = 1;
continue; continue;
}
// dim size: dynamic // dim size: dynamic
if (tensorSizes[i][inDim] == ShapedType::kDynamic || if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
tensorSizes[i][inDim] == kUnknownSize) { tensorSizes[i][inDim] == kUnknownSize) {
dynamicDimCnt++; dynamicDimCnt++;
dimSize = ShapedType::kDynamic;
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) { if (failed(dimSizeTensorInfo)) {
@ -371,12 +376,12 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
return failure(); return failure();
} }
// we already found static dim size not equal with this, fail. // we already found static dim size not equal with this, fail.
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) { if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) {
return failure(); return failure();
} }
staticDimCnt++; staticDimCnt++;
staticDimSize = tensorSizes[i][inDim]; dimSize = tensorSizes[i][inDim];
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) { if (failed(dimSizeTensorInfo)) {
@ -389,12 +394,15 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
// if (dynamicDimCnt > 1) { // if (dynamicDimCnt > 1) {
// return failure(); // return failure();
// } // }
bcastSizes.push_back(dimSize);
bcastSizeTensors.push_back(dimSizeTensor); bcastSizeTensors.push_back(dimSizeTensor);
} }
std::reverse(bcastSizes.begin(), bcastSizes.end());
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors) return std::pair<Value, SmallVector<int64_t>>(
.getResult(); rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult(),
bcastSizes);
} }
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,

View File

@ -760,6 +760,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic", "IndexPutImplIndexWithNoneModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"IndexSelectRank0IdxModule_basic", "IndexSelectRank0IdxModule_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"IntFloatModule_basic", "IntFloatModule_basic",