[ONNX] Add averagepool dilations support (#3490)

- To fix dilations issue: https://github.com/llvm/torch-mlir/issues/3428
- Test by: https://github.com/nod-ai/SHARK-TestSuite/pull/268
pull/3494/head
Chi_Liu 2024-06-21 17:24:57 -07:00 committed by GitHub
parent 98c6971a01
commit fc19709daa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 20 deletions

View File

@ -379,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
"AveragePool", 11, "AveragePool", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad; std::string autoPad;
SmallVector<int64_t> dilation; SmallVector<int64_t> dilations;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure(); return failure();
if (autoPad != "NOTSET") { if (autoPad != "NOTSET") {
@ -387,13 +387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET"); binder.op, "unsupported conversion: auto_pad != NOTSET");
} }
if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) {
return failure();
}
if (dilation.size() > 0) {
return rewriter.notifyMatchFailure(
binder.op, "dilation is not supported by torch.aten.avgpool op");
}
Torch::ValueTensorType resultType; Torch::ValueTensorType resultType;
Value operand; Value operand;
@ -436,7 +429,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, "strides list size does not match the number of axes"); binder.op, "strides list size does not match the number of axes");
} }
SmallVector<Value> cstKernel, cstPadding, cstStrides; SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
for (int64_t i : kernel) { for (int64_t i : kernel) {
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>( cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); binder.getLoc(), rewriter.getI64IntegerAttr(i)));
@ -454,9 +447,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
} }
for (int64_t i : strides) { for (int64_t i : strides) {
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>( cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i))); binder.getLoc(), rewriter.getI64IntegerAttr(i)));
} }
// No dilations attribute in pytorch avgpool op, so use this trick to
// encode dilation into strides. Then in the following torchtolinalg
// lowering, decode strides into strides + dilation.
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
if (binder.s64IntegerArrayAttr(
dilations, "dilations",
llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
for (auto dilation : dilations) {
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
}
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>( Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
@ -465,10 +473,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding); cstPadding);
Value stridesList = rewriter.create<Torch::PrimListConstructOp>( Value stridesDilationsList =
binder.getLoc(), rewriter.create<Torch::PrimListConstructOp>(
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), binder.getLoc(),
cstStrides); Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstStridesDilations);
Value cstCeilMode = Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode); rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>( Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
@ -477,19 +487,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (rank == 3) { if (rank == 3) {
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>( rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList, binder.op, resultType, operand, kernelSizeList,
paddingList, cstCeilMode, cstCountIncludePad); stridesDilationsList, paddingList, cstCeilMode,
cstCountIncludePad);
return success(); return success();
} else if (rank == 4) { } else if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>( rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList, binder.op, resultType, operand, kernelSizeList,
paddingList, cstCeilMode, cstCountIncludePad, stridesDilationsList, paddingList, cstCeilMode,
cstCountIncludePad,
/*divisor_override=*/cstNone); /*divisor_override=*/cstNone);
return success(); return success();
} else if (rank == 5) { } else if (rank == 5) {
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>( rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
binder.op, resultType, operand, kernelSizeList, stridesList, binder.op, resultType, operand, kernelSizeList,
paddingList, cstCeilMode, cstCountIncludePad, stridesDilationsList, paddingList, cstCeilMode,
cstCountIncludePad,
/*divisor_override=*/cstNone); /*divisor_override=*/cstNone);
return success(); return success();
} }

View File

@ -612,6 +612,16 @@ public:
strideInts, paddingInts))) strideInts, paddingInts)))
return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
// Decode strideInts into strideInts and dilation
if (strideInts.size() == 2 * Dim) {
for (int i = 0; i < Dim; i++) {
dilationInts[i] = strideInts[Dim + i];
}
for (int i = 0; i < Dim; i++) {
strideInts.pop_back();
}
}
// TODO: Add support for count_include_pad equal to `False`. // TODO: Add support for count_include_pad equal to `False`.
bool countIncludePad; bool countIncludePad;
if (!matchPattern(op.getCountIncludePad(), if (!matchPattern(op.getCountIncludePad(),