From b75d0e3f8b5267eadbfce67d136427cc9621a65b Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Tue, 5 Nov 2024 19:15:11 +0800 Subject: [PATCH] [stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829) In torch.index_put like ops, `values` is only required to be broadcastable to `input[indices]`, rather than exact dimension match. This patch fixes the problem by add additional stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op. BTW, this patch also enhance the `getBroadcastResultShape` utility in hlo namespace. --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 6 +- .../TorchToStablehlo/GatherScatter.cpp | 90 +++++++++++++------ .../StablehloLegalizeUtils.cpp | 28 +++--- projects/pt1/e2e_testing/xfail_sets.py | 1 + 4 files changed, 83 insertions(+), 42 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 1c3188001..9067b7e24 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Type outElementType); -FailureOr getBroadcastResultShape(PatternRewriter &rewriter, - Operation *op, ArrayRef tensors, - size_t dimSizeIndexBits); +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits); Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, TensorType outType, diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index dc8289b71..c7a67abeb 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -220,16 +220,10 @@ namespace { FailureOr broadcastAndConcatIndices(Operation *op, ConversionPatternRewriter &rewriter, SmallVector indexTensors, - llvm::ArrayRef inputShape, size_t dimSizeIndexBits, int &maxIndexRank) { // Step 1: broadcast indices tensors - SmallVector indicesShape; - SmallVector expandShape; - SmallVector concatShape; - bool allIndexStaticShape = true; - Value bcastSizeTensor; // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { @@ -242,20 +236,15 @@ FailureOr broadcastAndConcatIndices(Operation *op, maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); } - if (!allIndexStaticShape) { - auto bcastSizeTensorInfo = hlo::getBroadcastResultShape( - rewriter, op, indexTensors, dimSizeIndexBits); - if (failed(bcastSizeTensorInfo)) { - return failure(); - } - bcastSizeTensor = *bcastSizeTensorInfo; - } - - for (int i = 0; i < maxIndexRank; i++) { - indicesShape.push_back(inputShape[i]); - expandShape.push_back(inputShape[i]); - concatShape.push_back(inputShape[i]); + auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors, + dimSizeIndexBits); + if (failed(bcastSizeInfo)) { + return failure(); } + Value bcastSizeTensor = (*bcastSizeInfo).first; + auto indicesShape = (*bcastSizeInfo).second; + SmallVector expandShape(indicesShape.begin(), indicesShape.end()); + SmallVector concatShape(indicesShape.begin(), indicesShape.end()); expandShape.push_back(1); concatShape.push_back(indexTensors.size()); @@ -879,7 +868,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputTensorType = cast(input.getType()); auto outType = cast(getTypeConverter()->convertType(op.getType())); - auto outShape = outType.getShape(); Value indexList = op.getIndices(); SmallVector indicesTorchType; if (!getListConstructElements(indexList, indicesTorchType)) @@ -890,9 +878,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto gatherIndicesInfo = - broadcastAndConcatIndices(op, rewriter, indexTensors, outShape, - options.dimSizeIndexBits, maxIndexRank); + auto gatherIndicesInfo = broadcastAndConcatIndices( + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(gatherIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); @@ -949,6 +936,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = cast(getTypeConverter()->convertType(op.getType())); auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + auto inputRank = inputType.getRank(); auto valuesType = cast(values.getType()); int64_t valueRank = valuesType.getRank(); auto valuesShape = valuesType.getShape(); @@ -968,15 +957,58 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto scatterIndicesInfo = - broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape, - options.dimSizeIndexBits, maxIndexRank); + auto scatterIndicesInfo = broadcastAndConcatIndices( + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(scatterIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); } auto scatterIndices = *scatterIndicesInfo; + // broadcast `values` tensor to match expectedValuesShape. + SmallVector scatterIndicesDims; + for (int64_t i = 0; i < maxIndexRank; ++i) { + scatterIndicesDims.push_back(i); + } + auto expectedValuesShapeTensorInfo = + hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims, + options.dimSizeIndexBits); + if (failed(expectedValuesShapeTensorInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get shape of broadcasted indices"); + } + auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo; + SmallVector trailingInputDims; + for (int64_t i = indexCnt; i < inputRank; ++i) { + trailingInputDims.push_back(i); + } + auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, input, trailingInputDims, options.dimSizeIndexBits); + if (failed(trailingInputShapeTensorInfo)) { + return rewriter.notifyMatchFailure(op, "failed to get shape of input"); + } + expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(), + (*trailingInputShapeTensorInfo).end()); + + llvm::ArrayRef scatterIndicesShape = + (cast(scatterIndices.getType())).getShape(); + SmallVector expectedValuesShape( + scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank); + for (int64_t i = indexCnt; i < inputRank; i++) { + expectedValuesShape.push_back(inputShape[i]); + } + + valuesType = + RankedTensorType::get(expectedValuesShape, valuesType.getElementType()); + values = + hlo::promoteAndBroadcast(rewriter, values, valuesType, + rewriter + .create( + op->getLoc(), expectedValuesShapeTensors) + .getResult()); + valueRank = valuesType.getRank(); + valuesShape = valuesType.getShape(); + // create stablehlo::ScatterOp int64_t indexVecDim = maxIndexRank; SmallVector scatterDimOperandDimMap; @@ -1216,9 +1248,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; int maxIndexRank = -1; - auto gatherIndicesInfo = broadcastAndConcatIndices( - input.getDefiningOp(), rewriter, indexTensors, outType.getShape(), - dimSizeIndexBits, maxIndexRank); + auto gatherIndicesInfo = + broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, + dimSizeIndexBits, maxIndexRank); auto gatherIndices = *gatherIndicesInfo; int64_t numIndicesDim = indexTensors.size(); int64_t indexVecDim = maxIndexRank; diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 8b2ec2ed5..b22dc3e6e 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { return getDimIndexOfTensor(rewriter, op, value, dims); } -FailureOr getBroadcastResultShape(PatternRewriter &rewriter, - Operation *op, ArrayRef tensors, - size_t dimSizeIndexBits) { +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits) { SmallVector> tensorSizes; int maxRank = 0; @@ -337,10 +337,11 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, } SmallVector bcastSizeTensors; + SmallVector bcastSizes; for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. int dynamicDimCnt = 0; int staticDimCnt = 0; - int64_t staticDimSize; + int64_t dimSize = -1; Value dimSizeTensor = rewriter.create( op->getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); @@ -351,12 +352,16 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, continue; // dim size: 1 - if (tensorSizes[i][inDim] == 1) + if (tensorSizes[i][inDim] == 1) { + if (dimSize == -1) + dimSize = 1; continue; + } // dim size: dynamic if (tensorSizes[i][inDim] == ShapedType::kDynamic || tensorSizes[i][inDim] == kUnknownSize) { dynamicDimCnt++; + dimSize = ShapedType::kDynamic; auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); if (failed(dimSizeTensorInfo)) { @@ -371,12 +376,12 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, return failure(); } // we already found static dim size not equal with this, fail. - if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) { + if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) { return failure(); } staticDimCnt++; - staticDimSize = tensorSizes[i][inDim]; + dimSize = tensorSizes[i][inDim]; auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); if (failed(dimSizeTensorInfo)) { @@ -389,12 +394,15 @@ FailureOr getBroadcastResultShape(PatternRewriter &rewriter, // if (dynamicDimCnt > 1) { // return failure(); // } - + bcastSizes.push_back(dimSize); bcastSizeTensors.push_back(dimSizeTensor); } + std::reverse(bcastSizes.begin(), bcastSizes.end()); std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); - return rewriter.create(op->getLoc(), bcastSizeTensors) - .getResult(); + return std::pair>( + rewriter.create(op->getLoc(), bcastSizeTensors) + .getResult(), + bcastSizes); } FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 377154586..df84fce90 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -760,6 +760,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", "IndexSelectRank0IdxModule_basic", "IndexTensorNegativeIndexModule_basic", "IntFloatModule_basic",