Fix: 0 sizes tensor being regarded as unknown rank (#923)

pull/936/head
Tanyo Kwok 2022-06-14 09:58:50 +08:00 committed by GitHub
parent 0a7ba62438
commit 0d4445eaf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 1 deletions

View File

@ -149,8 +149,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
// `std::vector`'s `.data()` method can return nullptr when the
// size is 0. This triggers the "nothing known about sizes" case in
// the C API constructor, when we want the "we know we have 0 sizes"
// case. So use a dummy data pointer.
int64_t dummy;
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(),
/*optionalSizes=*/dimsData,
/*optionalDtype=*/
elementType);
}