mirror of https://github.com/llvm/torch-mlir
Add static shape for scalar tensors (#833)
* Assume zero rank tensors are scalar * Run RefineTypes pass on JIT Graph * Rollback assumption that zero rank tensors are scalar * Set numSizes to -1 for non-ranked tensors * Rename RefineTypes to RefineTupleTypespull/1125/head
parent
de5b380143
commit
0c35e607b3
|
@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t);
|
||||||
|
|
||||||
/// Gets a !torch.tensor type.
|
/// Gets a !torch.tensor type.
|
||||||
///
|
///
|
||||||
|
/// - `numSizes` having a value of -1 denotes an unranked tensor.
|
||||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||||
/// information is present (and `numSizes` is ignored in that case). -
|
/// information is present (and `numSizes` is ignored in that case). -
|
||||||
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
|
@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t);
|
||||||
|
|
||||||
/// Gets a !torch.vtensor type.
|
/// Gets a !torch.vtensor type.
|
||||||
///
|
///
|
||||||
|
/// - `numSizes` having a value of -1 denotes an unranked tensor.
|
||||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||||
/// information is present (and `numSizes` is ignored in that case).
|
/// information is present (and `numSizes` is ignored in that case).
|
||||||
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
|
|
|
@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
const int64_t *optionalSizes,
|
const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype) {
|
MlirType optionalDtype) {
|
||||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||||
if (optionalSizes)
|
// if numSizes == -1, then it is unranked.
|
||||||
|
if (numSizes > -1)
|
||||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||||
return wrap(Torch::NonValueTensorType::get(
|
return wrap(Torch::NonValueTensorType::get(
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
|
@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
||||||
const int64_t *optionalSizes,
|
const int64_t *optionalSizes,
|
||||||
MlirType optionalDtype) {
|
MlirType optionalDtype) {
|
||||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||||
if (optionalSizes)
|
// if numSizes == -1, then it is unranked.
|
||||||
|
if (numSizes > -1)
|
||||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||||
return wrap(Torch::ValueTensorType::get(
|
return wrap(Torch::ValueTensorType::get(
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
|
|
||||||
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h"
|
||||||
|
@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter(
|
||||||
ComputationPtr TorchMlirLoweringContext::Build() {
|
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
|
// Since we mutated the types of some nodes to insert shape information, we
|
||||||
|
// must perform this pass to ensure tuples have up to date output types.
|
||||||
|
torch::jit::RefineTupleTypes(graph_);
|
||||||
|
|
||||||
// Insert return values into graph.
|
// Insert return values into graph.
|
||||||
for (torch::jit::Value* output : root_tuple_) {
|
for (torch::jit::Value* output : root_tuple_) {
|
||||||
graph_->block()->registerOutput(output);
|
graph_->block()->registerOutput(output);
|
||||||
|
|
|
@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
if (!sizes.rank()) {
|
if (!sizes.rank()) {
|
||||||
// Unranked.
|
// Unranked.
|
||||||
return getMlirTensorType(context,
|
return getMlirTensorType(context,
|
||||||
/*numSizes=*/0,
|
/*numSizes=*/-1,
|
||||||
/*optionalSizes=*/nullptr,
|
/*optionalSizes=*/nullptr,
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
elementType);
|
elementType);
|
||||||
|
|
Loading…
Reference in New Issue