mirror of https://github.com/llvm/torch-mlir
[stablehlo]: fix aten.index_put_hacked_twin lowering to StableHlo (#3572)
Current StableHlo lowering strategy works well when `src` tensor's rank is no bigger than `dst` tensor's. The new patch make it succeed in other cases. The following is an example. ``` %190 = torch.prim.ListConstruct %arg4 : (!torch.vtensor<[1,1024],si64>) -> !torch.list<vtensor> %191 = torch.aten.index_put.hacked_twin %189, %190, %186, %true : !torch.vtensor<[1024,768],f32>, !torch.list<vtensor>, !torch.vtensor<[1,1024,768],f32>, !torch.bool -> !torch.vtensor<[1024,768],f32> ```pull/3576/head
parent
f49b9c14f1
commit
7b2902f6e2
|
@ -855,8 +855,8 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
auto outType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
int64_t inputRank = inputType.getRank();
|
||||
auto valuesType = cast<RankedTensorType>(values.getType());
|
||||
int64_t valueRank = valuesType.getRank();
|
||||
auto valuesShape = valuesType.getShape();
|
||||
bool accumulate;
|
||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
||||
|
@ -868,6 +868,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
if (!getListConstructElements(indexList, indicesTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
int64_t indexCnt = indicesTorchType.size();
|
||||
|
||||
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||||
indicesTorchType);
|
||||
|
@ -886,11 +887,11 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
SmallVector<int64_t> scatterDimOperandDimMap;
|
||||
SmallVector<int64_t> insertedWindowDims;
|
||||
SmallVector<int64_t> updateWindowDims;
|
||||
for (int64_t i = 0; i < maxIndexRank; ++i) {
|
||||
for (int64_t i = 0; i < indexCnt; ++i) {
|
||||
scatterDimOperandDimMap.push_back(i);
|
||||
insertedWindowDims.push_back(i);
|
||||
}
|
||||
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
||||
for (int64_t i = maxIndexRank; i < valueRank; ++i) {
|
||||
updateWindowDims.push_back(i);
|
||||
}
|
||||
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||
|
|
Loading…
Reference in New Issue