mirror of https://github.com/llvm/torch-mlir
[stablehlo] fix: enhance torch's index-like op lowering to stablehlo's gather/scatter (#3829)
In torch.index_put like ops, `values` is only required to be broadcastable to `input[indices]`, rather than exact dimension match. This patch fixes the problem by add additional stablehlo.dynamic_broadcast_in_dim before creating stablehlo.scatter op. BTW, this patch also enhance the `getBroadcastResultShape` utility in hlo namespace.pull/3804/head
parent
4c1518d365
commit
b75d0e3f8b
|
@ -52,9 +52,9 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
Type outElementType);
|
Type outElementType);
|
||||||
|
|
||||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
FailureOr<std::pair<Value, SmallVector<int64_t>>>
|
||||||
Operation *op, ArrayRef<Value> tensors,
|
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
|
||||||
size_t dimSizeIndexBits);
|
ArrayRef<Value> tensors, size_t dimSizeIndexBits);
|
||||||
|
|
||||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
TensorType outType,
|
TensorType outType,
|
||||||
|
|
|
@ -220,16 +220,10 @@ namespace {
|
||||||
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
SmallVector<Value> indexTensors,
|
SmallVector<Value> indexTensors,
|
||||||
llvm::ArrayRef<int64_t> inputShape,
|
|
||||||
size_t dimSizeIndexBits,
|
size_t dimSizeIndexBits,
|
||||||
int &maxIndexRank) {
|
int &maxIndexRank) {
|
||||||
// Step 1: broadcast indices tensors
|
// Step 1: broadcast indices tensors
|
||||||
SmallVector<int64_t> indicesShape;
|
|
||||||
SmallVector<int64_t> expandShape;
|
|
||||||
SmallVector<int64_t> concatShape;
|
|
||||||
|
|
||||||
bool allIndexStaticShape = true;
|
bool allIndexStaticShape = true;
|
||||||
Value bcastSizeTensor;
|
|
||||||
|
|
||||||
// concat index tensor into to indices tensor for concat
|
// concat index tensor into to indices tensor for concat
|
||||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||||
|
@ -242,20 +236,15 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
||||||
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
|
maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!allIndexStaticShape) {
|
auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors,
|
||||||
auto bcastSizeTensorInfo = hlo::getBroadcastResultShape(
|
dimSizeIndexBits);
|
||||||
rewriter, op, indexTensors, dimSizeIndexBits);
|
if (failed(bcastSizeInfo)) {
|
||||||
if (failed(bcastSizeTensorInfo)) {
|
return failure();
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
bcastSizeTensor = *bcastSizeTensorInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < maxIndexRank; i++) {
|
|
||||||
indicesShape.push_back(inputShape[i]);
|
|
||||||
expandShape.push_back(inputShape[i]);
|
|
||||||
concatShape.push_back(inputShape[i]);
|
|
||||||
}
|
}
|
||||||
|
Value bcastSizeTensor = (*bcastSizeInfo).first;
|
||||||
|
auto indicesShape = (*bcastSizeInfo).second;
|
||||||
|
SmallVector<int64_t> expandShape(indicesShape.begin(), indicesShape.end());
|
||||||
|
SmallVector<int64_t> concatShape(indicesShape.begin(), indicesShape.end());
|
||||||
expandShape.push_back(1);
|
expandShape.push_back(1);
|
||||||
concatShape.push_back(indexTensors.size());
|
concatShape.push_back(indexTensors.size());
|
||||||
|
|
||||||
|
@ -879,7 +868,6 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||||
auto outType =
|
auto outType =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto outShape = outType.getShape();
|
|
||||||
Value indexList = op.getIndices();
|
Value indexList = op.getIndices();
|
||||||
SmallVector<Value> indicesTorchType;
|
SmallVector<Value> indicesTorchType;
|
||||||
if (!getListConstructElements(indexList, indicesTorchType))
|
if (!getListConstructElements(indexList, indicesTorchType))
|
||||||
|
@ -890,9 +878,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
indicesTorchType);
|
indicesTorchType);
|
||||||
|
|
||||||
int maxIndexRank = -1;
|
int maxIndexRank = -1;
|
||||||
auto gatherIndicesInfo =
|
auto gatherIndicesInfo = broadcastAndConcatIndices(
|
||||||
broadcastAndConcatIndices(op, rewriter, indexTensors, outShape,
|
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
|
||||||
options.dimSizeIndexBits, maxIndexRank);
|
|
||||||
if (failed(gatherIndicesInfo)) {
|
if (failed(gatherIndicesInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to generate broadcasted indices");
|
op, "failed to generate broadcasted indices");
|
||||||
|
@ -949,6 +936,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
auto outType =
|
auto outType =
|
||||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputShape = inputType.getShape();
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
auto valuesType = cast<RankedTensorType>(values.getType());
|
auto valuesType = cast<RankedTensorType>(values.getType());
|
||||||
int64_t valueRank = valuesType.getRank();
|
int64_t valueRank = valuesType.getRank();
|
||||||
auto valuesShape = valuesType.getShape();
|
auto valuesShape = valuesType.getShape();
|
||||||
|
@ -968,15 +957,58 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
indicesTorchType);
|
indicesTorchType);
|
||||||
|
|
||||||
int maxIndexRank = -1;
|
int maxIndexRank = -1;
|
||||||
auto scatterIndicesInfo =
|
auto scatterIndicesInfo = broadcastAndConcatIndices(
|
||||||
broadcastAndConcatIndices(op, rewriter, indexTensors, valuesShape,
|
op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank);
|
||||||
options.dimSizeIndexBits, maxIndexRank);
|
|
||||||
if (failed(scatterIndicesInfo)) {
|
if (failed(scatterIndicesInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to generate broadcasted indices");
|
op, "failed to generate broadcasted indices");
|
||||||
}
|
}
|
||||||
auto scatterIndices = *scatterIndicesInfo;
|
auto scatterIndices = *scatterIndicesInfo;
|
||||||
|
|
||||||
|
// broadcast `values` tensor to match expectedValuesShape.
|
||||||
|
SmallVector<int64_t> scatterIndicesDims;
|
||||||
|
for (int64_t i = 0; i < maxIndexRank; ++i) {
|
||||||
|
scatterIndicesDims.push_back(i);
|
||||||
|
}
|
||||||
|
auto expectedValuesShapeTensorInfo =
|
||||||
|
hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
|
if (failed(expectedValuesShapeTensorInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get shape of broadcasted indices");
|
||||||
|
}
|
||||||
|
auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo;
|
||||||
|
SmallVector<int64_t> trailingInputDims;
|
||||||
|
for (int64_t i = indexCnt; i < inputRank; ++i) {
|
||||||
|
trailingInputDims.push_back(i);
|
||||||
|
}
|
||||||
|
auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor(
|
||||||
|
rewriter, op, input, trailingInputDims, options.dimSizeIndexBits);
|
||||||
|
if (failed(trailingInputShapeTensorInfo)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "failed to get shape of input");
|
||||||
|
}
|
||||||
|
expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(),
|
||||||
|
(*trailingInputShapeTensorInfo).end());
|
||||||
|
|
||||||
|
llvm::ArrayRef<int64_t> scatterIndicesShape =
|
||||||
|
(cast<RankedTensorType>(scatterIndices.getType())).getShape();
|
||||||
|
SmallVector<int64_t> expectedValuesShape(
|
||||||
|
scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank);
|
||||||
|
for (int64_t i = indexCnt; i < inputRank; i++) {
|
||||||
|
expectedValuesShape.push_back(inputShape[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
valuesType =
|
||||||
|
RankedTensorType::get(expectedValuesShape, valuesType.getElementType());
|
||||||
|
values =
|
||||||
|
hlo::promoteAndBroadcast(rewriter, values, valuesType,
|
||||||
|
rewriter
|
||||||
|
.create<tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), expectedValuesShapeTensors)
|
||||||
|
.getResult());
|
||||||
|
valueRank = valuesType.getRank();
|
||||||
|
valuesShape = valuesType.getShape();
|
||||||
|
|
||||||
// create stablehlo::ScatterOp
|
// create stablehlo::ScatterOp
|
||||||
int64_t indexVecDim = maxIndexRank;
|
int64_t indexVecDim = maxIndexRank;
|
||||||
SmallVector<int64_t> scatterDimOperandDimMap;
|
SmallVector<int64_t> scatterDimOperandDimMap;
|
||||||
|
@ -1216,9 +1248,9 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
|
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
|
||||||
|
|
||||||
int maxIndexRank = -1;
|
int maxIndexRank = -1;
|
||||||
auto gatherIndicesInfo = broadcastAndConcatIndices(
|
auto gatherIndicesInfo =
|
||||||
input.getDefiningOp(), rewriter, indexTensors, outType.getShape(),
|
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
|
||||||
dimSizeIndexBits, maxIndexRank);
|
dimSizeIndexBits, maxIndexRank);
|
||||||
auto gatherIndices = *gatherIndicesInfo;
|
auto gatherIndices = *gatherIndicesInfo;
|
||||||
int64_t numIndicesDim = indexTensors.size();
|
int64_t numIndicesDim = indexTensors.size();
|
||||||
int64_t indexVecDim = maxIndexRank;
|
int64_t indexVecDim = maxIndexRank;
|
||||||
|
|
|
@ -322,9 +322,9 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
||||||
return getDimIndexOfTensor(rewriter, op, value, dims);
|
return getDimIndexOfTensor(rewriter, op, value, dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
FailureOr<std::pair<Value, SmallVector<int64_t>>>
|
||||||
Operation *op, ArrayRef<Value> tensors,
|
getBroadcastResultShape(PatternRewriter &rewriter, Operation *op,
|
||||||
size_t dimSizeIndexBits) {
|
ArrayRef<Value> tensors, size_t dimSizeIndexBits) {
|
||||||
SmallVector<ArrayRef<int64_t>> tensorSizes;
|
SmallVector<ArrayRef<int64_t>> tensorSizes;
|
||||||
|
|
||||||
int maxRank = 0;
|
int maxRank = 0;
|
||||||
|
@ -337,10 +337,11 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> bcastSizeTensors;
|
SmallVector<Value> bcastSizeTensors;
|
||||||
|
SmallVector<int64_t> bcastSizes;
|
||||||
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
|
for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions.
|
||||||
int dynamicDimCnt = 0;
|
int dynamicDimCnt = 0;
|
||||||
int staticDimCnt = 0;
|
int staticDimCnt = 0;
|
||||||
int64_t staticDimSize;
|
int64_t dimSize = -1;
|
||||||
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
|
Value dimSizeTensor = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
||||||
|
@ -351,12 +352,16 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// dim size: 1
|
// dim size: 1
|
||||||
if (tensorSizes[i][inDim] == 1)
|
if (tensorSizes[i][inDim] == 1) {
|
||||||
|
if (dimSize == -1)
|
||||||
|
dimSize = 1;
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
// dim size: dynamic
|
// dim size: dynamic
|
||||||
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
|
if (tensorSizes[i][inDim] == ShapedType::kDynamic ||
|
||||||
tensorSizes[i][inDim] == kUnknownSize) {
|
tensorSizes[i][inDim] == kUnknownSize) {
|
||||||
dynamicDimCnt++;
|
dynamicDimCnt++;
|
||||||
|
dimSize = ShapedType::kDynamic;
|
||||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||||
if (failed(dimSizeTensorInfo)) {
|
if (failed(dimSizeTensorInfo)) {
|
||||||
|
@ -371,12 +376,12 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
// we already found static dim size not equal with this, fail.
|
// we already found static dim size not equal with this, fail.
|
||||||
if (staticDimCnt > 0 && staticDimSize != tensorSizes[i][inDim]) {
|
if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
staticDimCnt++;
|
staticDimCnt++;
|
||||||
staticDimSize = tensorSizes[i][inDim];
|
dimSize = tensorSizes[i][inDim];
|
||||||
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
auto dimSizeTensorInfo = hlo::getDimSizesOfTensor(
|
||||||
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
rewriter, op, tensors[i], {inDim}, dimSizeIndexBits);
|
||||||
if (failed(dimSizeTensorInfo)) {
|
if (failed(dimSizeTensorInfo)) {
|
||||||
|
@ -389,12 +394,15 @@ FailureOr<Value> getBroadcastResultShape(PatternRewriter &rewriter,
|
||||||
// if (dynamicDimCnt > 1) {
|
// if (dynamicDimCnt > 1) {
|
||||||
// return failure();
|
// return failure();
|
||||||
// }
|
// }
|
||||||
|
bcastSizes.push_back(dimSize);
|
||||||
bcastSizeTensors.push_back(dimSizeTensor);
|
bcastSizeTensors.push_back(dimSizeTensor);
|
||||||
}
|
}
|
||||||
|
std::reverse(bcastSizes.begin(), bcastSizes.end());
|
||||||
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
|
std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end());
|
||||||
return rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
|
return std::pair<Value, SmallVector<int64_t>>(
|
||||||
.getResult();
|
rewriter.create<tensor::FromElementsOp>(op->getLoc(), bcastSizeTensors)
|
||||||
|
.getResult(),
|
||||||
|
bcastSizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
|
@ -760,6 +760,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
"IndexPutImplIndexWithNoneModule_basic",
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
"IndexSelectRank0IdxModule_basic",
|
"IndexSelectRank0IdxModule_basic",
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"IntFloatModule_basic",
|
"IntFloatModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue