diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 8cff8da86..f459960ee 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t); /// Gets a !torch.tensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). - /// `optionalDtype` is allowed to be null, meaning that no dtype @@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t); /// Gets a !torch.vtensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). /// - `optionalDtype` is allowed to be null, meaning that no dtype diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 0c67453f3..6d72e7e15 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::NonValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); @@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::ValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index f939a7a07..7faa5e98d 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" @@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter( ComputationPtr TorchMlirLoweringContext::Build() { PRINT_FUNCTION(); + // Since we mutated the types of some nodes to insert shape information, we + // must perform this pass to ensure tuples have up to date output types. + torch::jit::RefineTupleTypes(graph_); + // Insert return values into graph. for (torch::jit::Value* output : root_tuple_) { graph_->block()->registerOutput(output); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 3cd4ae264..3da29416e 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, if (!sizes.rank()) { // Unranked. return getMlirTensorType(context, - /*numSizes=*/0, + /*numSizes=*/-1, /*optionalSizes=*/nullptr, /*optionalDtype=*/ elementType);