mirror of https://github.com/llvm/torch-mlir
Add static shape for scalar tensors (#833)
* Assume zero rank tensors are scalar * Run RefineTypes pass on JIT Graph * Rollback assumption that zero rank tensors are scalar * Set numSizes to -1 for non-ranked tensors * Rename RefineTypes to RefineTupleTypespull/1125/head
parent
de5b380143
commit
0c35e607b3
|
@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
|
|||
|
||||
/// Gets a !torch.tensor type.
|
||||
///
|
||||
/// - `numSizes` having a value of -1 denotes an unranked tensor.
|
||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||
/// information is present (and `numSizes` is ignored in that case). -
|
||||
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
|
@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
|
|||
|
||||
/// Gets a !torch.vtensor type.
|
||||
///
|
||||
/// - `numSizes` having a value of -1 denotes an unranked tensor.
|
||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||
/// information is present (and `numSizes` is ignored in that case).
|
||||
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
|
|
|
@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
|||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
// if numSizes == -1, then it is unranked.
|
||||
if (numSizes > -1)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::NonValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
|
@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
|||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
// if numSizes == -1, then it is unranked.
|
||||
if (numSizes > -1)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::ValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include <iostream>
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||
|
||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
||||
|
@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter(
|
|||
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
// Since we mutated the types of some nodes to insert shape information, we
|
||||
// must perform this pass to ensure tuples have up to date output types.
|
||||
torch::jit::RefineTupleTypes(graph_);
|
||||
|
||||
// Insert return values into graph.
|
||||
for (torch::jit::Value* output : root_tuple_) {
|
||||
graph_->block()->registerOutput(output);
|
||||
|
|
|
@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return getMlirTensorType(context,
|
||||
/*numSizes=*/0,
|
||||
/*numSizes=*/-1,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
|
|
Loading…
Reference in New Issue