mirror of https://github.com/llvm/torch-mlir
[Torch] Extract TensorPlaceholder to a common interface (#3668)
parent
eb7bf78a9c
commit
6eba5bc9ee
|
@ -21,59 +21,12 @@ from torch_mlir.compiler_utils import (
|
||||||
run_pipeline_with_repro_report,
|
run_pipeline_with_repro_report,
|
||||||
OutputType,
|
OutputType,
|
||||||
lower_mlir_module,
|
lower_mlir_module,
|
||||||
|
TensorPlaceholder,
|
||||||
)
|
)
|
||||||
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||||
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
|
||||||
|
|
||||||
|
|
||||||
class TensorPlaceholder:
|
|
||||||
"""A class that represents a formal parameter of a given shape and dtype.
|
|
||||||
|
|
||||||
This class can be constructed explicitly from a shape and dtype:
|
|
||||||
```python
|
|
||||||
placeholder = TensorPlaceholder([3, 4], torch.float32)
|
|
||||||
```
|
|
||||||
|
|
||||||
This class can also be constructed from a `torch.Tensor` which is already
|
|
||||||
known to be a valid input to the function. In this case, a set of
|
|
||||||
dynamic axes are allowed to be specified.
|
|
||||||
```python
|
|
||||||
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
|
|
||||||
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, shape: List[int], dtype: torch.dtype):
|
|
||||||
"""Create a tensor with shape `shape` and dtype `dtype`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
shape: The shape of the tensor. A size of `-1` indicates that the
|
|
||||||
dimension has an unknown size.
|
|
||||||
dtype: The dtype of the tensor.
|
|
||||||
"""
|
|
||||||
self.shape = shape
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
|
|
||||||
"""Create a tensor placeholder that is like the given tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: The tensor to create a placeholder for.
|
|
||||||
dynamic_axes: A list of dynamic axes. If specified, the compiled
|
|
||||||
module will allow those axes to be any size at runtime.
|
|
||||||
"""
|
|
||||||
if dynamic_axes is None:
|
|
||||||
dynamic_axes = []
|
|
||||||
shape = []
|
|
||||||
for i, dim in enumerate(tensor.shape):
|
|
||||||
if i in dynamic_axes:
|
|
||||||
shape.append(-1)
|
|
||||||
else:
|
|
||||||
shape.append(dim)
|
|
||||||
return TensorPlaceholder(shape, tensor.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
_example_arg = Union[TensorPlaceholder, torch.Tensor]
|
||||||
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
|
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
|
||||||
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
|
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
# Also available under a BSD-style license. See LICENSE.
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
from torch_mlir.torchscript import TensorPlaceholder
|
from torch_mlir.compiler_utils import TensorPlaceholder
|
||||||
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,61 @@ from io import StringIO
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Union
|
from typing import Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch_mlir.passmanager import PassManager
|
from torch_mlir.passmanager import PassManager
|
||||||
from torch_mlir.ir import StringAttr
|
from torch_mlir.ir import StringAttr
|
||||||
|
|
||||||
|
|
||||||
|
class TensorPlaceholder:
|
||||||
|
"""A class that represents a formal parameter of a given shape and dtype.
|
||||||
|
|
||||||
|
This class can be constructed explicitly from a shape and dtype:
|
||||||
|
```python
|
||||||
|
placeholder = TensorPlaceholder([3, 4], torch.float32)
|
||||||
|
```
|
||||||
|
|
||||||
|
This class can also be constructed from a `torch.Tensor` which is already
|
||||||
|
known to be a valid input to the function. In this case, a set of
|
||||||
|
dynamic axes are allowed to be specified.
|
||||||
|
```python
|
||||||
|
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
|
||||||
|
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape: List[int], dtype: torch.dtype):
|
||||||
|
"""Create a tensor with shape `shape` and dtype `dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the tensor. A size of `-1` indicates that the
|
||||||
|
dimension has an unknown size.
|
||||||
|
dtype: The dtype of the tensor.
|
||||||
|
"""
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
|
||||||
|
"""Create a tensor placeholder that is like the given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to create a placeholder for.
|
||||||
|
dynamic_axes: A list of dynamic axes. If specified, the compiled
|
||||||
|
module will allow those axes to be any size at runtime.
|
||||||
|
"""
|
||||||
|
if dynamic_axes is None:
|
||||||
|
dynamic_axes = []
|
||||||
|
shape = []
|
||||||
|
for i, dim in enumerate(tensor.shape):
|
||||||
|
if i in dynamic_axes:
|
||||||
|
shape.append(-1)
|
||||||
|
else:
|
||||||
|
shape.append(dim)
|
||||||
|
return TensorPlaceholder(shape, tensor.dtype)
|
||||||
|
|
||||||
|
|
||||||
def get_module_name_for_debug_dump(module):
|
def get_module_name_for_debug_dump(module):
|
||||||
"""Gets a name suitable for a debug dump.
|
"""Gets a name suitable for a debug dump.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue