From 0d4445eaf9e270207b3670aa3eb4316e489c59d4 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Tue, 14 Jun 2022 09:58:50 +0800 Subject: [PATCH] Fix: 0 sizes tensor being regarded as unknown rank (#923) --- .../torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 8a69a73a5..c023959fe 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -149,8 +149,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, auto shapeSymbol = symbolicShape[i]; dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; } + + // `std::vector`'s `.data()` method can return nullptr when the + // size is 0. This triggers the "nothing known about sizes" case in + // the C API constructor, when we want the "we know we have 0 sizes" + // case. So use a dummy data pointer. + int64_t dummy; + int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data(); return torchMlirTorchNonValueTensorTypeGet(context, dims.size(), - /*optionalSizes=*/dims.data(), + /*optionalSizes=*/dimsData, /*optionalDtype=*/ elementType); }