mirror of https://github.com/llvm/torch-mlir
parent
4da3d714cc
commit
42b9eccdb3
|
@ -700,12 +700,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
|
||||
SmallVector<int64_t> inputDims;
|
||||
SmallVector<int64_t> dims;
|
||||
|
||||
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
|
||||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
} else {
|
||||
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const integer `dim` is not supported");
|
||||
}
|
||||
if (inputDims.size() == 0) {
|
||||
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto d : inputDims) {
|
||||
d = toPositiveDim(d, inputTy.getRank());
|
||||
|
|
|
@ -728,7 +728,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"MaxPool3dModule_basic",
|
||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
"MseLossMeanReductionModule_basic",
|
||||
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||
"MulFloatModule_basic",
|
||||
|
@ -1140,6 +1139,7 @@ STABLEHLO_PASS_SET = {
|
|||
"MaxPool2dStaticModule_basic",
|
||||
"MeanDimAllReduceModule_basic",
|
||||
"MeanDimEmptyDimModule_basic",
|
||||
"MeanDimNoneDimModule_basic",
|
||||
"MeanDtypeModule_basic",
|
||||
"MeanDynamicSizesModule_basic",
|
||||
"MeanModule_basic",
|
||||
|
|
Loading…
Reference in New Issue