mirror of https://github.com/llvm/torch-mlir
Fix output size computation for MaxPool2D for ceil_model = true.
parent
bdbc64a205
commit
b1550a2463
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -116,6 +117,34 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
|
||||||
else
|
else
|
||||||
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
|
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
|
||||||
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
|
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
|
||||||
|
|
||||||
|
if (ceilMode) {
|
||||||
|
Value outMinusOneTimesStride =
|
||||||
|
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
|
||||||
|
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
|
||||||
|
loc, castIndexToInt64(b, loc, in), paddingInt);
|
||||||
|
|
||||||
|
auto reduceOutputDim =
|
||||||
|
b.createOrFold<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
|
||||||
|
outMinusOneTimesStride, inAddLeftPadding);
|
||||||
|
|
||||||
|
// Emit 'then' region of 'scf.if'
|
||||||
|
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
|
||||||
|
opBuilder.create<scf::YieldOp>(loc, division);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Emit 'else' region of 'scf.if'
|
||||||
|
auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
|
||||||
|
opBuilder.create<scf::YieldOp>(loc, out);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Emit 'scf.if' op
|
||||||
|
auto ifOp = b.create<scf::IfOp>(loc, reduceOutputDim, emitThenRegion,
|
||||||
|
emitElseRegion);
|
||||||
|
|
||||||
|
return castIntToIndex(b, loc, ifOp.getResult(0));
|
||||||
|
}
|
||||||
|
|
||||||
return castIntToIndex(b, loc, out);
|
return castIntToIndex(b, loc, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -527,8 +556,8 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
|
||||||
Type elementType) {
|
Type elementType) {
|
||||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||||
ValueRange payloadArgs) {
|
ValueRange payloadArgs) {
|
||||||
Value elem =
|
Value elem = mlir::torch::Torch::convertScalarToDtype(
|
||||||
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
|
builder, loc, payloadArgs[0], elementType);
|
||||||
builder.create<linalg::YieldOp>(loc, elem);
|
builder.create<linalg::YieldOp>(loc, elem);
|
||||||
};
|
};
|
||||||
return torch_to_linalg::createElementwiseLinalgGeneric(
|
return torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
|
|
|
@ -5252,9 +5252,11 @@ public:
|
||||||
} else {
|
} else {
|
||||||
int64_t dimSize =
|
int64_t dimSize =
|
||||||
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
||||||
if (ceilMode && (dimSize % stride != 0))
|
int64_t outputDim = dimSize / stride + 1;
|
||||||
return dimSize / stride + 2;
|
if (ceilMode && (dimSize % stride != 0) &&
|
||||||
return dimSize / stride + 1;
|
(outputDim * stride < inputDim + padBefore))
|
||||||
|
outputDim++;
|
||||||
|
return outputDim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -763,6 +763,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"LenStrModule_basic",
|
"LenStrModule_basic",
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||||
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||||
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
|
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
|
||||||
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
|
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
|
||||||
|
@ -2261,6 +2262,7 @@ TOSA_PASS_SET = {
|
||||||
"MatmulStaticBroadcast_basic",
|
"MatmulStaticBroadcast_basic",
|
||||||
"MaxPool2dEmptyStrideStaticModule_basic",
|
"MaxPool2dEmptyStrideStaticModule_basic",
|
||||||
"MaxPool2dStaticCeilModeTrueModule_basic",
|
"MaxPool2dStaticCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
"MeanModule_basic",
|
"MeanModule_basic",
|
||||||
"MmDagModule_basic",
|
"MmDagModule_basic",
|
||||||
|
@ -3000,6 +3002,7 @@ ONNX_XFAIL_SET = {
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||||
"MaxPool2dModule_basic",
|
"MaxPool2dModule_basic",
|
||||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||||
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||||
|
@ -4516,6 +4519,7 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"MaxPool1dCeilModeTrueModule_basic",
|
"MaxPool1dCeilModeTrueModule_basic",
|
||||||
"MaxPool1dModule_basic",
|
"MaxPool1dModule_basic",
|
||||||
"MaxPool2dCeilModeTrueModule_basic",
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
|
||||||
"MaxPool2dModule_basic",
|
"MaxPool2dModule_basic",
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesAllOnesModule_basic",
|
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||||
|
|
|
@ -420,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0))
|
module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mp2d = torch.nn.MaxPool2d(
|
||||||
|
kernel_size=6,
|
||||||
|
stride=6,
|
||||||
|
padding=3,
|
||||||
|
dilation=1,
|
||||||
|
ceil_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 6, 20, 10], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mp2d(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule()
|
||||||
|
)
|
||||||
|
def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue