From edc87fc577b699a0c9dbfff94f6cb38e2831223d Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Thu, 1 Aug 2024 10:41:09 +0800 Subject: [PATCH] [stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322) For now, at most one dynamic dim of index tensors in aten.index/aten.index_put-like op is supported. --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 7 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 19 ++-- .../TorchToStablehlo/GatherScatter.cpp | 92 +++++++++++++------ .../StablehloLegalizeUtils.cpp | 85 ++++++++++++++++- 4 files changed, 164 insertions(+), 39 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 78a1aba7e..1c3188001 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -52,8 +52,13 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Type outElementType); +FailureOr getBroadcastResultShape(PatternRewriter &rewriter, + Operation *op, ArrayRef tensors, + size_t dimSizeIndexBits); + Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, - TensorType outType); + TensorType outType, + std::optional bcastSizeTensor); SmallVector toPositiveDims(ArrayRef dims, int64_t rank); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 5e3ab2114..1f21a1afe 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -768,7 +768,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( getTypeConverter()->convertType(op->getResult(0).getType())); if (options.enableStaticShape && selfTy.hasStaticShape()) { - Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); + Value bcastOp = + hlo::promoteAndBroadcast(rewriter, self, outType, std::nullopt); rewriter.replaceOp(op, bcastOp); return success(); } @@ -1488,8 +1489,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value()); // Apply affine transform: output x weight + bias [element-wise] - auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); - auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto bcastedWeight = + hlo::promoteAndBroadcast(rewriter, weight, outputTy, std::nullopt); + auto bcastedBias = + hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create( @@ -1634,8 +1637,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( maxValue = *maxInfo; } if (inputType.hasStaticShape()) { - minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType); - maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType); + minValue = + hlo::promoteAndBroadcast(rewriter, minValue, inputType, std::nullopt); + maxValue = + hlo::promoteAndBroadcast(rewriter, maxValue, inputType, std::nullopt); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); @@ -2021,7 +2026,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } @@ -2036,7 +2041,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index e3168004b..528a0718b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -221,32 +221,40 @@ FailureOr broadcastAndConcatIndices(Operation *op, ConversionPatternRewriter &rewriter, SmallVector indexTensors, llvm::ArrayRef inputShape, + size_t dimSizeIndexBits, int &maxIndexRank) { // Step 1: broadcast indices tensors SmallVector indicesShape; SmallVector expandShape; SmallVector concatShape; + + bool allIndexStaticShape = true; + Value bcastSizeTensor; + // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto indexTensor = indexTensors[i]; auto indexTensorType = cast(indexTensor.getType()); for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { if (size == kUnknownSize) - return failure(); + allIndexStaticShape = false; } maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); } - SmallVector refinedInputShape = makeShapeTorchCompatible(inputShape); - for (int64_t size : refinedInputShape) { - if (size == kUnknownSize) { + if (!allIndexStaticShape) { + auto bcastSizeTensorInfo = hlo::getBroadcastResultShape( + rewriter, op, indexTensors, dimSizeIndexBits); + if (failed(bcastSizeTensorInfo)) { return failure(); } + bcastSizeTensor = *bcastSizeTensorInfo; } + for (int i = 0; i < maxIndexRank; i++) { - indicesShape.push_back(refinedInputShape[i]); - expandShape.push_back(refinedInputShape[i]); - concatShape.push_back(refinedInputShape[i]); + indicesShape.push_back(inputShape[i]); + expandShape.push_back(inputShape[i]); + concatShape.push_back(inputShape[i]); } expandShape.push_back(1); concatShape.push_back(indexTensors.size()); @@ -256,12 +264,29 @@ FailureOr broadcastAndConcatIndices(Operation *op, RankedTensorType bcastIndexType = RankedTensorType::get(indicesShape, indexElemTy); for (auto indexTensor : indexTensors) { - Value bcastVal = - hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType); + Value bcastVal; RankedTensorType reshapeType = RankedTensorType::get(expandShape, indexElemTy); - bcastVal = rewriter.create(op->getLoc(), reshapeType, - bcastVal); + if (allIndexStaticShape) { + bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, + std::nullopt); + bcastVal = rewriter.create(op->getLoc(), + reshapeType, bcastVal); + } else { + bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, + bcastSizeTensor); + auto bcastValShapeTensorVec = + *hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits); + bcastValShapeTensorVec.push_back(rewriter.create( + op->getLoc(), rewriter.getIntegerAttr( + rewriter.getIntegerType(dimSizeIndexBits), 1))); + Value bcastValShapeTensor = rewriter + .create( + op->getLoc(), bcastValShapeTensorVec) + .getResult(); + bcastVal = rewriter.create( + op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor); + } broadcastedIndices.push_back(bcastVal); } @@ -797,8 +822,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors, - outShape, maxIndexRank); + auto gatherIndicesInfo = + broadcastAndConcatIndices(op, rewriter, indexTensors, outShape, + options.dimSizeIndexBits, maxIndexRank); if (failed(gatherIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); @@ -874,8 +900,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto scatterIndicesInfo = broadcastAndConcatIndices( - op, rewriter, indexTensors, valuesShape, maxIndexRank); + auto scatterIndicesInfo = + broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape, + options.dimSizeIndexBits, maxIndexRank); if (failed(scatterIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); @@ -1109,7 +1136,8 @@ SmallVector clip(ConversionPatternRewriter &rewriter, Operation *op, Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, Value input, Value ix, Value iy, Value w, int64_t N, int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx, - Value CIdx, RankedTensorType outType, Type elemTy) { + Value CIdx, RankedTensorType outType, Type elemTy, + size_t dimSizeIndexBits) { Location loc = op->getLoc(); auto inputTensorType = cast(input.getType()); SmallVector clipValues = @@ -1120,9 +1148,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; int maxIndexRank = -1; - auto gatherIndicesInfo = - broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, - outType.getShape(), maxIndexRank); + auto gatherIndicesInfo = broadcastAndConcatIndices( + input.getDefiningOp(), rewriter, indexTensors, outType.getShape(), + dimSizeIndexBits, maxIndexRank); auto gatherIndices = *gatherIndicesInfo; int64_t numIndicesDim = indexTensors.size(); int64_t indexVecDim = maxIndexRank; @@ -1310,14 +1338,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(loc, iy, iy_nw, bcastDimensions), bcastDimensions); - Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, - oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); - Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, - oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); - Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, - oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); - Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, - oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_nw = + getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_ne = + getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_sw = + getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_se = + getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); // summand_nw + summand_ne + summand_sw + summand_se Value sum = rewriter.create(loc, summand_nw, summand_ne); @@ -1332,9 +1364,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value ix_round = rewriter.create(loc, ix); Value iy_round = rewriter.create(loc, iy); Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round); - Value summand = - getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH, - oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand = getSummand(rewriter, op, input, ix_round, iy_round, + oneTensor, N, oH, oW, iH, iW, Nidx, Cidx, outTy, + elemTy, options.dimSizeIndexBits); rewriter.replaceOp(op, summand); } return success(); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index cf31ba281..8b2ec2ed5 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -179,12 +179,15 @@ Value promoteType(PatternRewriter &rewriter, Location loc, Value input, } Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, - TensorType outType) { + TensorType outType, + std::optional bcastSizeTensor) { // Two tensors are “broadcastable” if the following rules hold: // - Each tensor has at least one dimension. // - When iterating over the dimension sizes, starting at the trailing // dimension, the dimension sizes must either be equal, one of them is 1, or // one of them does not exist. + // If one provide bcastSizeTensor, we emit stablehlo::DynamicBroadcastInDimOp + // instead of stablehlo::BroadcastInDimOp to support dynamic shape. Operation *op = input.getDefiningOp(); TensorType in_type = dyn_cast(input.getType()); @@ -222,6 +225,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, return input; } auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); + if (bcastSizeTensor.has_value()) { + auto bcast_op = rewriter.create( + op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr); + return bcast_op.getResult(); + } auto bcast_op = rewriter.create( op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); @@ -314,6 +322,81 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { return getDimIndexOfTensor(rewriter, op, value, dims); } +FailureOr getBroadcastResultShape(PatternRewriter &rewriter, + Operation *op, ArrayRef tensors, + size_t dimSizeIndexBits) { + SmallVector> tensorSizes; + + int maxRank = 0; + for (auto tensor : tensors) { + auto tensorType = cast(tensor.getType()); + auto tensorRank = tensorType.getRank(); + + tensorSizes.emplace_back(tensorType.getShape()); + maxRank = std::max(maxRank, static_cast(tensorRank)); + } + + SmallVector bcastSizeTensors; + for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. + int dynamicDimCnt = 0; + int staticDimCnt = 0; + int64_t staticDimSize; + Value dimSizeTensor = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); + + for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors. + int inDim = tensorSizes[i].size() - 1 - outDim; + if (inDim < 0) + continue; + + // dim size: 1 + if (tensorSizes[i][inDim] == 1) + continue; + // dim size: dynamic + if (tensorSizes[i][inDim] == ShapedType::kDynamic || + tensorSizes[i][inDim] == kUnknownSize) { + dynamicDimCnt++; + auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); + if (failed(dimSizeTensorInfo)) { + return failure(); + } + dimSizeTensor = (*dimSizeTensorInfo)[0]; + continue; + } + // dim size: static + // we already found dynamic dim size, fail. + if (dynamicDimCnt > 0) { + return failure(); + } + // we already found static dim size not equal with this, fail. + if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) { + return failure(); + } + + staticDimCnt++; + staticDimSize = tensorSizes[i][inDim]; + auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); + if (failed(dimSizeTensorInfo)) { + return failure(); + } + dimSizeTensor = (*dimSizeTensorInfo)[0]; + } + + // TODO: Relax this check, by assuming all dynamic shape is same. + // if (dynamicDimCnt > 1) { + // return failure(); + // } + + bcastSizeTensors.push_back(dimSizeTensor); + } + std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); + return rewriter.create(op->getLoc(), bcastSizeTensors) + .getResult(); +} + FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef inputUnsqzDims) {