mirror of https://github.com/llvm/torch-mlir
[Torch] Extract TensorPlaceholder to a common interface (#3668)
parent
eb7bf78a9c
commit
6eba5bc9ee
|
@ -21,59 +21,12 @@ from torch_mlir.compiler_utils import (
|
|||
run_pipeline_with_repro_report,
|
||||
OutputType,
|
||||
lower_mlir_module,
|
||||
TensorPlaceholder,
|
||||
)
|
||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||
|
||||
|
||||
class TensorPlaceholder:
|
||||
"""A class that represents a formal parameter of a given shape and dtype.
|
||||
|
||||
This class can be constructed explicitly from a shape and dtype:
|
||||
```python
|
||||
placeholder = TensorPlaceholder([3, 4], torch.float32)
|
||||
```
|
||||
|
||||
This class can also be constructed from a `torch.Tensor` which is already
|
||||
known to be a valid input to the function. In this case, a set of
|
||||
dynamic axes are allowed to be specified.
|
||||
```python
|
||||
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
|
||||
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, shape: List[int], dtype: torch.dtype):
|
||||
"""Create a tensor with shape `shape` and dtype `dtype`.
|
||||
|
||||
Args:
|
||||
shape: The shape of the tensor. A size of `-1` indicates that the
|
||||
dimension has an unknown size.
|
||||
dtype: The dtype of the tensor.
|
||||
"""
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
@staticmethod
|
||||
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
|
||||
"""Create a tensor placeholder that is like the given tensor.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to create a placeholder for.
|
||||
dynamic_axes: A list of dynamic axes. If specified, the compiled
|
||||
module will allow those axes to be any size at runtime.
|
||||
"""
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = []
|
||||
shape = []
|
||||
for i, dim in enumerate(tensor.shape):
|
||||
if i in dynamic_axes:
|
||||
shape.append(-1)
|
||||
else:
|
||||
shape.append(dim)
|
||||
return TensorPlaceholder(shape, tensor.dtype)
|
||||
|
||||
|
||||
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
||||
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
|
||||
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from torch_mlir.torchscript import TensorPlaceholder
|
||||
from torch_mlir.compiler_utils import TensorPlaceholder
|
||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||
|
||||
|
||||
|
|
|
@ -7,12 +7,61 @@ from io import StringIO
|
|||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
|
||||
class TensorPlaceholder:
|
||||
"""A class that represents a formal parameter of a given shape and dtype.
|
||||
|
||||
This class can be constructed explicitly from a shape and dtype:
|
||||
```python
|
||||
placeholder = TensorPlaceholder([3, 4], torch.float32)
|
||||
```
|
||||
|
||||
This class can also be constructed from a `torch.Tensor` which is already
|
||||
known to be a valid input to the function. In this case, a set of
|
||||
dynamic axes are allowed to be specified.
|
||||
```python
|
||||
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
|
||||
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, shape: List[int], dtype: torch.dtype):
|
||||
"""Create a tensor with shape `shape` and dtype `dtype`.
|
||||
|
||||
Args:
|
||||
shape: The shape of the tensor. A size of `-1` indicates that the
|
||||
dimension has an unknown size.
|
||||
dtype: The dtype of the tensor.
|
||||
"""
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
@staticmethod
|
||||
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
|
||||
"""Create a tensor placeholder that is like the given tensor.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to create a placeholder for.
|
||||
dynamic_axes: A list of dynamic axes. If specified, the compiled
|
||||
module will allow those axes to be any size at runtime.
|
||||
"""
|
||||
if dynamic_axes is None:
|
||||
dynamic_axes = []
|
||||
shape = []
|
||||
for i, dim in enumerate(tensor.shape):
|
||||
if i in dynamic_axes:
|
||||
shape.append(-1)
|
||||
else:
|
||||
shape.append(dim)
|
||||
return TensorPlaceholder(shape, tensor.dtype)
|
||||
|
||||
|
||||
def get_module_name_for_debug_dump(module):
|
||||
"""Gets a name suitable for a debug dump.
|
||||
|
||||
|
|
Loading…
Reference in New Issue