mirror of https://github.com/llvm/torch-mlir
parent
39d882f7c9
commit
919b599ebe
|
@ -565,15 +565,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
Value cstCeilMode =
|
Value cstCeilMode =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||||
|
|
||||||
if (rank == 3)
|
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
|
||||||
"Unimplemented: AtenMaxPool1dOp");
|
|
||||||
|
|
||||||
if (binder.op->getNumResults() == 2) {
|
if (binder.op->getNumResults() == 2) {
|
||||||
Torch::ValueTensorType resultTypeIndices;
|
Torch::ValueTensorType resultTypeIndices;
|
||||||
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
|
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
if (rank == 3)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp");
|
||||||
|
|
||||||
if (rank == 4) {
|
if (rank == 4) {
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(
|
||||||
binder.op, resultTypeOut, resultTypeIndices, operand,
|
binder.op, resultTypeOut, resultTypeIndices, operand,
|
||||||
|
@ -589,6 +589,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (rank == 3) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
|
||||||
|
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
|
||||||
|
paddingList, dilationsList, cstCeilMode);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
if (rank == 4) {
|
if (rank == 4) {
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
|
||||||
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
|
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
|
||||||
|
|
|
@ -2418,10 +2418,7 @@ ONNX_XFAIL_SET = {
|
||||||
"LogSoftmaxBackwardModule_basic",
|
"LogSoftmaxBackwardModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
|
||||||
"MaxPool1dStaticModule_basic",
|
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
"MaxPool2dModule_basic",
|
"MaxPool2dModule_basic",
|
||||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue