From d030bffc624860b57d43dc918e3bd2a55d33e077 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 1 Mar 2024 12:31:07 -0800 Subject: [PATCH] [torch] Support `aten.view` rank-0 collapse (#2965) Collapsing to a rank-0 tensor using `aten.view` was currently bailing out. Added the special case. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 22 +++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 1 - 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 42aacceab..e6ae601dc 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -795,17 +795,21 @@ public: auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - if (resultRank == 0) - return rewriter.notifyMatchFailure(op, - "result shape of rank 0 is invalid"); + if (resultRank == 0) { + rewriter + .replaceOpWithNewOp( + op, resultType, input, ArrayRef()) + .getResult(); + return success(); + } if (inputRank == 0) { - Value expanded = - rewriter - .create(loc, resultType, input, - ArrayRef()) - .getResult(); - rewriter.replaceOp(op, expanded); + llvm::SmallVector outshape(resultRank, 1); + auto expandTy = + RankedTensorType::get(outshape, resultType.getElementType()); + Value expand = rewriter.create( + op.getLoc(), expandTy, input, ArrayRef()); + rewriter.replaceOpWithNewOp(op, resultType, expand); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e2b198839..05ca09922 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = { "ReduceMaxUnsignedIntModule_basic", # Failure - torch.aten.view lower - "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",