[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322)

For now, at most one dynamic dim of index tensors in
aten.index/aten.index_put-like op is supported.
pull/3516/head
Jiawei Wu 2024-08-01 10:41:09 +08:00 committed by GitHub
parent 7f475e174e
commit edc87fc577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 164 additions and 39 deletions

View File

@ -52,8 +52,13 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
Type outElementType);
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);
TensorType outType,
std::optional<Value> bcastSizeTensor);
SmallVector<int64_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);

View File

@ -768,7 +768,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
getTypeConverter()->convertType(op->getResult(0).getType()));
if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
Value bcastOp =
hlo::promoteAndBroadcast(rewriter, self, outType, std::nullopt);
rewriter.replaceOp(op, bcastOp);
return success();
}
@ -1488,8 +1489,10 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
.value());
// Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto bcastedWeight =
hlo::promoteAndBroadcast(rewriter, weight, outputTy, std::nullopt);
auto bcastedBias =
hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt);
auto outputMulWeight =
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput = rewriter.create<stablehlo::AddOp>(
@ -1634,8 +1637,10 @@ LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
maxValue = *maxInfo;
}
if (inputType.hasStaticShape()) {
minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType);
maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType);
minValue =
hlo::promoteAndBroadcast(rewriter, minValue, inputType, std::nullopt);
maxValue =
hlo::promoteAndBroadcast(rewriter, maxValue, inputType, std::nullopt);
}
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
@ -2021,7 +2026,7 @@ LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt);
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
return success();
}
@ -2036,7 +2041,7 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt);
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
return success();
}

View File

@ -221,32 +221,40 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
ConversionPatternRewriter &rewriter,
SmallVector<Value> indexTensors,
llvm::ArrayRef<int64_t> inputShape,
size_t dimSizeIndexBits,
int &maxIndexRank) {
// Step 1: broadcast indices tensors
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;
bool allIndexStaticShape = true;
Value bcastSizeTensor;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto indexTensor = indexTensors[i];
auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
if (size == kUnknownSize)
return failure();
allIndexStaticShape = false;
}
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
}
SmallVector<int64_t> refinedInputShape = makeShapeTorchCompatible(inputShape);
for (int64_t size : refinedInputShape) {
if (size == kUnknownSize) {
if (!allIndexStaticShape) {
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
rewriter, op, indexTensors, dimSizeIndexBits);
if (failed(bcastSizeTensorInfo)) {
return failure();
}
bcastSizeTensor = *bcastSizeTensorInfo;
}
for (int i = 0; i < maxIndexRank; i++) {
indicesShape.push_back(refinedInputShape[i]);
expandShape.push_back(refinedInputShape[i]);
concatShape.push_back(refinedInputShape[i]);
indicesShape.push_back(inputShape[i]);
expandShape.push_back(inputShape[i]);
concatShape.push_back(inputShape[i]);
}
expandShape.push_back(1);
concatShape.push_back(indexTensors.size());
@ -256,12 +264,29 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
RankedTensorType bcastIndexType =
RankedTensorType::get(indicesShape, indexElemTy);
for (auto indexTensor : indexTensors) {
Value bcastVal =
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
Value bcastVal;
RankedTensorType reshapeType =
RankedTensorType::get(expandShape, indexElemTy);
bcastVal = rewriter.create<stablehlo::ReshapeOp>(op->getLoc(), reshapeType,
bcastVal);
if (allIndexStaticShape) {
bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType,
std::nullopt);
bcastVal = rewriter.create<stablehlo::ReshapeOp>(op->getLoc(),
reshapeType, bcastVal);
} else {
bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType,
bcastSizeTensor);
auto bcastValShapeTensorVec =
*hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits);
bcastValShapeTensorVec.push_back(rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(dimSizeIndexBits), 1)));
Value bcastValShapeTensor = rewriter
.create<tensor::FromElementsOp>(
op->getLoc(), bcastValShapeTensorVec)
.getResult();
bcastVal = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor);
}
broadcastedIndices.push_back(bcastVal);
}
@ -797,8 +822,9 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
indicesTorchType);
int maxIndexRank = -1;
auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors,
outShape, maxIndexRank);
auto gatherIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
options.dimSizeIndexBits, maxIndexRank);
if (failed(gatherIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
@ -874,8 +900,9 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
indicesTorchType);
int maxIndexRank = -1;
auto scatterIndicesInfo = broadcastAndConcatIndices(
op, rewriter, indexTensors, valuesShape, maxIndexRank);
auto scatterIndicesInfo =
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
options.dimSizeIndexBits, maxIndexRank);
if (failed(scatterIndicesInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to generate broadcasted indices");
@ -1109,7 +1136,8 @@ SmallVector<Value> clip(ConversionPatternRewriter &rewriter, Operation *op,
Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
Value input, Value ix, Value iy, Value w, int64_t N,
int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx,
Value CIdx, RankedTensorType outType, Type elemTy) {
Value CIdx, RankedTensorType outType, Type elemTy,
size_t dimSizeIndexBits) {
Location loc = op->getLoc();
auto inputTensorType = cast<RankedTensorType>(input.getType());
SmallVector<Value> clipValues =
@ -1120,9 +1148,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
int maxIndexRank = -1;
auto gatherIndicesInfo =
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
outType.getShape(), maxIndexRank);
auto gatherIndicesInfo = broadcastAndConcatIndices(
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
dimSizeIndexBits, maxIndexRank);
auto gatherIndices = *gatherIndicesInfo;
int64_t numIndicesDim = indexTensors.size();
int64_t indexVecDim = maxIndexRank;
@ -1310,14 +1338,18 @@ LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_nw, bcastDimensions),
bcastDimensions);
Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_nw =
getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, oH, oW, iH, iW,
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
Value summand_ne =
getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, oH, oW, iH, iW,
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
Value summand_sw =
getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, oH, oW, iH, iW,
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
Value summand_se =
getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, oH, oW, iH, iW,
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
// summand_nw + summand_ne + summand_sw + summand_se
Value sum = rewriter.create<stablehlo::AddOp>(loc, summand_nw, summand_ne);
@ -1332,9 +1364,9 @@ LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
Value ix_round = rewriter.create<stablehlo::RoundOp>(loc, ix);
Value iy_round = rewriter.create<stablehlo::RoundOp>(loc, iy);
Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round);
Value summand =
getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH,
oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand = getSummand(rewriter, op, input, ix_round, iy_round,
oneTensor, N, oH, oW, iH, iW, Nidx, Cidx, outTy,
elemTy, options.dimSizeIndexBits);
rewriter.replaceOp(op, summand);
}
return success();

View File

@ -179,12 +179,15 @@ Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
}
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType) {
TensorType outType,
std::optional<Value> bcastSizeTensor) {
// Two tensors are “broadcastable” if the following rules hold:
// - Each tensor has at least one dimension.
// - When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
// If one provide bcastSizeTensor, we emit stablehlo::DynamicBroadcastInDimOp
// instead of stablehlo::BroadcastInDimOp to support dynamic shape.
Operation *op = input.getDefiningOp();
TensorType in_type = dyn_cast<TensorType>(input.getType());
@ -222,6 +225,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
return input;
}
auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims);
if (bcastSizeTensor.has_value()) {
auto bcast_op = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr);
return bcast_op.getResult();
}
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, input, bcast_attr);
return bcast_op.getResult();
@ -314,6 +322,81 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
return getDimIndexOfTensor(rewriter, op, value, dims);
}
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
Operation *op, ArrayRef<Value> tensors,
size_t dimSizeIndexBits) {
SmallVector<ArrayRef<int64_t>> tensorSizes;
int maxRank = 0;
for (auto tensor : tensors) {
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto tensorRank = tensorType.getRank();
tensorSizes.emplace_back(tensorType.getShape());
maxRank = std::max(maxRank, static_cast<int>(tensorRank));
}
SmallVector<Value> bcastSizeTensors;
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
int dynamicDimCnt = 0;
int staticDimCnt = 0;
int64_t staticDimSize;
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(),
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors.
int inDim = tensorSizes[i].size() - 1 - outDim;
if (inDim < 0)
continue;
// dim size: 1
if (tensorSizes[i][inDim] == 1)
continue;
// dim size: dynamic
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
tensorSizes[i][inDim] == kUnknownSize) {
dynamicDimCnt++;
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
return failure();
}
dimSizeTensor = (*dimSizeTensorInfo)[0];
continue;
}
// dim size: static
// we already found dynamic dim size, fail.
if (dynamicDimCnt > 0) {
return failure();
}
// we already found static dim size not equal with this, fail.
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
return failure();
}
staticDimCnt++;
staticDimSize = tensorSizes[i][inDim];
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
if (failed(dimSizeTensorInfo)) {
return failure();
}
dimSizeTensor = (*dimSizeTensorInfo)[0];
}
// TODO: Relax this check, by assuming all dynamic shape is same.
// if (dynamicDimCnt > 1) {
// return failure();
// }
bcastSizeTensors.push_back(dimSizeTensor);
}
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
.getResult();
}
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor,
ArrayRef<int64_t> inputUnsqzDims) {