mirror of https://github.com/llvm/torch-mlir
[Torch] support adaptive_max_pool1d when return_indices equals False (#3783)
parent
52ecff831b
commit
fa26bfc0d6
|
@ -7117,10 +7117,19 @@ class DecomposeAtenAdaptiveMaxPool1dOp
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
loc, Torch::ListType::get(Torch::IntType::get(context)),
|
||||||
ValueRange{constantOne});
|
ValueRange{constantOne});
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
|
if (op.getResult(1).use_empty()) {
|
||||||
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
|
||||||
paddingSizeList, dialationList,
|
loc, op.getType(0), input, kernelSizeList, strideList,
|
||||||
/*ceil_mode=*/constantFalse);
|
paddingSizeList, dialationList,
|
||||||
|
/*ceil_mode=*/constantFalse);
|
||||||
|
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
|
||||||
|
} else {
|
||||||
|
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
|
||||||
|
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
|
||||||
|
paddingSizeList, dialationList,
|
||||||
|
/*ceil_mode=*/constantFalse);
|
||||||
|
rewriter.replaceOp(op, maxPool.getResults());
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue