[Stablehlo] Fix AtenSumDimIntListOp when dim==None (#3216)

as titile
pull/3222/head
Xinyu Yang 2024-04-24 11:25:46 +08:00 committed by GitHub
parent 4da3d714cc
commit 42b9eccdb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 5 deletions

View File

@ -700,11 +700,17 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
}
if (inputDims.size() == 0) {
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}
}
for (auto d : inputDims) {

View File

@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic",
"MeanDimNoneDimModule_basic",
"MseLossMeanReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic",
"MulFloatModule_basic",
@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = {
"MaxPool2dStaticModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",