mirror of https://github.com/llvm/torch-mlir
[Linalg] Bring back onnx AveragePool padding asymmetric support
parent
ae6f5e8251
commit
2f2dfb7e44
|
@ -441,17 +441,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
|
for (int64_t i : padding) {
|
||||||
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
|
|
||||||
// axes x.
|
|
||||||
int64_t paddingSizeHalf = padding.size() / 2;
|
|
||||||
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
|
|
||||||
// Check if onnx padding attribute is symmetric.
|
|
||||||
if (padding[i] != padding[i + paddingSizeHalf])
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "onnx padding attribute is not symmetric");
|
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
for (int64_t i : strides) {
|
for (int64_t i : strides) {
|
||||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
|
|
@ -641,7 +641,7 @@ public:
|
||||||
// Case1: AtenAvgPool1d/2dOp with countIncludePad=false support.
|
// Case1: AtenAvgPool1d/2dOp with countIncludePad=false support.
|
||||||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||||
auto selfType = cast<RankedTensorType>(self.getType());
|
auto selfType = cast<RankedTensorType>(self.getType());
|
||||||
const int64_t selfRank = selfType.getRank();
|
unsigned selfRank = selfType.getRank();
|
||||||
int64_t wDim = toPositiveDim(-1, selfRank);
|
int64_t wDim = toPositiveDim(-1, selfRank);
|
||||||
int64_t hDim = toPositiveDim(-2, selfRank);
|
int64_t hDim = toPositiveDim(-2, selfRank);
|
||||||
Value inputHeight = getDimOp(rewriter, loc, self, hDim);
|
Value inputHeight = getDimOp(rewriter, loc, self, hDim);
|
||||||
|
@ -657,6 +657,12 @@ public:
|
||||||
/*indexingMaps=*/indexingMapsAvg,
|
/*indexingMaps=*/indexingMapsAvg,
|
||||||
/*iteratorTypes=*/iteratorTypesAvg,
|
/*iteratorTypes=*/iteratorTypesAvg,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
if (!isa<Torch::NoneType>(
|
||||||
|
op.getDivisorOverride().getType())) {
|
||||||
|
// AtenAvgPool2/3dOp has an optional divisor_override
|
||||||
|
// attribute while AtenAvgPool1dOp does not.
|
||||||
|
divisor = adaptor.getDivisorOverride();
|
||||||
|
} else {
|
||||||
// The algorithm for computing the divisor with
|
// The algorithm for computing the divisor with
|
||||||
// count_include_pad is manily based on pytorch
|
// count_include_pad is manily based on pytorch
|
||||||
// implementation. The following code is comment
|
// implementation. The following code is comment
|
||||||
|
@ -683,6 +689,15 @@ public:
|
||||||
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
|
loc, rewriter.getI64IntegerAttr(paddingInts[1]));
|
||||||
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
|
Value owDW = b.create<arith::MulIOp>(loc, ow, dW);
|
||||||
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
|
Value iw0 = b.create<arith::SubIOp>(loc, owDW, padW);
|
||||||
|
// onnx average pool may pass asymmetric padding,
|
||||||
|
// so modify the padding values to now represent high
|
||||||
|
// padding.
|
||||||
|
if (paddingInts.size() == 2 * (selfRank - 2)) {
|
||||||
|
padH = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(paddingInts[2]));
|
||||||
|
padW = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(paddingInts[3]));
|
||||||
|
}
|
||||||
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
|
// int64_t ih1 = std::min(ih0 + kH, input_height + padH);
|
||||||
Value ih = castIndexToInt64(b, loc, inputHeight);
|
Value ih = castIndexToInt64(b, loc, inputHeight);
|
||||||
Value ih0KH = b.create<arith::AddIOp>(
|
Value ih0KH = b.create<arith::AddIOp>(
|
||||||
|
@ -725,18 +740,13 @@ public:
|
||||||
divisor = convertScalarToDtype(b, loc, poolSize,
|
divisor = convertScalarToDtype(b, loc, poolSize,
|
||||||
resultElementType);
|
resultElementType);
|
||||||
} else {
|
} else {
|
||||||
Value ih1_ih0 =
|
Value ih1_ih0 = b.create<arith::SubIOp>(loc, ih1Clamped,
|
||||||
b.create<arith::SubIOp>(loc, ih1Clamped, ih0Clamped);
|
ih0Clamped);
|
||||||
Value iw1_iw0 =
|
Value iw1_iw0 = b.create<arith::SubIOp>(loc, iw1Clamped,
|
||||||
b.create<arith::SubIOp>(loc, iw1Clamped, iw0Clamped);
|
iw0Clamped);
|
||||||
divisor = b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
|
divisor =
|
||||||
|
b.create<arith::MulIOp>(loc, ih1_ih0, iw1_iw0);
|
||||||
}
|
}
|
||||||
// AtenAvgPool2/3dOp has an optional divisor_override
|
|
||||||
// attribute while AtenAvgPool1dOp does not.
|
|
||||||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
|
||||||
if (!isa<Torch::NoneType>(
|
|
||||||
op.getDivisorOverride().getType()))
|
|
||||||
divisor = adaptor.getDivisorOverride();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
divisor = convertScalarToDtype(b, loc, divisor,
|
divisor = convertScalarToDtype(b, loc, divisor,
|
||||||
|
|
Loading…
Reference in New Issue