mirror of https://github.com/llvm/torch-mlir
parent
4da3d714cc
commit
42b9eccdb3
|
@ -700,11 +700,17 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
|
|
||||||
SmallVector<int64_t> inputDims;
|
SmallVector<int64_t> inputDims;
|
||||||
SmallVector<int64_t> dims;
|
SmallVector<int64_t> dims;
|
||||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
|
||||||
}
|
|
||||||
if (inputDims.size() == 0) {
|
|
||||||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||||
|
} else {
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const integer `dim` is not supported");
|
||||||
|
}
|
||||||
|
if (inputDims.size() == 0) {
|
||||||
|
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto d : inputDims) {
|
for (auto d : inputDims) {
|
||||||
|
|
|
@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
"MeanDimNoneDimModule_basic",
|
|
||||||
"MseLossMeanReductionModule_basic",
|
"MseLossMeanReductionModule_basic",
|
||||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
"MulFloatModule_basic",
|
"MulFloatModule_basic",
|
||||||
|
@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"MaxPool2dStaticModule_basic",
|
"MaxPool2dStaticModule_basic",
|
||||||
"MeanDimAllReduceModule_basic",
|
"MeanDimAllReduceModule_basic",
|
||||||
"MeanDimEmptyDimModule_basic",
|
"MeanDimEmptyDimModule_basic",
|
||||||
|
"MeanDimNoneDimModule_basic",
|
||||||
"MeanDtypeModule_basic",
|
"MeanDtypeModule_basic",
|
||||||
"MeanDynamicSizesModule_basic",
|
"MeanDynamicSizesModule_basic",
|
||||||
"MeanModule_basic",
|
"MeanModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue