[Stablehlo] Fix AtenSumDimIntListOp when dim==None (#3216)

as titile
pull/3222/head
Xinyu Yang 2024-04-24 11:25:46 +08:00 committed by GitHub
parent 4da3d714cc
commit 42b9eccdb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 5 deletions

View File

@ -700,11 +700,17 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
SmallVector<int64_t> inputDims; SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); if (failed(checkNotNone(rewriter, op, op.getDim()))) {
}
if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank())); inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}
} }
for (auto d : inputDims) { for (auto d : inputDims) {

View File

@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"MaxPool3dModule_basic", "MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic", "MaxPool3dStaticModule_basic",
"MeanDimNoneDimModule_basic",
"MseLossMeanReductionModule_basic", "MseLossMeanReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic",
"MulFloatModule_basic", "MulFloatModule_basic",
@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = {
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MeanDimAllReduceModule_basic", "MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic", "MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic", "MeanDynamicSizesModule_basic",
"MeanModule_basic", "MeanModule_basic",