Add static shape for scalar tensors (#833)

* Assume zero rank tensors are scalar

* Run RefineTypes pass on JIT Graph

* Rollback assumption that zero rank tensors are scalar

* Set numSizes to -1 for non-ranked tensors

* Rename RefineTypes to RefineTupleTypes
pull/1125/head
Henry Tu 2022-05-11 21:00:06 -04:00 committed by Henry Tu
parent de5b380143
commit 0c35e607b3
4 changed files with 12 additions and 3 deletions

View File

@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
/// Gets a !torch.tensor type.
///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype
@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
/// Gets a !torch.vtensor type.
///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype

View File

@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));

View File

@ -13,6 +13,7 @@
#include <iostream>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter(
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION();
// Since we mutated the types of some nodes to insert shape information, we
// must perform this pass to ensure tuples have up to date output types.
torch::jit::RefineTupleTypes(graph_);
// Insert return values into graph.
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);

View File

@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
if (!sizes.rank()) {
// Unranked.
return getMlirTensorType(context,
/*numSizes=*/0,
/*numSizes=*/-1,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);