[torch] Add edgecase for aten.shape_to_tensor for rank-0 input (#2962)

Currently lowering uses `tensor.from_elements` which does not allow zero
inputs. In this case we return a `tensor.empty` operation.
pull/2958/head
Rob Suderman 2024-02-28 09:47:06 -08:00 committed by GitHub
parent 08bc013fcd
commit dd673cfa8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 0 deletions

View File

@ -84,6 +84,12 @@ public:
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t rank = operandTy.getRank(); int64_t rank = operandTy.getRank();
if (rank == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(op, resultTy.getShape(),
resultTy.getElementType());
return success();
}
SmallVector<Value> dims; SmallVector<Value> dims;
for (int i = 0; i < rank; ++i) { for (int i = 0; i < rank; ++i) {
Value dim = rewriter.createOrFold<tensor::DimOp>(loc, operand, i); Value dim = rewriter.createOrFold<tensor::DimOp>(loc, operand, i);