[ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion (#3053)

Two e2e tests (AdaptiveAveragePool1/2dUnitOutputSizeDynamic) were
failing due to numerics. This was as a result of passing -1 as the
kernel size in the lowering for the corresponding onnx op
GlobalAveragePool.
pull/3073/head
zjgarvey 2024-03-28 11:43:09 -05:00 committed by GitHub
parent e6e7689a24
commit c19fc9ba47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View File

@ -1117,9 +1117,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
for (unsigned i = 2; i < inputRank; i++) {
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
if (inputShape[i] == Torch::kUnknownSize) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), operand, dim);
cstKernel.push_back(inputDimSize);
} else {
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
}
cstPadding.push_back(cstZero);
cstStrides.push_back(cstOne);
}

View File

@ -1485,8 +1485,6 @@ ONNX_XFAIL_SET = {
"PermuteNegativeIndexModule_basic",
# Failure - incorrect numerics
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog2IntModule_basic",
@ -1496,7 +1494,6 @@ ONNX_XFAIL_SET = {
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ResNet18Module_basic",
"SliceCopyEndGreaterThanDimSize_Module_basic",
"SliceCopyNegative_Module_basic",
"SliceCopyNonZeroDim_Module_basic",