diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index f164e9384..585fa94d0 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -21,59 +21,12 @@ from torch_mlir.compiler_utils import ( run_pipeline_with_repro_report, OutputType, lower_mlir_module, + TensorPlaceholder, ) from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library -class TensorPlaceholder: - """A class that represents a formal parameter of a given shape and dtype. - - This class can be constructed explicitly from a shape and dtype: - ```python - placeholder = TensorPlaceholder([3, 4], torch.float32) - ``` - - This class can also be constructed from a `torch.Tensor` which is already - known to be a valid input to the function. In this case, a set of - dynamic axes are allowed to be specified. - ```python - placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1]) - # Equivalent to `TensorPlaceholder([3, -1], torch.float32)` - ``` - """ - - def __init__(self, shape: List[int], dtype: torch.dtype): - """Create a tensor with shape `shape` and dtype `dtype`. - - Args: - shape: The shape of the tensor. A size of `-1` indicates that the - dimension has an unknown size. - dtype: The dtype of the tensor. - """ - self.shape = shape - self.dtype = dtype - - @staticmethod - def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): - """Create a tensor placeholder that is like the given tensor. - - Args: - tensor: The tensor to create a placeholder for. - dynamic_axes: A list of dynamic axes. If specified, the compiled - module will allow those axes to be any size at runtime. - """ - if dynamic_axes is None: - dynamic_axes = [] - shape = [] - for i, dim in enumerate(tensor.shape): - if i in dynamic_axes: - shape.append(-1) - else: - shape.append(dim) - return TensorPlaceholder(shape, tensor.dtype) - - _example_arg = Union[TensorPlaceholder, torch.Tensor] _example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]] _example_args = Union[_example_args_for_one_method, "ExampleArgs"] diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index dd9f8d8f8..0ab47efa9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from torch_mlir.torchscript import TensorPlaceholder +from torch_mlir.compiler_utils import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index cb2799f85..ecf129d72 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -7,12 +7,61 @@ from io import StringIO import os import sys import tempfile -from typing import Union +from typing import Union, List +import torch from torch_mlir.passmanager import PassManager from torch_mlir.ir import StringAttr +class TensorPlaceholder: + """A class that represents a formal parameter of a given shape and dtype. + + This class can be constructed explicitly from a shape and dtype: + ```python + placeholder = TensorPlaceholder([3, 4], torch.float32) + ``` + + This class can also be constructed from a `torch.Tensor` which is already + known to be a valid input to the function. In this case, a set of + dynamic axes are allowed to be specified. + ```python + placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1]) + # Equivalent to `TensorPlaceholder([3, -1], torch.float32)` + ``` + """ + + def __init__(self, shape: List[int], dtype: torch.dtype): + """Create a tensor with shape `shape` and dtype `dtype`. + + Args: + shape: The shape of the tensor. A size of `-1` indicates that the + dimension has an unknown size. + dtype: The dtype of the tensor. + """ + self.shape = shape + self.dtype = dtype + + @staticmethod + def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): + """Create a tensor placeholder that is like the given tensor. + + Args: + tensor: The tensor to create a placeholder for. + dynamic_axes: A list of dynamic axes. If specified, the compiled + module will allow those axes to be any size at runtime. + """ + if dynamic_axes is None: + dynamic_axes = [] + shape = [] + for i, dim in enumerate(tensor.shape): + if i in dynamic_axes: + shape.append(-1) + else: + shape.append(dim) + return TensorPlaceholder(shape, tensor.dtype) + + def get_module_name_for_debug_dump(module): """Gets a name suitable for a debug dump.