[Torch] Extract TensorPlaceholder to a common interface (#3668)

pull/3174/merge
penguin_wwy 2024-08-27 23:31:28 +08:00 committed by GitHub
parent eb7bf78a9c
commit 6eba5bc9ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 50 deletions

View File

@ -21,59 +21,12 @@ from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
OutputType,
lower_mlir_module,
TensorPlaceholder,
)
from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
This class can be constructed explicitly from a shape and dtype:
```python
placeholder = TensorPlaceholder([3, 4], torch.float32)
```
This class can also be constructed from a `torch.Tensor` which is already
known to be a valid input to the function. In this case, a set of
dynamic axes are allowed to be specified.
```python
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
```
"""
def __init__(self, shape: List[int], dtype: torch.dtype):
"""Create a tensor with shape `shape` and dtype `dtype`.
Args:
shape: The shape of the tensor. A size of `-1` indicates that the
dimension has an unknown size.
dtype: The dtype of the tensor.
"""
self.shape = shape
self.dtype = dtype
@staticmethod
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
"""Create a tensor placeholder that is like the given tensor.
Args:
tensor: The tensor to create a placeholder for.
dynamic_axes: A list of dynamic axes. If specified, the compiled
module will allow those axes to be any size at runtime.
"""
if dynamic_axes is None:
dynamic_axes = []
shape = []
for i, dim in enumerate(tensor.shape):
if i in dynamic_axes:
shape.append(-1)
else:
shape.append(dim)
return TensorPlaceholder(shape, tensor.dtype)
_example_arg = Union[TensorPlaceholder, torch.Tensor]
_example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]]
_example_args = Union[_example_args_for_one_method, "ExampleArgs"]

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from torch_mlir.torchscript import TensorPlaceholder
from torch_mlir.compiler_utils import TensorPlaceholder
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME

View File

@ -7,12 +7,61 @@ from io import StringIO
import os
import sys
import tempfile
from typing import Union
from typing import Union, List
import torch
from torch_mlir.passmanager import PassManager
from torch_mlir.ir import StringAttr
class TensorPlaceholder:
"""A class that represents a formal parameter of a given shape and dtype.
This class can be constructed explicitly from a shape and dtype:
```python
placeholder = TensorPlaceholder([3, 4], torch.float32)
```
This class can also be constructed from a `torch.Tensor` which is already
known to be a valid input to the function. In this case, a set of
dynamic axes are allowed to be specified.
```python
placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1])
# Equivalent to `TensorPlaceholder([3, -1], torch.float32)`
```
"""
def __init__(self, shape: List[int], dtype: torch.dtype):
"""Create a tensor with shape `shape` and dtype `dtype`.
Args:
shape: The shape of the tensor. A size of `-1` indicates that the
dimension has an unknown size.
dtype: The dtype of the tensor.
"""
self.shape = shape
self.dtype = dtype
@staticmethod
def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
"""Create a tensor placeholder that is like the given tensor.
Args:
tensor: The tensor to create a placeholder for.
dynamic_axes: A list of dynamic axes. If specified, the compiled
module will allow those axes to be any size at runtime.
"""
if dynamic_axes is None:
dynamic_axes = []
shape = []
for i, dim in enumerate(tensor.shape):
if i in dynamic_axes:
shape.append(-1)
else:
shape.append(dim)
return TensorPlaceholder(shape, tensor.dtype)
def get_module_name_for_debug_dump(module):
"""Gets a name suitable for a debug dump.