Add static shape for scalar tensors (#833)

* Assume zero rank tensors are scalar

* Run RefineTypes pass on JIT Graph

* Rollback assumption that zero rank tensors are scalar

* Set numSizes to -1 for non-ranked tensors

* Rename RefineTypes to RefineTupleTypes
pull/1125/head
Henry Tu 2022-05-11 21:00:06 -04:00 committed by Henry Tu
parent de5b380143
commit 0c35e607b3
4 changed files with 12 additions and 3 deletions

View File

@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
/// Gets a !torch.tensor type. /// Gets a !torch.tensor type.
/// ///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size /// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case). - /// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype /// `optionalDtype` is allowed to be null, meaning that no dtype
@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
/// Gets a !torch.vtensor type. /// Gets a !torch.vtensor type.
/// ///
/// - `numSizes` having a value of -1 denotes an unranked tensor.
/// - `optionalSizes` is allowed to be null, meaning that no size /// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case). /// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype /// - `optionalDtype` is allowed to be null, meaning that no dtype

View File

@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes, const int64_t *optionalSizes,
MlirType optionalDtype) { MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes) // if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get( return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
const int64_t *optionalSizes, const int64_t *optionalSizes,
MlirType optionalDtype) { MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes) // if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get( return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));

View File

@ -13,6 +13,7 @@
#include <iostream> #include <iostream>
#include <torch/csrc/jit/api/compilation_unit.h> #include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h> #include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter(
ComputationPtr TorchMlirLoweringContext::Build() { ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION(); PRINT_FUNCTION();
// Since we mutated the types of some nodes to insert shape information, we
// must perform this pass to ensure tuples have up to date output types.
torch::jit::RefineTupleTypes(graph_);
// Insert return values into graph. // Insert return values into graph.
for (torch::jit::Value* output : root_tuple_) { for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output); graph_->block()->registerOutput(output);

View File

@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
if (!sizes.rank()) { if (!sizes.rank()) {
// Unranked. // Unranked.
return getMlirTensorType(context, return getMlirTensorType(context,
/*numSizes=*/0, /*numSizes=*/-1,
/*optionalSizes=*/nullptr, /*optionalSizes=*/nullptr,
/*optionalDtype=*/ /*optionalDtype=*/
elementType); elementType);