[stablehlo]: fix aten.index_put_hacked_twin lowering to StableHlo (#3572)

Current StableHlo lowering strategy works well when `src` tensor's rank
is no bigger than `dst` tensor's. The new patch make it succeed in other
cases. The following is an example.
```
%190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor>
%191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32>
```
pull/3576/head
Jiawei Wu 2024-07-31 22:33:57 +08:00 committed by GitHub
parent f49b9c14f1
commit 7b2902f6e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 3 deletions

View File

@ -855,8 +855,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputType = cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
auto valuesType = cast<RankedTensorType>(values.getType());
int64_t valueRank = valuesType.getRank();
auto valuesShape = valuesType.getShape();
bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
@ -868,6 +868,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
if (!getListConstructElements(indexList, indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
int64_t indexCnt = indicesTorchType.size();
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
indicesTorchType);
@ -886,11 +887,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
SmallVector<int64_t> scatterDimOperandDimMap;
SmallVector<int64_t> insertedWindowDims;
SmallVector<int64_t> updateWindowDims;
for (int64_t i = 0; i < maxIndexRank; ++i) {
for (int64_t i = 0; i < indexCnt; ++i) {
scatterDimOperandDimMap.push_back(i);
insertedWindowDims.push_back(i);
}
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
for (int64_t i = maxIndexRank; i < valueRank; ++i) {
updateWindowDims.push_back(i);
}
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(