mirror of https://github.com/llvm/torch-mlir
parent
39d882f7c9
commit
919b599ebe
|
@ -565,15 +565,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value cstCeilMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||
|
||||
if (rank == 3)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: AtenMaxPool1dOp");
|
||||
|
||||
if (binder.op->getNumResults() == 2) {
|
||||
Torch::ValueTensorType resultTypeIndices;
|
||||
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
|
||||
return failure();
|
||||
|
||||
if (rank == 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp");
|
||||
|
||||
if (rank == 4) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(
|
||||
binder.op, resultTypeOut, resultTypeIndices, operand,
|
||||
|
@ -589,6 +589,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
}
|
||||
} else {
|
||||
if (rank == 3) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
|
||||
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
|
||||
paddingList, dilationsList, cstCeilMode);
|
||||
return success();
|
||||
}
|
||||
if (rank == 4) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
||||
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
|
||||
|
|
|
@ -2418,10 +2418,7 @@ ONNX_XFAIL_SET = {
|
|||
"LogSoftmaxBackwardModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||
|
|
Loading…
Reference in New Issue