mirror of https://github.com/llvm/torch-mlir
Fix: 0 sizes tensor being regarded as unknown rank (#923)
parent
0a7ba62438
commit
0d4445eaf9
|
@ -149,8 +149,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
|
||||
// `std::vector`'s `.data()` method can return nullptr when the
|
||||
// size is 0. This triggers the "nothing known about sizes" case in
|
||||
// the C API constructor, when we want the "we know we have 0 sizes"
|
||||
// case. So use a dummy data pointer.
|
||||
int64_t dummy;
|
||||
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
|
||||
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalSizes=*/dimsData,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue