mirror of https://github.com/llvm/torch-mlir
[ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion (#3053)
Two e2e tests (AdaptiveAveragePool1/2dUnitOutputSizeDynamic) were failing due to numerics. This was as a result of passing -1 as the kernel size in the lowering for the corresponding onnx op GlobalAveragePool.pull/3073/head
parent
e6e7689a24
commit
c19fc9ba47
|
@ -1117,9 +1117,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
for (unsigned i = 2; i < inputRank; i++) {
|
||||
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
||||
if (inputShape[i] == Torch::kUnknownSize) {
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||
binder.getLoc(), operand, dim);
|
||||
cstKernel.push_back(inputDimSize);
|
||||
} else {
|
||||
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
||||
}
|
||||
cstPadding.push_back(cstZero);
|
||||
cstStrides.push_back(cstOne);
|
||||
}
|
||||
|
|
|
@ -1485,8 +1485,6 @@ ONNX_XFAIL_SET = {
|
|||
"PermuteNegativeIndexModule_basic",
|
||||
|
||||
# Failure - incorrect numerics
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||
"ElementwiseAtan2TensorIntModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
|
@ -1496,7 +1494,6 @@ ONNX_XFAIL_SET = {
|
|||
"HardsigmoidModule_basic",
|
||||
"HardsigmoidRandomModule_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"ResNet18Module_basic",
|
||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"SliceCopyNonZeroDim_Module_basic",
|
||||
|
|
Loading…
Reference in New Issue