[Torch] support adaptive_max_pool1d when return_indices equals False (#3783)

byteir
yyp0 2024-10-11 23:42:15 +08:00 committed by Yuanqiang Liu
parent 52ecff831b
commit fa26bfc0d6
1 changed files with 13 additions and 4 deletions

View File

@ -7117,10 +7117,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
loc, Torch::ListType::get(Torch::IntType::get(context)),
ValueRange{constantOne});
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
if (op.getResult(1).use_empty()) {
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
loc, op.getType(0), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
} else {
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
paddingSizeList, dialationList,
/*ceil_mode=*/constantFalse);
rewriter.replaceOp(op, maxPool.getResults());
}
return success();
}
};