diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 42aacceab..e6ae601dc 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -795,17 +795,21 @@ public: auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - if (resultRank == 0) - return rewriter.notifyMatchFailure(op, - "result shape of rank 0 is invalid"); + if (resultRank == 0) { + rewriter + .replaceOpWithNewOp( + op, resultType, input, ArrayRef()) + .getResult(); + return success(); + } if (inputRank == 0) { - Value expanded = - rewriter - .create(loc, resultType, input, - ArrayRef()) - .getResult(); - rewriter.replaceOp(op, expanded); + llvm::SmallVector outshape(resultRank, 1); + auto expandTy = + RankedTensorType::get(outshape, resultType.getElementType()); + Value expand = rewriter.create( + op.getLoc(), expandTy, input, ArrayRef()); + rewriter.replaceOpWithNewOp(op, resultType, expand); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e2b198839..05ca09922 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = { "ReduceMaxUnsignedIntModule_basic", # Failure - torch.aten.view lower - "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",