Fix output size computation for MaxPool2D for ceil_model = true.

pull/3890/head
Sayan Saha 2024-11-22 14:30:59 -05:00
parent bdbc64a205
commit b1550a2463
4 changed files with 69 additions and 5 deletions

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -116,6 +117,34 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
else else
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt); division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1); Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
if (ceilMode) {
Value outMinusOneTimesStride =
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
loc, castIndexToInt64(b, loc, in), paddingInt);
auto reduceOutputDim =
b.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
outMinusOneTimesStride, inAddLeftPadding);
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
opBuilder.create<scf::YieldOp>(loc, division);
};
// Emit 'else' region of 'scf.if'
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
opBuilder.create<scf::YieldOp>(loc, out);
};
// Emit 'scf.if' op
auto ifOp = b.create<scf::IfOp>(loc, reduceOutputDim, emitThenRegion,
emitElseRegion);
return castIntToIndex(b, loc, ifOp.getResult(0));
}
return castIntToIndex(b, loc, out); return castIntToIndex(b, loc, out);
} }
@ -527,8 +556,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
Type elementType) { Type elementType) {
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) { ValueRange payloadArgs) {
Value elem = Value elem = mlir::torch::Torch::convertScalarToDtype(
convertScalarToDtype(builder, loc, payloadArgs[0], elementType); builder, loc, payloadArgs[0], elementType);
builder.create<linalg::YieldOp>(loc, elem); builder.create<linalg::YieldOp>(loc, elem);
}; };
return torch_to_linalg::createElementwiseLinalgGeneric( return torch_to_linalg::createElementwiseLinalgGeneric(

View File

@ -5252,9 +5252,11 @@ public:
} else { } else {
int64_t dimSize = int64_t dimSize =
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
if (ceilMode && (dimSize % stride != 0)) int64_t outputDim = dimSize / stride + 1;
return dimSize / stride + 2; if (ceilMode && (dimSize % stride != 0) &&
return dimSize / stride + 1; (outputDim * stride < inputDim + padBefore))
outputDim++;
return outputDim;
} }
} }

View File

@ -763,6 +763,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"LenStrModule_basic", "LenStrModule_basic",
"MaxPool2dCeilModeTrueModule_basic", "MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
"MaxPool2dWithIndicesBackwardStatic3DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic",
@ -2261,6 +2262,7 @@ TOSA_PASS_SET = {
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
"MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dEmptyStrideStaticModule_basic",
"MaxPool2dStaticCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic",
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MeanModule_basic", "MeanModule_basic",
"MmDagModule_basic", "MmDagModule_basic",
@ -3000,6 +3002,7 @@ ONNX_XFAIL_SET = {
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic", "MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic", "MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
"MaxPool2dModule_basic", "MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic",
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
@ -4516,6 +4519,7 @@ ONNX_TOSA_XFAIL_SET = {
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic", "MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic", "MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
"MaxPool2dModule_basic", "MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic",

View File

@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0))
class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp2d = torch.nn.MaxPool2d(
kernel_size=6,
stride=6,
padding=3,
dilation=1,
ceil_mode=True,
)
@export
@annotate_args(
[
None,
([2, 6, 20, 10], torch.float32, True),
]
)
def forward(self, x):
return self.mp2d(x)
@register_test_case(
module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule()
)
def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0))
# ============================================================================== # ==============================================================================