mirror of https://github.com/llvm/torch-mlir
[ONNX] Add averagepool dilations support (#3490)
- To fix dilations issue: https://github.com/llvm/torch-mlir/issues/3428 - Test by: https://github.com/nod-ai/SHARK-TestSuite/pull/268pull/3494/head
parent
98c6971a01
commit
fc19709daa
|
@ -379,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
"AveragePool", 11,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
SmallVector<int64_t> dilation;
|
||||
SmallVector<int64_t> dilations;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return failure();
|
||||
if (autoPad != "NOTSET") {
|
||||
|
@ -387,13 +387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) {
|
||||
return failure();
|
||||
}
|
||||
if (dilation.size() > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "dilation is not supported by torch.aten.avgpool op");
|
||||
}
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
|
@ -436,7 +429,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, "strides list size does not match the number of axes");
|
||||
}
|
||||
|
||||
SmallVector<Value> cstKernel, cstPadding, cstStrides;
|
||||
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
|
||||
for (int64_t i : kernel) {
|
||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
|
@ -454,9 +447,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||
}
|
||||
for (int64_t i : strides) {
|
||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
|
||||
// No dilations attribute in pytorch avgpool op, so use this trick to
|
||||
// encode dilation into strides. Then in the following torchtolinalg
|
||||
// lowering, decode strides into strides + dilation.
|
||||
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
|
||||
if (binder.s64IntegerArrayAttr(
|
||||
dilations, "dilations",
|
||||
llvm::SmallVector<int64_t>(rank - 2, 1))) {
|
||||
return failure();
|
||||
}
|
||||
for (auto dilation : dilations) {
|
||||
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
|
||||
}
|
||||
|
||||
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
|
@ -465,10 +473,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstPadding);
|
||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstStrides);
|
||||
Value stridesDilationsList =
|
||||
rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(
|
||||
Torch::IntType::get(binder.op->getContext())),
|
||||
cstStridesDilations);
|
||||
Value cstCeilMode =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
|
||||
|
@ -477,19 +487,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
|
||||
if (rank == 3) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
|
||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad);
|
||||
binder.op, resultType, operand, kernelSizeList,
|
||||
stridesDilationsList, paddingList, cstCeilMode,
|
||||
cstCountIncludePad);
|
||||
return success();
|
||||
} else if (rank == 4) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
|
||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
binder.op, resultType, operand, kernelSizeList,
|
||||
stridesDilationsList, paddingList, cstCeilMode,
|
||||
cstCountIncludePad,
|
||||
/*divisor_override=*/cstNone);
|
||||
return success();
|
||||
} else if (rank == 5) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
|
||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
||||
paddingList, cstCeilMode, cstCountIncludePad,
|
||||
binder.op, resultType, operand, kernelSizeList,
|
||||
stridesDilationsList, paddingList, cstCeilMode,
|
||||
cstCountIncludePad,
|
||||
/*divisor_override=*/cstNone);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -612,6 +612,16 @@ public:
|
|||
strideInts, paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||
|
||||
// Decode strideInts into strideInts and dilation
|
||||
if (strideInts.size() == 2 * Dim) {
|
||||
for (int i = 0; i < Dim; i++) {
|
||||
dilationInts[i] = strideInts[Dim + i];
|
||||
}
|
||||
for (int i = 0; i < Dim; i++) {
|
||||
strideInts.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add support for count_include_pad equal to `False`.
|
||||
bool countIncludePad;
|
||||
if (!matchPattern(op.getCountIncludePad(),
|
||||
|
|
Loading…
Reference in New Issue