From 42b9eccdb3f51811ea9c1ec2f89379c3e64824f7 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 24 Apr 2024 11:25:46 +0800 Subject: [PATCH] [Stablehlo] Fix AtenSumDimIntListOp when dim==None (#3216) as titile --- lib/Conversion/TorchToStablehlo/Reduction.cpp | 14 ++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 367acefc9..1e494f433 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -700,11 +700,17 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( SmallVector inputDims; SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { - return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); - } - if (inputDims.size() == 0) { + + if (failed(checkNotNone(rewriter, op, op.getDim()))) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } else { + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } } for (auto d : inputDims) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4657bc0a4..1f16a25a9 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", - "MeanDimNoneDimModule_basic", "MseLossMeanReductionModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic", "MulFloatModule_basic", @@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = { "MaxPool2dStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", + "MeanDimNoneDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", "MeanModule_basic",