mirror of https://github.com/llvm/torch-mlir
Fix output size computation for MaxPool2D for ceil_model = true.
parent
bdbc64a205
commit
b1550a2463
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -116,6 +117,34 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
|
|||
else
|
||||
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
|
||||
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
|
||||
|
||||
if (ceilMode) {
|
||||
Value outMinusOneTimesStride =
|
||||
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
|
||||
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
|
||||
loc, castIndexToInt64(b, loc, in), paddingInt);
|
||||
|
||||
auto reduceOutputDim =
|
||||
b.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
|
||||
outMinusOneTimesStride, inAddLeftPadding);
|
||||
|
||||
// Emit 'then' region of 'scf.if'
|
||||
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
|
||||
opBuilder.create<scf::YieldOp>(loc, division);
|
||||
};
|
||||
|
||||
// Emit 'else' region of 'scf.if'
|
||||
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
|
||||
opBuilder.create<scf::YieldOp>(loc, out);
|
||||
};
|
||||
|
||||
// Emit 'scf.if' op
|
||||
auto ifOp = b.create<scf::IfOp>(loc, reduceOutputDim, emitThenRegion,
|
||||
emitElseRegion);
|
||||
|
||||
return castIntToIndex(b, loc, ifOp.getResult(0));
|
||||
}
|
||||
|
||||
return castIntToIndex(b, loc, out);
|
||||
}
|
||||
|
||||
|
@ -527,8 +556,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
|
|||
Type elementType) {
|
||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem =
|
||||
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
|
||||
Value elem = mlir::torch::Torch::convertScalarToDtype(
|
||||
builder, loc, payloadArgs[0], elementType);
|
||||
builder.create<linalg::YieldOp>(loc, elem);
|
||||
};
|
||||
return torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
|
|
|
@ -5252,9 +5252,11 @@ public:
|
|||
} else {
|
||||
int64_t dimSize =
|
||||
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
||||
if (ceilMode && (dimSize % stride != 0))
|
||||
return dimSize / stride + 2;
|
||||
return dimSize / stride + 1;
|
||||
int64_t outputDim = dimSize / stride + 1;
|
||||
if (ceilMode && (dimSize % stride != 0) &&
|
||||
(outputDim * stride < inputDim + padBefore))
|
||||
outputDim++;
|
||||
return outputDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -763,6 +763,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"LenStrModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
|
||||
|
@ -2261,6 +2262,7 @@ TOSA_PASS_SET = {
|
|||
"MatmulStaticBroadcast_basic",
|
||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||
"MaxPool2dStaticModule_basic",
|
||||
"MeanModule_basic",
|
||||
"MmDagModule_basic",
|
||||
|
@ -3000,6 +3002,7 @@ ONNX_XFAIL_SET = {
|
|||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||
|
@ -4516,6 +4519,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||
|
|
|
@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0))
|
||||
|
||||
|
||||
class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp2d = torch.nn.MaxPool2d(
|
||||
kernel_size=6,
|
||||
stride=6,
|
||||
padding=3,
|
||||
dilation=1,
|
||||
ceil_mode=True,
|
||||
)
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6, 20, 10], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule()
|
||||
)
|
||||
def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue