[Stablehlo] Fix AtenSumDimIntListOp when dim==None (#3216)

as titile
pull/3222/head
Xinyu Yang 2024-04-24 11:25:46 +08:00 committed by GitHub
parent 4da3d714cc
commit 42b9eccdb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 5 deletions

View File

@ -700,12 +700,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
SmallVector<int64_t> inputDims; SmallVector<int64_t> inputDims;
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
} }
if (inputDims.size() == 0) { if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank())); inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
} }
}
for (auto d : inputDims) { for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank()); d = toPositiveDim(d, inputTy.getRank());

View File

@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"MaxPool3dModule_basic", "MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic", "MaxPool3dStaticModule_basic",
"MeanDimNoneDimModule_basic",
"MseLossMeanReductionModule_basic", "MseLossMeanReductionModule_basic",
"MseLossSumReductionWithDifferentElemTypeModule_basic", "MseLossSumReductionWithDifferentElemTypeModule_basic",
"MulFloatModule_basic", "MulFloatModule_basic",
@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = {
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MeanDimAllReduceModule_basic", "MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic", "MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic", "MeanDynamicSizesModule_basic",
"MeanModule_basic", "MeanModule_basic",