mirror of https://github.com/llvm/torch-mlir
[Torch] support adaptive_max_pool1d when return_indices equals False (#3783)
parent
52ecff831b
commit
fa26bfc0d6
|
@ -7117,10 +7117,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
|
|||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||
ValueRange{constantOne});
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
|
||||
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
||||
if (op.getResult(1).use_empty()) {
|
||||
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
|
||||
loc, op.getType(0), input, kernelSizeList, strideList,
|
||||
paddingSizeList, dialationList,
|
||||
/*ceil_mode=*/constantFalse);
|
||||
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
|
||||
} else {
|
||||
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
|
||||
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
||||
paddingSizeList, dialationList,
|
||||
/*ceil_mode=*/constantFalse);
|
||||
rewriter.replaceOp(op, maxPool.getResults());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue