mirror of https://github.com/llvm/torch-mlir
[ONNX] Fix AveragePool attributes support (#3235)
Issues was found here https://github.com/nod-ai/SHARK-Turbine/issues/643 - [ONNX] Fix padding attributes for onnx.AveragePool - [Linalg] Add countIncludePad false support for AtenAvgPool1/2dOp - [Linalg] Add an avg_pool2d countIncludePad False e2e tests - [Linalg] Fix conflict with AtenAvgPool3dOp - [Linalg] Fix e2e crash with AtenAvgPool1dOp - [Linalg] Add dynamic dim support for AtenAvgPool2dOp - [Linalg] Fix AvgPool2dDivisorOverrideModule crashpull/3455/head
parent
41d04a8995
commit
ae6f5e8251
|
@ -441,9 +441,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
for (int64_t i : padding) {
|
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
|
||||||
|
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
|
||||||
|
// axes x.
|
||||||
|
int64_t paddingSizeHalf = padding.size() / 2;
|
||||||
|
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
|
||||||
|
// Check if onnx padding attribute is symmetric.
|
||||||
|
if (padding[i] != padding[i + paddingSizeHalf])
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "onnx padding attribute is not symmetric");
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||||
}
|
}
|
||||||
for (int64_t i : strides) {
|
for (int64_t i : strides) {
|
||||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
|
|
@ -619,13 +619,6 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "count_include_pad must be a constant");
|
op, "count_include_pad must be a constant");
|
||||||
|
|
||||||
// If the padding is zero then there is no padding to include.
|
|
||||||
if (!countIncludePad &&
|
|
||||||
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: count_include_pad is expected to be true");
|
|
||||||
}
|
|
||||||
|
|
||||||
// `sumPool` contains the result of sumpool operation over the input.
|
// `sumPool` contains the result of sumpool operation over the input.
|
||||||
Value sumPool, paddedInput;
|
Value sumPool, paddedInput;
|
||||||
SmallVector<Value, Dim + 2> outTensorShape;
|
SmallVector<Value, Dim + 2> outTensorShape;
|
||||||
|
@ -635,9 +628,142 @@ public:
|
||||||
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
|
paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType),
|
||||||
outTensorShape, paddedInput, sumPool)))
|
outTensorShape, paddedInput, sumPool)))
|
||||||
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
|
return rewriter.notifyMatchFailure(op, "unable to compute sumpool");
|
||||||
// }
|
|
||||||
|
|
||||||
Value divisor = kernelSizeIntValues[0];
|
// Compute the average of sumPool.
|
||||||
|
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
||||||
|
SmallVector<AffineMap> indexingMapsAvg(
|
||||||
|
2, rewriter.getMultiDimIdentityMap(Dim + 2));
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
||||||
|
Dim + 2, utils::IteratorType::parallel);
|
||||||
|
Value avgPool;
|
||||||
|
Value divisor;
|
||||||
|
// Case1: AtenAvgPool1d/2dOp with countIncludePad=false support.
|
||||||
|
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||||
|
auto selfType = cast<RankedTensorType>(self.getType());
|
||||||
|
const int64_t selfRank = selfType.getRank();
|
||||||
|
int64_t wDim = toPositiveDim(-1, selfRank);
|
||||||
|
int64_t hDim = toPositiveDim(-2, selfRank);
|
||||||
|
Value inputHeight = getDimOp(rewriter, loc, self, hDim);
|
||||||
|
Value inputWidth = getDimOp(rewriter, loc, self, wDim);
|
||||||
|
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
|
||||||
|
const int64_t rank = sumPoolType.getRank();
|
||||||
|
int dimH = toPositiveDim(-2, rank);
|
||||||
|
int dimW = toPositiveDim(-1, rank);
|
||||||
|
avgPool =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, outputTensor.getType(), sumPool, outputTensor,
|
||||||
|
/*indexingMaps=*/indexingMapsAvg,
|
||||||
|
/*iteratorTypes=*/iteratorTypesAvg,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
// The algorithm for computing the divisor with
|
||||||
|
// count_include_pad is manily based on pytorch
|
||||||
|
// implementation. The following code is comment
|
||||||
|
// with pytorch code.
|
||||||
|
// https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78
|
||||||
|
Value indexOh =
|
||||||
|
b.create<linalg::IndexOp>(loc, /*value=*/dimH);
|
||||||
|
Value oh = castIndexToInt64(b, loc, indexOh);
|
||||||
|
Value indexOw =
|
||||||
|
b.create<linalg::IndexOp>(loc, /*value=*/dimW);
|
||||||
|
Value ow = castIndexToInt64(b, loc, indexOw);
|
||||||
|
|
||||||
|
// int64_t ih0 = oh * dH - padH;
|
||||||
|
Value dH = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(strideInts[0]));
|
||||||
|
Value padH = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(paddingInts[0]));
|
||||||
|
Value ohDH = b.create<arith::MulIOp>(loc, oh, dH);
|
||||||
|
Value ih0 = b.create<arith::SubIOp>(loc, ohDH, padH);
|
||||||
|
// int64_t iw0 = ow * dW - padW;
|
||||||
|
Value dW = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(strideInts[1]));
|
||||||
|
Value padW = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
|
||||||
|
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
|
||||||
|
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
|
||||||
|
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
|
||||||
|
Value ih = castIndexToInt64(b, loc, inputHeight);
|
||||||
|
Value ih0KH = b.create<arith::AddIOp>(
|
||||||
|
loc, ih0, kernelSizeIntValues[0]);
|
||||||
|
Value ihPadH = b.create<arith::AddIOp>(loc, ih, padH);
|
||||||
|
Value ih1 = b.create<arith::MinSIOp>(loc, ih0KH, ihPadH);
|
||||||
|
// int64_t iw1 = std::min(iw0 + kW, input_width + padW);
|
||||||
|
Value iw = castIndexToInt64(b, loc, inputWidth);
|
||||||
|
Value iw0KW = b.create<arith::AddIOp>(
|
||||||
|
loc, iw0, kernelSizeIntValues[1]);
|
||||||
|
Value iwPadW = b.create<arith::AddIOp>(loc, iw, padW);
|
||||||
|
Value iw1 = b.create<arith::MinSIOp>(loc, iw0KW, iwPadW);
|
||||||
|
// int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
|
||||||
|
Value ih1Ih0 = b.create<arith::SubIOp>(loc, ih1, ih0);
|
||||||
|
Value iw1Iw0 = b.create<arith::SubIOp>(loc, iw1, iw0);
|
||||||
|
Value poolSize =
|
||||||
|
b.create<arith::MulIOp>(loc, ih1Ih0, iw1Iw0);
|
||||||
|
// ih0 = std::max(ih0, 0);
|
||||||
|
Value cstZero = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value ih0Clamped =
|
||||||
|
b.create<arith::MaxSIOp>(loc, ih0, cstZero);
|
||||||
|
// iw0 = std::max(iw0, 0);
|
||||||
|
Value iw0Clamped =
|
||||||
|
b.create<arith::MaxSIOp>(loc, iw0, cstZero);
|
||||||
|
// ih1 = std::min(ih1, input_height);
|
||||||
|
Value ih1Clamped = b.create<arith::MinSIOp>(loc, ih1, ih);
|
||||||
|
// iw1 = std::min(iw1, input_width);
|
||||||
|
Value iw1Clamped = b.create<arith::MinSIOp>(loc, iw1, iw);
|
||||||
|
// if (divisor_override.has_value()) {
|
||||||
|
// divisor = divisor_override.value();
|
||||||
|
// } else {
|
||||||
|
// if(count_include_pad) {
|
||||||
|
// divisor = pool_size;
|
||||||
|
// } else {
|
||||||
|
// divisor = (ih1 - ih0) * (iw1 - iw0);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
if (countIncludePad) {
|
||||||
|
divisor = convertScalarToDtype(b, loc, poolSize,
|
||||||
|
resultElementType);
|
||||||
|
} else {
|
||||||
|
Value ih1_ih0 =
|
||||||
|
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
|
||||||
|
Value iw1_iw0 =
|
||||||
|
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
|
||||||
|
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
|
||||||
|
}
|
||||||
|
// AtenAvgPool2/3dOp has an optional divisor_override
|
||||||
|
// attribute while AtenAvgPool1dOp does not.
|
||||||
|
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||||
|
if (!isa<Torch::NoneType>(
|
||||||
|
op.getDivisorOverride().getType()))
|
||||||
|
divisor = adaptor.getDivisorOverride();
|
||||||
|
}
|
||||||
|
|
||||||
|
divisor = convertScalarToDtype(b, loc, divisor,
|
||||||
|
resultElementType);
|
||||||
|
Value avg;
|
||||||
|
if (isa<mlir::IntegerType>(resultElementType))
|
||||||
|
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
||||||
|
else if (isa<mlir::FloatType>(resultElementType))
|
||||||
|
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
||||||
|
b.create<linalg::YieldOp>(loc, avg);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add support for count_include_pad equal to `False` in
|
||||||
|
// AtenAvgPool1/3dOp.
|
||||||
|
if (!countIncludePad &&
|
||||||
|
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: count_include_pad is expected to be true for "
|
||||||
|
"AtenAvgPool3dOp");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`.
|
||||||
|
divisor = kernelSizeIntValues[0];
|
||||||
for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) {
|
for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) {
|
||||||
divisor =
|
divisor =
|
||||||
rewriter.create<arith::MulIOp>(loc, divisor, kernelSizeIntValues[i]);
|
rewriter.create<arith::MulIOp>(loc, divisor, kernelSizeIntValues[i]);
|
||||||
|
@ -648,29 +774,20 @@ public:
|
||||||
: adaptor.getDivisorOverride();
|
: adaptor.getDivisorOverride();
|
||||||
}
|
}
|
||||||
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType);
|
||||||
|
avgPool = rewriter
|
||||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
.create<linalg::GenericOp>(
|
||||||
loc, getAsOpFoldResult(outTensorShape), resultElementType);
|
loc, outputTensor.getType(), sumPool, outputTensor,
|
||||||
SmallVector<AffineMap> indexingMapsAvg(
|
/*indexingMaps=*/indexingMapsAvg,
|
||||||
2, rewriter.getMultiDimIdentityMap(Dim + 2));
|
/*iteratorTypes=*/iteratorTypesAvg,
|
||||||
SmallVector<utils::IteratorType> iteratorTypesAvg(
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Dim + 2, utils::IteratorType::parallel);
|
Value avg;
|
||||||
Value avgPool =
|
if (isa<mlir::IntegerType>(resultElementType))
|
||||||
rewriter
|
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
||||||
.create<linalg::GenericOp>(
|
else if (isa<mlir::FloatType>(resultElementType))
|
||||||
loc, outputTensor.getType(), sumPool, outputTensor,
|
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
||||||
/*indexingMaps=*/indexingMapsAvg,
|
b.create<linalg::YieldOp>(loc, avg);
|
||||||
/*iteratorTypes=*/iteratorTypesAvg,
|
})
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
.getResult(0);
|
||||||
Value avg;
|
|
||||||
if (isa<mlir::IntegerType>(resultElementType))
|
|
||||||
avg = b.create<arith::DivSIOp>(loc, args[0], divisor);
|
|
||||||
else if (isa<mlir::FloatType>(resultElementType))
|
|
||||||
avg = b.create<arith::DivFOp>(loc, args[0], divisor);
|
|
||||||
b.create<linalg::YieldOp>(loc, avg);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, avgPool);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -888,6 +888,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"Aten_CastLongModule_basic",
|
"Aten_CastLongModule_basic",
|
||||||
"AvgPool1dStaticModule_basic",
|
"AvgPool1dStaticModule_basic",
|
||||||
"AvgPool2dStaticModule_basic",
|
"AvgPool2dStaticModule_basic",
|
||||||
|
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||||
"AvgPool3dStaticModule_basic",
|
"AvgPool3dStaticModule_basic",
|
||||||
"BaddbmmBroadcast1DInputModule_basic",
|
"BaddbmmBroadcast1DInputModule_basic",
|
||||||
"BaddbmmBroadcast2DInputModule_basic",
|
"BaddbmmBroadcast2DInputModule_basic",
|
||||||
|
@ -1479,6 +1480,7 @@ STABLEHLO_CRASHING_SET = set()
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||||
"TensorSplitSections_GetItemModule_basic",
|
"TensorSplitSections_GetItemModule_basic",
|
||||||
"TensorSplitSections_ListUnpackModule_basic",
|
"TensorSplitSections_ListUnpackModule_basic",
|
||||||
"AtenLinear2D_basic",
|
"AtenLinear2D_basic",
|
||||||
|
@ -1950,6 +1952,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"AvgPool2dCountIncludePadFalseStaticModule_basic",
|
||||||
"AtenLinear1D_basic",
|
"AtenLinear1D_basic",
|
||||||
"AtenLinearMatVec_basic",
|
"AtenLinearMatVec_basic",
|
||||||
"AtenLinearVecMatBias_basic",
|
"AtenLinearVecMatBias_basic",
|
||||||
|
|
|
@ -1017,6 +1017,35 @@ def AvgPool2dStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 2, 10, 20, low=-1))
|
module.forward(tu.rand(2, 2, 10, 20, low=-1))
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool2dCountIncludePadFalseStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ap2d = torch.nn.AvgPool2d(
|
||||||
|
kernel_size=[3, 3],
|
||||||
|
stride=[1, 1],
|
||||||
|
padding=[1, 1],
|
||||||
|
ceil_mode=False,
|
||||||
|
count_include_pad=False,
|
||||||
|
divisor_override=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([32, 384, 25, 25], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ap2d(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AvgPool2dCountIncludePadFalseStaticModule())
|
||||||
|
def AvgPool2dCountIncludePadFalseStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(32, 384, 25, 25, low=-1))
|
||||||
|
|
||||||
|
|
||||||
class AvgPool2dDivisorOverrideModule(torch.nn.Module):
|
class AvgPool2dDivisorOverrideModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue