mirror of https://github.com/llvm/torch-mlir
[ONNX] Add averagepool dilations support (#3490)
- To fix dilations issue: https://github.com/llvm/torch-mlir/issues/3428 - Test by: https://github.com/nod-ai/SHARK-TestSuite/pull/268pull/3494/head
parent
98c6971a01
commit
fc19709daa
|
@ -379,7 +379,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
"AveragePool", 11,
|
"AveragePool", 11,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
std::string autoPad;
|
std::string autoPad;
|
||||||
SmallVector<int64_t> dilation;
|
SmallVector<int64_t> dilations;
|
||||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||||
return failure();
|
return failure();
|
||||||
if (autoPad != "NOTSET") {
|
if (autoPad != "NOTSET") {
|
||||||
|
@ -387,13 +387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||||
}
|
}
|
||||||
if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (dilation.size() > 0) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "dilation is not supported by torch.aten.avgpool op");
|
|
||||||
}
|
|
||||||
|
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value operand;
|
Value operand;
|
||||||
|
@ -436,7 +429,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.op, "strides list size does not match the number of axes");
|
binder.op, "strides list size does not match the number of axes");
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> cstKernel, cstPadding, cstStrides;
|
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
|
||||||
for (int64_t i : kernel) {
|
for (int64_t i : kernel) {
|
||||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
|
@ -454,9 +447,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||||
}
|
}
|
||||||
for (int64_t i : strides) {
|
for (int64_t i : strides) {
|
||||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// No dilations attribute in pytorch avgpool op, so use this trick to
|
||||||
|
// encode dilation into strides. Then in the following torchtolinalg
|
||||||
|
// lowering, decode strides into strides + dilation.
|
||||||
|
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
|
||||||
|
if (binder.s64IntegerArrayAttr(
|
||||||
|
dilations, "dilations",
|
||||||
|
llvm::SmallVector<int64_t>(rank - 2, 1))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
for (auto dilation : dilations) {
|
||||||
|
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
|
||||||
|
}
|
||||||
|
|
||||||
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
@ -465,10 +473,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
cstPadding);
|
cstPadding);
|
||||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
Value stridesDilationsList =
|
||||||
|
rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(
|
||||||
cstStrides);
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstStridesDilations);
|
||||||
Value cstCeilMode =
|
Value cstCeilMode =
|
||||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
|
||||||
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
|
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
@ -477,19 +487,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
|
|
||||||
if (rank == 3) {
|
if (rank == 3) {
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool1dOp>(
|
||||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
binder.op, resultType, operand, kernelSizeList,
|
||||||
paddingList, cstCeilMode, cstCountIncludePad);
|
stridesDilationsList, paddingList, cstCeilMode,
|
||||||
|
cstCountIncludePad);
|
||||||
return success();
|
return success();
|
||||||
} else if (rank == 4) {
|
} else if (rank == 4) {
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool2dOp>(
|
||||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
binder.op, resultType, operand, kernelSizeList,
|
||||||
paddingList, cstCeilMode, cstCountIncludePad,
|
stridesDilationsList, paddingList, cstCeilMode,
|
||||||
|
cstCountIncludePad,
|
||||||
/*divisor_override=*/cstNone);
|
/*divisor_override=*/cstNone);
|
||||||
return success();
|
return success();
|
||||||
} else if (rank == 5) {
|
} else if (rank == 5) {
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenAvgPool3dOp>(
|
||||||
binder.op, resultType, operand, kernelSizeList, stridesList,
|
binder.op, resultType, operand, kernelSizeList,
|
||||||
paddingList, cstCeilMode, cstCountIncludePad,
|
stridesDilationsList, paddingList, cstCeilMode,
|
||||||
|
cstCountIncludePad,
|
||||||
/*divisor_override=*/cstNone);
|
/*divisor_override=*/cstNone);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -612,6 +612,16 @@ public:
|
||||||
strideInts, paddingInts)))
|
strideInts, paddingInts)))
|
||||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||||
|
|
||||||
|
// Decode strideInts into strideInts and dilation
|
||||||
|
if (strideInts.size() == 2 * Dim) {
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
dilationInts[i] = strideInts[Dim + i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < Dim; i++) {
|
||||||
|
strideInts.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add support for count_include_pad equal to `False`.
|
// TODO: Add support for count_include_pad equal to `False`.
|
||||||
bool countIncludePad;
|
bool countIncludePad;
|
||||||
if (!matchPattern(op.getCountIncludePad(),
|
if (!matchPattern(op.getCountIncludePad(),
|
||||||
|
|
Loading…
Reference in New Issue