mirror of https://github.com/llvm/torch-mlir
[ONNX] Fixes Issue with Dynamic Dims in GlobalAveragePool -> Torch Conversion (#3053)
Two e2e tests (AdaptiveAveragePool1/2dUnitOutputSizeDynamic) were failing due to numerics. This was as a result of passing -1 as the kernel size in the lowering for the corresponding onnx op GlobalAveragePool.pull/3073/head
parent
e6e7689a24
commit
c19fc9ba47
|
@ -1117,9 +1117,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
for (unsigned i = 2; i < inputRank; i++) {
|
for (unsigned i = 2; i < inputRank; i++) {
|
||||||
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
if (inputShape[i] == Torch::kUnknownSize) {
|
||||||
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i));
|
||||||
|
Value inputDimSize = rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
binder.getLoc(), operand, dim);
|
||||||
|
cstKernel.push_back(inputDimSize);
|
||||||
|
} else {
|
||||||
|
int64_t kernelSize = inputShape[i] - resultShape[i] + 1;
|
||||||
|
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize)));
|
||||||
|
}
|
||||||
cstPadding.push_back(cstZero);
|
cstPadding.push_back(cstZero);
|
||||||
cstStrides.push_back(cstOne);
|
cstStrides.push_back(cstOne);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1485,8 +1485,6 @@ ONNX_XFAIL_SET = {
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
|
||||||
# Failure - incorrect numerics
|
# Failure - incorrect numerics
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
|
||||||
"ElementwiseAtan2TensorIntModule_basic",
|
"ElementwiseAtan2TensorIntModule_basic",
|
||||||
"ElementwiseLog10IntModule_basic",
|
"ElementwiseLog10IntModule_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
|
@ -1496,7 +1494,6 @@ ONNX_XFAIL_SET = {
|
||||||
"HardsigmoidModule_basic",
|
"HardsigmoidModule_basic",
|
||||||
"HardsigmoidRandomModule_basic",
|
"HardsigmoidRandomModule_basic",
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
"ResNet18Module_basic",
|
|
||||||
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||||
"SliceCopyNegative_Module_basic",
|
"SliceCopyNegative_Module_basic",
|
||||||
"SliceCopyNonZeroDim_Module_basic",
|
"SliceCopyNonZeroDim_Module_basic",
|
||||||
|
|
Loading…
Reference in New Issue