[torch] Support `aten.view` rank-0 collapse (#2965)

Collapsing to a rank-0 tensor using `aten.view` was currently bailing
out. Added the special case.
pull/2975/head
Rob Suderman 2024-03-01 12:31:07 -08:00 committed by GitHub
parent e7d90a4b82
commit d030bffc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 10 deletions

View File

@ -795,17 +795,21 @@ public:
auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
int64_t resultRank = resultType.getRank();
if (resultRank == 0)
return rewriter.notifyMatchFailure(op,
"result shape of rank 0 is invalid");
if (resultRank == 0) {
rewriter
.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, resultType, input, ArrayRef<ReassociationIndices>())
.getResult();
return success();
}
if (inputRank == 0) {
Value expanded =
rewriter
.create<tensor::ExpandShapeOp>(loc, resultType, input,
ArrayRef<ReassociationIndices>())
.getResult();
rewriter.replaceOp(op, expanded);
llvm::SmallVector<int64_t> outshape(resultRank, 1);
auto expandTy =
RankedTensorType::get(outshape, resultType.getElementType());
Value expand = rewriter.create<tensor::ExpandShapeOp>(
op.getLoc(), expandTy, input, ArrayRef<ReassociationIndices>());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, expand);
return success();
}

View File

@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = {
"ReduceMaxUnsignedIntModule_basic",
# Failure - torch.aten.view lower
"AddSizeIntModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",