onnx.MaxPool add atenMaxPool1d lowering support (#3452)

fixes #3422
pull/3406/merge
Phaneesh Barwaria 2024-06-13 15:37:11 +05:30 committed by GitHub
parent 39d882f7c9
commit 919b599ebe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 7 deletions

View File

@ -565,15 +565,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value cstCeilMode = Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
if (rank == 3)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: AtenMaxPool1dOp");
if (binder.op->getNumResults() == 2) { if (binder.op->getNumResults() == 2) {
Torch::ValueTensorType resultTypeIndices; Torch::ValueTensorType resultTypeIndices;
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1)) if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
return failure(); return failure();
if (rank == 3)
return rewriter.notifyMatchFailure(
binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp");
if (rank == 4) { if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>( rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(
binder.op, resultTypeOut, resultTypeIndices, operand, binder.op, resultTypeOut, resultTypeIndices, operand,
@ -589,6 +589,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success(); return success();
} }
} else { } else {
if (rank == 3) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
paddingList, dilationsList, cstCeilMode);
return success();
}
if (rank == 4) { if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>( rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
binder.op, resultTypeOut, operand, kernelSizeList, stridesList, binder.op, resultTypeOut, operand, kernelSizeList, stridesList,

View File

@ -2418,10 +2418,7 @@ ONNX_XFAIL_SET = {
"LogSoftmaxBackwardModule_basic", "LogSoftmaxBackwardModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic", "MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic", "MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic", "MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic",