mirror of https://github.com/llvm/torch-mlir
[torch] Support `aten.view` rank-0 collapse (#2965)
Collapsing to a rank-0 tensor using `aten.view` was currently bailing out. Added the special case.pull/2975/head
parent
e7d90a4b82
commit
d030bffc62
|
@ -795,17 +795,21 @@ public:
|
|||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t resultRank = resultType.getRank();
|
||||
if (resultRank == 0)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"result shape of rank 0 is invalid");
|
||||
if (resultRank == 0) {
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
op, resultType, input, ArrayRef<ReassociationIndices>())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inputRank == 0) {
|
||||
Value expanded =
|
||||
rewriter
|
||||
.create<tensor::ExpandShapeOp>(loc, resultType, input,
|
||||
ArrayRef<ReassociationIndices>())
|
||||
.getResult();
|
||||
rewriter.replaceOp(op, expanded);
|
||||
llvm::SmallVector<int64_t> outshape(resultRank, 1);
|
||||
auto expandTy =
|
||||
RankedTensorType::get(outshape, resultType.getElementType());
|
||||
Value expand = rewriter.create<tensor::ExpandShapeOp>(
|
||||
op.getLoc(), expandTy, input, ArrayRef<ReassociationIndices>());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, expand);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -2144,7 +2144,6 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxUnsignedIntModule_basic",
|
||||
|
||||
# Failure - torch.aten.view lower
|
||||
"AddSizeIntModule_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
|
|
Loading…
Reference in New Issue