mirror of https://github.com/llvm/torch-mlir
[stablehlo] support dynamic-shaped index in stablehlo conversion for aten.index-like ops (#3322)
For now, at most one dynamic dim of index tensors in aten.index/aten.index_put-like op is supported.pull/3516/head
parent
7f475e174e
commit
edc87fc577
|
@ -52,8 +52,13 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type outElementType);
|
||||
|
||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<Value> tensors,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType);
|
||||
TensorType outType,
|
||||
std::optional<Value> bcastSizeTensor);
|
||||
|
||||
SmallVector<int64_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
|
||||
|
||||
|
|
|
@ -768,7 +768,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||
|
||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
Value bcastOp =
|
||||
hlo::promoteAndBroadcast(rewriter, self, outType, std::nullopt);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
return success();
|
||||
}
|
||||
|
@ -1488,8 +1489,10 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
.value());
|
||||
|
||||
// Apply affine transform: output x weight + bias [element-wise]
|
||||
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||
auto bcastedWeight =
|
||||
hlo::promoteAndBroadcast(rewriter, weight, outputTy, std::nullopt);
|
||||
auto bcastedBias =
|
||||
hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt);
|
||||
auto outputMulWeight =
|
||||
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||
auto finalOuput = rewriter.create<stablehlo::AddOp>(
|
||||
|
@ -1634,8 +1637,10 @@ LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
|
|||
maxValue = *maxInfo;
|
||||
}
|
||||
if (inputType.hasStaticShape()) {
|
||||
minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType);
|
||||
maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType);
|
||||
minValue =
|
||||
hlo::promoteAndBroadcast(rewriter, minValue, inputType, std::nullopt);
|
||||
maxValue =
|
||||
hlo::promoteAndBroadcast(rewriter, maxValue, inputType, std::nullopt);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
||||
maxValue);
|
||||
|
@ -2021,7 +2026,7 @@ LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
|
|||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
@ -2036,7 +2041,7 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
|
|||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -221,32 +221,40 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value> indexTensors,
|
||||
llvm::ArrayRef<int64_t> inputShape,
|
||||
size_t dimSizeIndexBits,
|
||||
int &maxIndexRank) {
|
||||
// Step 1: broadcast indices tensors
|
||||
SmallVector<int64_t> indicesShape;
|
||||
SmallVector<int64_t> expandShape;
|
||||
SmallVector<int64_t> concatShape;
|
||||
|
||||
bool allIndexStaticShape = true;
|
||||
Value bcastSizeTensor;
|
||||
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto indexTensor = indexTensors[i];
|
||||
auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
|
||||
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||||
if (size == kUnknownSize)
|
||||
return failure();
|
||||
allIndexStaticShape = false;
|
||||
}
|
||||
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
|
||||
}
|
||||
|
||||
SmallVector<int64_t> refinedInputShape = makeShapeTorchCompatible(inputShape);
|
||||
for (int64_t size : refinedInputShape) {
|
||||
if (size == kUnknownSize) {
|
||||
if (!allIndexStaticShape) {
|
||||
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
|
||||
rewriter, op, indexTensors, dimSizeIndexBits);
|
||||
if (failed(bcastSizeTensorInfo)) {
|
||||
return failure();
|
||||
}
|
||||
bcastSizeTensor = *bcastSizeTensorInfo;
|
||||
}
|
||||
|
||||
for (int i = 0; i < maxIndexRank; i++) {
|
||||
indicesShape.push_back(refinedInputShape[i]);
|
||||
expandShape.push_back(refinedInputShape[i]);
|
||||
concatShape.push_back(refinedInputShape[i]);
|
||||
indicesShape.push_back(inputShape[i]);
|
||||
expandShape.push_back(inputShape[i]);
|
||||
concatShape.push_back(inputShape[i]);
|
||||
}
|
||||
expandShape.push_back(1);
|
||||
concatShape.push_back(indexTensors.size());
|
||||
|
@ -256,12 +264,29 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
|||
RankedTensorType bcastIndexType =
|
||||
RankedTensorType::get(indicesShape, indexElemTy);
|
||||
for (auto indexTensor : indexTensors) {
|
||||
Value bcastVal =
|
||||
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
|
||||
Value bcastVal;
|
||||
RankedTensorType reshapeType =
|
||||
RankedTensorType::get(expandShape, indexElemTy);
|
||||
bcastVal = rewriter.create<stablehlo::ReshapeOp>(op->getLoc(), reshapeType,
|
||||
bcastVal);
|
||||
if (allIndexStaticShape) {
|
||||
bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType,
|
||||
std::nullopt);
|
||||
bcastVal = rewriter.create<stablehlo::ReshapeOp>(op->getLoc(),
|
||||
reshapeType, bcastVal);
|
||||
} else {
|
||||
bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType,
|
||||
bcastSizeTensor);
|
||||
auto bcastValShapeTensorVec =
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits);
|
||||
bcastValShapeTensorVec.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(dimSizeIndexBits), 1)));
|
||||
Value bcastValShapeTensor = rewriter
|
||||
.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), bcastValShapeTensorVec)
|
||||
.getResult();
|
||||
bcastVal = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor);
|
||||
}
|
||||
broadcastedIndices.push_back(bcastVal);
|
||||
}
|
||||
|
||||
|
@ -797,8 +822,9 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
|||
indicesTorchType);
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors,
|
||||
outShape, maxIndexRank);
|
||||
auto gatherIndicesInfo =
|
||||
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
|
||||
options.dimSizeIndexBits, maxIndexRank);
|
||||
if (failed(gatherIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate broadcasted indices");
|
||||
|
@ -874,8 +900,9 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
indicesTorchType);
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto scatterIndicesInfo = broadcastAndConcatIndices(
|
||||
op, rewriter, indexTensors, valuesShape, maxIndexRank);
|
||||
auto scatterIndicesInfo =
|
||||
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
|
||||
options.dimSizeIndexBits, maxIndexRank);
|
||||
if (failed(scatterIndicesInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to generate broadcasted indices");
|
||||
|
@ -1109,7 +1136,8 @@ SmallVector<Value> clip(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value input, Value ix, Value iy, Value w, int64_t N,
|
||||
int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx,
|
||||
Value CIdx, RankedTensorType outType, Type elemTy) {
|
||||
Value CIdx, RankedTensorType outType, Type elemTy,
|
||||
size_t dimSizeIndexBits) {
|
||||
Location loc = op->getLoc();
|
||||
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<Value> clipValues =
|
||||
|
@ -1120,9 +1148,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto gatherIndicesInfo =
|
||||
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
|
||||
outType.getShape(), maxIndexRank);
|
||||
auto gatherIndicesInfo = broadcastAndConcatIndices(
|
||||
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
|
||||
dimSizeIndexBits, maxIndexRank);
|
||||
auto gatherIndices = *gatherIndicesInfo;
|
||||
int64_t numIndicesDim = indexTensors.size();
|
||||
int64_t indexVecDim = maxIndexRank;
|
||||
|
@ -1310,14 +1338,18 @@ LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
|
|||
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_nw, bcastDimensions),
|
||||
bcastDimensions);
|
||||
|
||||
Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_nw =
|
||||
getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, oH, oW, iH, iW,
|
||||
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
|
||||
Value summand_ne =
|
||||
getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, oH, oW, iH, iW,
|
||||
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
|
||||
Value summand_sw =
|
||||
getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, oH, oW, iH, iW,
|
||||
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
|
||||
Value summand_se =
|
||||
getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, oH, oW, iH, iW,
|
||||
Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits);
|
||||
|
||||
// summand_nw + summand_ne + summand_sw + summand_se
|
||||
Value sum = rewriter.create<stablehlo::AddOp>(loc, summand_nw, summand_ne);
|
||||
|
@ -1332,9 +1364,9 @@ LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
|
|||
Value ix_round = rewriter.create<stablehlo::RoundOp>(loc, ix);
|
||||
Value iy_round = rewriter.create<stablehlo::RoundOp>(loc, iy);
|
||||
Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round);
|
||||
Value summand =
|
||||
getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH,
|
||||
oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand = getSummand(rewriter, op, input, ix_round, iy_round,
|
||||
oneTensor, N, oH, oW, iH, iW, Nidx, Cidx, outTy,
|
||||
elemTy, options.dimSizeIndexBits);
|
||||
rewriter.replaceOp(op, summand);
|
||||
}
|
||||
return success();
|
||||
|
|
|
@ -179,12 +179,15 @@ Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
|||
}
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType) {
|
||||
TensorType outType,
|
||||
std::optional<Value> bcastSizeTensor) {
|
||||
// Two tensors are “broadcastable” if the following rules hold:
|
||||
// - Each tensor has at least one dimension.
|
||||
// - When iterating over the dimension sizes, starting at the trailing
|
||||
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
||||
// one of them does not exist.
|
||||
// If one provide bcastSizeTensor, we emit stablehlo::DynamicBroadcastInDimOp
|
||||
// instead of stablehlo::BroadcastInDimOp to support dynamic shape.
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType in_type = dyn_cast<TensorType>(input.getType());
|
||||
|
||||
|
@ -222,6 +225,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
return input;
|
||||
}
|
||||
auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims);
|
||||
if (bcastSizeTensor.has_value()) {
|
||||
auto bcast_op = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
}
|
||||
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
op->getLoc(), outType, input, bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
|
@ -314,6 +322,81 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
|||
return getDimIndexOfTensor(rewriter, op, value, dims);
|
||||
}
|
||||
|
||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<Value> tensors,
|
||||
size_t dimSizeIndexBits) {
|
||||
SmallVector<ArrayRef<int64_t>> tensorSizes;
|
||||
|
||||
int maxRank = 0;
|
||||
for (auto tensor : tensors) {
|
||||
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||
auto tensorRank = tensorType.getRank();
|
||||
|
||||
tensorSizes.emplace_back(tensorType.getShape());
|
||||
maxRank = std::max(maxRank, static_cast<int>(tensorRank));
|
||||
}
|
||||
|
||||
SmallVector<Value> bcastSizeTensors;
|
||||
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
|
||||
int dynamicDimCnt = 0;
|
||||
int staticDimCnt = 0;
|
||||
int64_t staticDimSize;
|
||||
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
||||
|
||||
for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors.
|
||||
int inDim = tensorSizes[i].size() - 1 - outDim;
|
||||
if (inDim < 0)
|
||||
continue;
|
||||
|
||||
// dim size: 1
|
||||
if (tensorSizes[i][inDim] == 1)
|
||||
continue;
|
||||
// dim size: dynamic
|
||||
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
|
||||
tensorSizes[i][inDim] == kUnknownSize) {
|
||||
dynamicDimCnt++;
|
||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||
if (failed(dimSizeTensorInfo)) {
|
||||
return failure();
|
||||
}
|
||||
dimSizeTensor = (*dimSizeTensorInfo)[0];
|
||||
continue;
|
||||
}
|
||||
// dim size: static
|
||||
// we already found dynamic dim size, fail.
|
||||
if (dynamicDimCnt > 0) {
|
||||
return failure();
|
||||
}
|
||||
// we already found static dim size not equal with this, fail.
|
||||
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
staticDimCnt++;
|
||||
staticDimSize = tensorSizes[i][inDim];
|
||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||
if (failed(dimSizeTensorInfo)) {
|
||||
return failure();
|
||||
}
|
||||
dimSizeTensor = (*dimSizeTensorInfo)[0];
|
||||
}
|
||||
|
||||
// TODO: Relax this check, by assuming all dynamic shape is same.
|
||||
// if (dynamicDimCnt > 1) {
|
||||
// return failure();
|
||||
// }
|
||||
|
||||
bcastSizeTensors.push_back(dimSizeTensor);
|
||||
}
|
||||
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
|
||||
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims) {
|
||||
|
|
Loading…
Reference in New Issue