[custom op] Generalize shape library logic to work with dtypes (#1594)

* [custom op] Generalize shape library logic to work with dtypes

This commit generalizes the shape library logic, so that dtype rules
for ops can also be expressed using the same mechanism. In other
words, each op can now have a shape function and a dtype function
specified in Python that is imported during lowering to calculate the
shapes and dtypes throught a program. For more information about how
to specify a dtype function, see the updated
`docs/adding_a_shape_and_dtype_function.md`.

For those not familiar with how the shape library works, the file
`docs/calculations_lib.md` provides an overview.
pull/1716/head
Ramiro Leal-Cavazos 2022-12-13 08:25:41 -08:00 committed by GitHub
parent 2acf7da63c
commit a710237437
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 3739 additions and 2312 deletions

View File

@ -61,14 +61,14 @@ jobs:
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
- name: Build and test (in-tree), also update ODS and shape library - name: Build and test (in-tree), also update ODS and abstract interpretation library
if: env.PT_HASH_CHANGED != '0' if: env.PT_HASH_CHANGED != '0'
run: | run: |
cd ${GITHUB_WORKSPACE} cd ${GITHUB_WORKSPACE}
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \ TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \ TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \
./build_tools/python_deploy/build_linux_packages.sh ./build_tools/python_deploy/build_linux_packages.sh
- name: Push changes to main branch - name: Push changes to main branch
@ -79,7 +79,7 @@ jobs:
git config user.name "Roll PyTorch Action" git config user.name "Roll PyTorch Action"
git fetch --recurse-submodules=no git fetch --recurse-submodules=no
git checkout main git checkout main
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main) git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
- name: Update PyTorch Build Cache (if running on main branch) - name: Update PyTorch Build Cache (if running on main branch)

View File

@ -53,8 +53,8 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration # Skip running tests if you want quick iteration
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}" TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
# Update ODS and shape library files # Update ODS and abstract interpretation library files
TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}" TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB:-OFF}"
PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE" PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}" TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
@ -119,7 +119,7 @@ function run_on_host() {
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \ -e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
-e "TM_PACKAGES=${package}" \ -e "TM_PACKAGES=${package}" \
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \ -e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
-e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \ -e "TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB=${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" \
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \ -e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \ -e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \ -e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
@ -164,10 +164,10 @@ function run_in_docker() {
in-tree) in-tree)
setup_venv "$python_version" setup_venv "$python_version"
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version" build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then
pushd /main_checkout/torch-mlir pushd /main_checkout/torch-mlir
./build_tools/update_torch_ods.sh ./build_tools/update_torch_ods.sh
./build_tools/update_shape_lib.sh ./build_tools/update_abstract_interp_lib.sh
popd popd
fi fi
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
@ -253,8 +253,8 @@ function test_in_tree() {
cd /main_checkout/torch-mlir/ cd /main_checkout/torch-mlir/
export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir" export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir"
echo ":::: Check that update_shape_lib.sh has been run" echo ":::: Check that update_abstract_interp_lib.sh has been run"
_check_file_not_changed_by ./build_tools/update_shape_lib.sh lib/Dialect/Torch/Transforms/ShapeLibrary.cpp _check_file_not_changed_by ./build_tools/update_abstract_interp_lib.sh lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
echo ":::: Check that update_torch_ods.sh has been run" echo ":::: Check that update_torch_ods.sh has been run"
_check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td _check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

View File

@ -1,5 +1,6 @@
#!/bin/bash #!/bin/bash
# Updates auto-generated shape library files for the `torch` dialect. # Updates auto-generated abstract interpretation library files for the
# `torch` dialect.
# #
# Environment variables: # Environment variables:
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names # TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
@ -41,6 +42,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
fi fi
PYTHONPATH="${pypath}" python \ PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.abstract_interp_lib_gen \
--pytorch_op_extensions=${ext_module:-""} \ --pytorch_op_extensions=${ext_module:-""} \
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" --torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"

View File

@ -0,0 +1,130 @@
# Torch-MLIR Abstract Interpretation Library Infrastructure
## Overview
The Torch-MLIR project has an infrastructure for maintaining a library of
calculation functions for different Torch operators, which supply extra
information such as result dtypes and shapes as well as decompositions. These
functions are fully executable specifications of the shape/dtype/decomposition
functions for each operator and can be authored and tested from Python for
convenience. These are then brought into the compiler and can be manipulated /
transformed for various purposes. Additionally, in the case of shape functions,
this effort is synergistic with upstream PyTorch efforts to maintain a library
of shape functions.
The two main use cases are:
- Refinement / inference. The `torch-shape-refinement-pipeline` and
`torch-dtype-refinement-pipeline` pass pipelines orchestrate a series of
passes that use the available information in the program to further refine the
types in the program.
- Error guard insertion for backends (Not Yet Implemented). The executable
functions can include error guards / assertions that abort the program in case
of invalid input (such as a matmul with a mismatching contracting dimension).
## Architecture
Functions are defined as TorchScript-able Python functions in
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`.
The signatures of the functions are systematically derived from Torch JIT
operator registry. Most shape functions are expected to reuse the upstream
helper functions
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
and any new shape functions should be added there.
The `build_tools/update_abstract_interp_lib.sh` script invokes
`abstract_interp_lib_gen.py` to generate an MLIR module containing the functions,
which is currently embedded as a string literal in
`lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp`.
The function `StringRef mlir::torch::Torch::getAbstractInterpLibrary()` is
available for use inside the compiler any time that the library is needed.
## Shape and Dtype Refinement Pipeline Architecture
One of the main services that Torch-MLIR provides for backends is to normalize
all Torch frontends into a common form which includes tensor shapes and dtypes
that are as precise as possible. This alleviates the need for backends to solve
this problem themselves. This process of shape and dtype refinement is
accomplished in Torch-MLIR through a pipeline of passes which uses the abstract
interpretation library combined with abstract interpretation of the calculation
functions to calculate shapes and dtypes that are as precise as possible.
The pipeline works as follows:
1. Calculations are reified. The `torch-reify-shape-calculations` and
`torch-reify-dtype-calculations` passes reify (i.e., materializes into the
IR) the functions for each op with a function in the calculation library. To
do this, the passes wrap those ops in a `torch.shape.calculate` or
`torch.dtype.calculate` op, respectively, which has two regions: 1) a body
with the op itself, and 2) the shape or dtype calculation, which calculates
the shapes or dtypes of the tensors yielded by the body.
2. Simplifying the functions and propagating the shapes and dtypes. After the
functions are reified, we then attempt to "optimize hard enough" until the
shapes and dtypes yielded by the calculation regions become obvious in the IR.
Those results are propagated through the IR, which usually reveals more
opportunities for simplification.
a. After reification, the functions are just a loose collection of
functions, which are difficult to analyze. The first step is to inline them.
b. After inlining, the `torch-simplify-shape-calculations` and
`torch-simplify-dtype-calculations` passes are used to simplify the
calculations. These passes bring in a number of targeted canonicalization
patterns and folds, along with a few specific patterns such as unrolling
fixed-trip-count loops and abstractly interpreting list operations (an
example is turning a series of "append" operations into a list
literal). These passes also look at the values yielded by the calculation
regions, and if the resulting shape or dtype can be deduced by looking at the
IR (for example, the shape is the list literal `[1, 2, 3]`), it will refine
the types of the `torch.shape.calculate` and `torch.dtype.calculate`
ops. This usually yields more opportunities for simplification. This process
runs to a fixed-point.
3. Dropping the calculations. Once all the types in the program have been
refined as much as possible, the ops that were originally wrapped in
`torch.shape.calculate` and `torch.dtype.calculate` are unwrapped by the
`torch-drop-abstract-interp-calculations` pass which drops the reified
calculations, leaving behind the shape and dtype refined program.
Inferring precise shapes and dtypes often is needed for correctness by
backends. That said, requiring "optimizing hard enough" for correctness is
usually considered quite brittle in a compiler flow. In this case, the saving
grace is that we are only optimizing the functions, which are authored by
compiler developers (not users), and thus there is some give-and-take in terms
of understanding the optimizable constructs while authoring the functions, or
improving the optimizations to enable easier authoring. Some brittleness is
likely to escape to users, unfortunately, since there will always be situations
where, for example, a statically shaped program allows the shape functions to be
simplified to a greater extent than in a dynamically shaped program (for
example, if the shape function checks "is this dimension of size 1"). We hope
that this is minimal.
## Adding to the abstract interpretation library
See [Adding a Shape and Dtype Function](adding_a_shape_and_dtype_function.md)
for details on how to add a shape and dtype function for an operator.
## Rationale
### Use of full operator signatures
The use of the full operator signature such as
`def atenaddTensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:`
for defining calculation functions is somewhat verbose and repetitive, especially when
there are multiple identical functions. Upstream uses a map with key-value
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is
more compact and less repetitive in some ways (upstream also allows trailing
arguments beyond those accepted by the shape function to be ignored, allowing
further deduplication). The decision to do it the more verbose way in Torch-MLIR
was based on the following goals:
- To make the system very easy to debug and test.
- To make the system maximally consistent between functions that are
implemented with the upstream shape helpers and the ones that are manually
written, which are still a fairly large and non-trivial set.
- To make it as mechanical as possible to add a new function.

View File

@ -1,75 +0,0 @@
# Adding a Shape Function
## Overview
As part of adding support for a Torch operator in Torch-MLIR, it is usually
necessary to define a shape function so that the compiler can infer the shapes
of result tensors for the operator. We use the [shape library](shape_lib.md) for this process.
## Step-by-step guide
We will use the example of adding support for the `torch.aten.tanh` op.
1. First, you need to find the shape function signature for the operator you are
implementing a shape function for. This can be found in
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` generated
by the `build_tools/update_torch_ods.sh` script. That file is the "rosetta
stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`,
and the shape function signature
`def atentanh(self: List[int]) -> List[int]:`. Note the use of `` as a
separator since `.` or `::` aren't legal in a Python identifier.
2. Paste the shape function signature into `shape_lib_gen.py` in an appropriate
place (ideally near other functions with a similar shape function). Note that
`shape_lib_gen.py` will check that this signature is verbatim identical with
the one given in `JITOperatorRegistryDump.txt` -- this ensures that the shape
functions don't get outdated if Torch changes an operator signature.
3. Fill in the body of the shape function. Ideally this will just be a call into
a helper function from
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
But in general, you will need to write the shape function and test it (see
the comments about "Shape function testing infrastructure" in
`shape_lib_gen.py`). New shape functions should be added upstream following
the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
though it can be useful to iterate locally in `shape_lib_gen.py` first.
4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
library. After this step happens, ideally everything "just works" and the
shape is now correctly inferred for the operator.
## When things go wrong
It is possible that the shape refinement pipeline (see
[Shape Refinement Pipeline Architecture](shape_lib.md#shape-refinement-pipeline-architecture))
is not able to infer the shape of a tensor with a given shape function. This
usually means that there is something about the shape function which the
optimizations in `torch-simplify-shape-functions`
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`) cannot handle.
To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the shape function and the IR dumps. The best IR dump to
look at varies, but frequently the IR dump right before `DropShapeCalculations`
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
- You might find that there is a loop with a non-constant trip count, but based
on your understanding of the shape function, you would expect it to be
simplified to a constant trip count -- you can then look at the trip count
calculation and see if there is a missing fold or canonicalization.
- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.
- You might find that there is an `Optional` value that you would expect to be
resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the shape function using constructs that
`torch-simplify-shape-functions` can handle (look at other shape functions for
examples, sometimes it requires writing things a little awkwardly).

View File

@ -0,0 +1,99 @@
# Adding Abstract Interpretation Functions
## Overview
As part of adding support for a Torch operator in Torch-MLIR, it is usually
necessary to define a shape and dtype function so that the compiler can infer
the shapes and dtypes of result tensors for the operator. We use the
[abstract interpretation library](abstract_interp_lib.md) for this process.
## Step-by-step guide
We will use the example of adding support for the `torch.aten.tanh` op.
1. First, you need to find the shape and dtype function signatures for
the operator you are implementing a functions for. This can be
found in
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
generated by the `build_tools/update_torch_ods.sh` script. That
file is the "rosetta stone" that allows translating between
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
function signatures are:
- `def atentanh〡shape(self: List[int]) -> List[int]:`
- `def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:`
Note the use of `` as a separator since `.` or `::` aren't legal
in a Python identifier.
2. Paste the function signature into `abstract_interp_lib_gen.py` in an
appropriate place (ideally near other functions with a similar
functions). Note that `abstract_interp_lib_gen.py` will check that
these signatures are verbatim identical with the ones given in
`JITOperatorRegistryDump.txt` -- this ensures that the functions
don't get outdated if Torch changes an operator signature.
3. Fill in the body of the function. Ideally this will just be a call
into a helper function from
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
But in general, you will need to write the function and test it
(see the comments about "Shape, dtype, and decomposition function
testing infrastructure" in `testing_framework.py`). New shape
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
though it can be useful to iterate locally in `abstract_interp_lib_gen.py`
first.
Similarly, dtype functions should ideally just be a call to the helper
`promote_dtypes` defined in `library_generator.py`. However, some ops will
require some extra logic to calculate the right result types. While dtypes
are expressed as `int`s in the arguments of the dtype function, using PyTorch
dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype
function is fully supported. Dtype functions are also expected to be fully
tested.
4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to
update the library. After this step happens, ideally everything
"just works" and the functions are now correctly inferred for the
operator.
## When things go wrong
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture))
is not able to infer the shape or dtype of a tensor with a given
abstract interpretation function. This usually means that there is something
about the function which the optimizations in
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions`
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`,
`lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`)
cannot handle.
To debug this, the overall goal is to pinpoint the IR construct that is not
being simplified. This is usually accomplished by a combination of looking at
the Python code for the function and the IR dumps. The best IR dump to look at
varies, but frequently the IR dump right before `DropAbstractInterpCalculations`
is the most useful, because it has already been simplified as much as possible,
making it is easy to see what is blocking further simplification. Examples of
issues you might see:
- You might find that there is a loop with a non-constant trip count,
but based on your understanding of the function, you would expect it
to be simplified to a constant trip count -- you can then look at
the trip count calculation and see if there is a missing fold or
canonicalization.
- You might find that there is a list operation that is not currently understood
by the optimizations. You can then teach the optimizations about that
operation.
- You might find that there is an `Optional` value that you would
expect to be resolved to either a concrete value or `None`. You can
then look at the calculation that produces the optional value and
see what folds or canonicalizations are missing.
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
guidance on debugging Torch-MLIR.
As a last resort, you can rewrite the function using constructs that
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle
(look at other functions for examples, sometimes it requires writing things a
little awkwardly).

View File

@ -446,10 +446,12 @@ LLVM updates in other projects like ONNX-MLIR and MLIR-HLO.
working on). If these fixes are too complex, please file a work-in-progress working on). If these fixes are too complex, please file a work-in-progress
PR explaining the issues you are running into asking for help so that someone PR explaining the issues you are running into asking for help so that someone
from the community can help. from the community can help.
5. **Update Shape Library**: Run `build_tools/update_shape_lib.sh`. This is 5. **Update Abstract Interpretation Library**: Run
sometimes needed because upstream changes can affect canonicalization and `build_tools/update_abstract_interp_lib.sh`. This is sometimes needed
other minor details of the IR in the shape library. See because upstream changes can affect canonicalization and other minor details
[docs/shape_lib.md](docs/shape_lib.md) for more details on the shape library. of the IR in the shape library. See
[docs/abstract_interp_lib.md](docs/abstract_interp_lib.md) for more details
on the abstract interpretation library.
Here are some examples of PRs updating the LLVM and MLIR-HLO submodules: Here are some examples of PRs updating the LLVM and MLIR-HLO submodules:

View File

@ -1,121 +0,0 @@
# Torch-MLIR Shape Library Infrastructure
## Overview
The Torch-MLIR project has an infrastructure for maintaining a library of shape
functions for different Torch operators. These shape functions are fully
executable specifications of the shape functions for each operator and can be
authored and tested from Python for convenience. These are then brought into the
compiler and can be manipulated / transformed for various purposes.
Additionally, this effort is synergistic with upstream PyTorch efforts to
maintain a library of shape functions.
The two main use cases are:
- Shape refinement / shape inference. The `torch-shape-refinement-pipeline` pass
pipeline orchestrates a series of passes that use the available shape information in the program to further refine the types in the program.
- Error guard insertion for backends (Not Yet Implemented). The executable shape
functions can include error guards / assertions that abort the program in case
of invalid input (such as a matmul with a mismatching contracting dimension).
## Architecture
Shape functions are defined as TorchScript-able Python functions in
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py`.
The signatures of the shape functions are systematically derived from Torch JIT
operator registry (mainly by replacing `Tensor` with `List[int]` in the operator
signatures). Most shape functions are expected to reuse the upstream helper
functions [`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
and any new shape functions should be added there.
The `build_tools/update_shape_lib.sh` script invokes `shape_lib_gen.py` to
generate an MLIR module containing the shape functions, which is currently
embedded as a string literal in `lib/Dialect/Torch/Transforms/ShapeLibrary.cpp`.
The function `StringRef mlir::torch::Torch::getShapeLibrary()` is available for
use inside the compiler any time that the shape library is needed.
## Shape Refinement Pipeline Architecture
One of the main services that Torch-MLIR provides for backends is to normalize
all Torch frontends into a common form which includes tensor shapes that are as
precise as possible. This alleviates the need for backends to solve this problem
themselves. This process of shape refinement is accomplished in Torch-MLIR
through a pipeline of passes which uses the shape library combined with abstract
interpretation of the shape functions to calculate shapes that are as precise as
possible.
The pipeline works as follows:
1. Shape calculations are reified. The `torch-reify-shape-calculations` reifies
(i.e., materializes into the IR) the shape functions for each op with a shape
function in the shape library. To do this, it wraps those ops in a
`torch.shape.calculate` op, which has two regions: 1) a body with the op
itself, and 2) the shape calculation, which calculates the shapes of the
tensors yielded by the body.
2. Simplifying the shape functions and propagating the shapes. After the shape
functions are reified, we then attempt to "optimize hard enough" until the
shapes yielded by the shape calculation regions become obvious in the IR.
Those shapes are propagated through the IR, which usually reveals more
opportunities for simplification.
a. After reification, the shape functions are just a loose collection of
functions, which are difficult to analyze. The first step is to inline them.
b. After inlining, the `torch-simplify-shape-calculations` pass is used to
simplify the shape calculations. This pass brings in a number of targeted
canonicalization patterns and folds, along with a few specific patterns such
as unrolling fixed-trip-count loops and abstractly interpreting list
operations (an example is turning a series of "append" operations into a list
literal). This pass also looks at the values yielded by the shape calculation
regions, and if the resulting shape can be deduced by looking at the IR (for
example, the shape is the list literal `[1, 2, 3]`), it will refine the types
of the `torch.shape.calculate` op. This usually yields more opportunities for
simplification. This process runs to a fixed-point.
3. Dropping the shape calculations. Once all the types in the program have been
refined as much as possible, the ops that were originally wrapped in
`torch.shape.calculate` are unwrapped by the `torch-drop-shape-calculations`
pass which drops the reified shape calculations, leaving behind the shape-refined program.
Inferring precise shape often is needed for correctness by backends. That said,
requiring "optimizing hard enough" for correctness is usually considered quite
brittle in a compiler flow. In this case, the saving grace is that we are only
optimizing the shape functions, which are authored by compiler developers (not
users), and thus there is some give-and-take in terms of understanding the
optimizable constructs while authoring the shape functions, or improving the
optimizations to enable easier authoring. Some brittleness is likely to escape
to users, unfortunately, since there will always be situations where, for
example, a statically shaped program allows the shape functions to be simplified
to a greater extent than in a dynamically shaped program (for example, if the
shape function checks "is this dimension of size 1"). We hope that this is
minimal.
## Adding to the shape library
See [Adding a Shape Function](adding_a_shape_function.md) for details on how to
add a shpae function for an operator.
## Rationale
### Use of full operator signatures
The use of the full operator signature such as
`def atenaddTensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:`
for defining shape functions is somewhat verbose and repetitive, especially when
there are multiple identical shape functions. Upstream uses a map with key-value
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is
more compact and less repetitive in some ways (upstream also allows trailing
arguments beyond those accepted by the shape function to be ignored, allowing
further deduplication). The decision to do it the more verbose way in Torch-MLIR
was based on the following goals:
- To make the system very easy to debug and test.
- To make the system maximally consistent between shape functions that are
implemented with the upstream shape helpers and the ones that are manually
written, which are still a fairly large and non-trivial set.
- To make it as mechanical as possible to add a new shape function.

View File

@ -171,6 +171,42 @@ m_TorchListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
return detail::torch_list_of_constant_ints_op_binder(bind_values); return detail::torch_list_of_constant_ints_op_binder(bind_values);
} }
namespace detail {
/// Matches the optional constant integers stored in a `torch.ListConstruct`.
struct torch_list_of_optional_constant_ints_op_binder {
SmallVectorImpl<Optional<int64_t>> &bind_values;
/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_list_of_optional_constant_ints_op_binder(
SmallVectorImpl<Optional<int64_t>> &bvs)
: bind_values(bvs) {}
bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.getElements()) {
int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num);
else if (value.getType().isa<Torch::NoneType>())
bind_values.push_back(llvm::None);
else
return false;
}
return true;
}
};
} // namespace detail
/// Matches the optional constant integers stored in a
/// `torch.prim.ListConstruct`.
inline detail::torch_list_of_optional_constant_ints_op_binder
m_TorchListOfOptionalConstantInts(
SmallVectorImpl<Optional<int64_t>> &bind_values) {
return detail::torch_list_of_optional_constant_ints_op_binder(bind_values);
}
namespace detail { namespace detail {
/// Matches the constant bools stored in a `torch.ListConstruct`. /// Matches the constant bools stored in a `torch.ListConstruct`.
struct torch_list_of_constant_bools_op_binder { struct torch_list_of_constant_bools_op_binder {

View File

@ -1101,6 +1101,40 @@ def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.flo
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)"; let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
} }
def Torch_PromoteDtypesOp: Torch_Op<"promote_dtypes", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "`promote_dtypes op : (int?[], int[]) -> (int)`";
let description = [{
This op is generated when the python function
`__torch_mlir_internal_promote_dtypes` is used in a dtype refinement
function. It represents the type promotion logic used by PyTorch to
determine result types.
The first argument is a list of optional ranks for each of the inputs
being used for promotion. The ranks are optional to allow representing
`Scalar` inputs, which follow their own set of promotion rules.
The second argument is a list of dtypes for each of the inputs being used
for promotion.
The order of the values in each list must be the same. In other words,
the ith rank and the ith dtype must be from the same Scalar/Tensor.
It is an error to call this op with empty lists or lists of different size.
}];
let arguments = (ins
AnyTorchListOfOptionalIntType:$ranks,
AnyTorchListOfTorchIntType:$dtypes
);
let results = (outs
Torch_IntType:$result
);
let assemblyFormat = "$ranks `,` $dtypes attr-dict `:` functional-type(operands, results)";
}
// To handle runtime assertions, torchscript provides us `torch._assert` operation. // To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The // But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like: // `torch._assert` would introduce control flow like:
@ -1140,31 +1174,31 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
let summary = "Shape calculation encapsulation op"; let summary = "Shape calculation encapsulation op";
let description = [{ let description = [{
The `torch.shape.calculate` op captures a shape calculation The `torch.shape.calculate` op captures a shape calculation
(in the region `shapeCalculation`) which calculates the shapes for (in the region `calculation`) which calculates the shapes for
the set of values yielded by the `body` region. the set of values yielded by the `body` region.
The `shapeCalculation` region yields a `!torch.list<int>` for each The `calculation` region yields a `!torch.list<int>` for each
value yielded by the `body` region. value yielded by the `body` region.
Conceptually, the `shapeCalculation` region executes first, then `body` Conceptually, the `calculation` region executes first, then `body`
region. So the `shapeCalculation` region can also contain arbitrary region. So the `calculation` region can also contain arbitrary
assertions or side-effecting code which guard the validity of the execution assertions or side-effecting code which guard the validity of the execution
of the body (typically by terminating the program with a of the body (typically by terminating the program with a
torch.prim.RaiseException op). torch.prim.RaiseException op).
The program has undefined behavior if the values yielded by the `body` The program has undefined behavior if the values yielded by the `body`
region do not have the shapes yielded by the `shapeCalculation` region. region do not have the shapes yielded by the `calculation` region.
}]; }];
let arguments = (ins); let arguments = (ins);
let results = (outs Variadic<AnyTorchType>:$results); let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region let regions = (region
SizedRegion<1>:$body, SizedRegion<1>:$body,
SizedRegion<1>:$shapeCalculation SizedRegion<1>:$calculation
); );
let assemblyFormat = [{ let assemblyFormat = [{
$body `shapes` $shapeCalculation attr-dict `:` type($results) $body `shapes` $calculation attr-dict `:` type($results)
}]; }];
} }
@ -1211,4 +1245,84 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
let hasVerifier = 1; let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===//
// Dtype calculation modeling ops.
//===----------------------------------------------------------------------===//
def Torch_DtypeCalculateOp : Torch_Op<"dtype.calculate", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
let summary = "Dtype calculation encapsulation op";
let description = [{
The `torch.dtype.calculate` op captures a dtype calculation
(in the region `calculation`) which calculates the dtypes for
the set of values yielded by the `body` region.
The `calculation` region yields a `!torch.int` for each
value yielded by the `body` region.
Conceptually, the `calculation` region executes first, then `body`
region. So the `calculation` region can also contain arbitrary
assertions or side-effecting code which guard the validity of the execution
of the body (typically by terminating the program with a
torch.prim.RaiseException op).
The program has undefined behavior if the values yielded by the `body`
region do not have the dtypes yielded by the `calculation` region.
}];
let arguments = (ins);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region
SizedRegion<1>:$body,
SizedRegion<1>:$calculation
);
let assemblyFormat = [{
$body `dtypes` $calculation attr-dict `:` type($results)
}];
}
def Torch_DtypeCalculateYieldOp : Torch_Op<"dtype.calculate.yield", [
Terminator,
ReturnLike,
HasParent<"::mlir::torch::Torch::DtypeCalculateOp">]> {
let summary = "yield-like terminator for torch.dtype.calculate";
let description = [{
This op terminates the `body` region of a `torch.dtype.calculate` op.
}];
let arguments = (ins
Variadic<AnyTorchType>:$results
);
let results = (outs);
let assemblyFormat = [{
attr-dict ($results^ `:` type($results))?
}];
}
def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
HasValueSemantics,
ReadOnly,
HasParent<"::mlir::torch::Torch::DtypeCalculateOp">]> {
let summary = "yield-like terminator for torch.dtype.calculate shape region";
let description = [{
This op terminates the `dtypeCalculation` region of a
`torch.dtype.calculate` op.
}];
let arguments = (ins
Variadic<Torch_IntType>:$results
);
let results = (outs);
let assemblyFormat = [{
attr-dict ($results^ `:` type($results))?
}];
let hasVerifier = 1;
}
#endif // TORCH_OPS #endif // TORCH_OPS

View File

@ -390,6 +390,9 @@ def AnyTorchListOfTensorType:
def AnyTorchListOfOptionalTensorType : def AnyTorchListOfOptionalTensorType :
ListOf<[AnyTorchOptionalTensorType], ListOf<[AnyTorchOptionalTensorType],
"Any optional tensor list type (Tensor?[])">; "Any optional tensor list type (Tensor?[])">;
def AnyTorchListOfOptionalIntType :
ListOf<[AnyTorchOptionalIntType],
"List of optional ints type (int?[])">;
def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">; def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">;
def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">; def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">;
// Note: TorchScript does not consider !torch.bool to be a Scalar. // Note: TorchScript does not consider !torch.bool to be a Scalar.

View File

@ -80,6 +80,9 @@ void createTorchSimplificationPipeline(
/// Creates a pipeline that refines shapes of tensor operations in the program. /// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(OpPassManager &pm); void createTorchShapeRefinementPipeline(OpPassManager &pm);
/// Creates a pipeline that refines dtype of tensor operations in the program.
void createTorchDtypeRefinementPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass(); std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
@ -102,7 +105,13 @@ std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyShapeCalculationsPass(); createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass(); std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyDtypeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDropAbstractInterpCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createEraseModuleInitializerPass(); createEraseModuleInitializerPass();
@ -113,7 +122,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass(); std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
StringRef getShapeLibrary(); StringRef getAbstractInterpLibrary();
} // namespace Torch } // namespace Torch

View File

@ -239,7 +239,7 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
} }
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Decompose complicated torch operations"; let summary = "Reify shape calculations.";
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()"; let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
let description = [{ let description = [{
}]; }];
@ -253,9 +253,23 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
}]; }];
} }
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> { def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
let summary = "Drop reified shape calculations."; let summary = "Reify dtype calculations.";
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()"; let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()";
let description = [{
}];
}
def SimplifyDtypeCalculations : Pass<"torch-simplify-dtype-calculations", "func::FuncOp"> {
let summary = "Simplify reified dtype calculations.";
let constructor = "mlir::torch::Torch::createSimplifyDtypeCalculationsPass()";
let description = [{
}];
}
def DropAbstractInterpCalculations : Pass<"torch-drop-abstract-interp-calculations", "func::FuncOp"> {
let summary = "Drop reified abstract interpretation calculations.";
let constructor = "mlir::torch::Torch::createDropAbstractInterpCalculationsPass()";
let description = [{ let description = [{
}]; }];
} }

View File

@ -11,7 +11,6 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
namespace mlir { namespace mlir {
@ -35,8 +34,23 @@ Type getTypeForTorchType(
MLIRContext *context, Type type, MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
Type getTorchTypeForScalarType(MLIRContext *context, FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt); torch_upstream::ScalarType dtypeInt);
// This is the type rule used for deciding dtype for:
// 1. A new tensor created from given data.
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
// If the data is floating-point, the `dtype` is inferred to be the
// default dtype, see `torch.get_default_dtype`.
Type getDefaultDtypeForTorchScalar(Type type);
// This is the type rule used for deciding builtin type for:
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
// PrimNumToTensorScalarOp.
// 2. The scalar type for type promotion when a scalar is an operand of scalar
// only operation like AtenAddOp.
Type getBuiltInTypeForTorchScalar(Type type);
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype); Type dtype);

View File

@ -2294,24 +2294,40 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
// ShapeCalculateOp // ShapeCalculateOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void ShapeCalculateOp::getSuccessorRegions( template <typename CalculateOp>
Optional<unsigned> index, ArrayRef<Attribute> operands, static void
SmallVectorImpl<RegionSuccessor> &regions) { getSuccessorRegionsForCalculateOp(CalculateOp op, Optional<unsigned> index,
(void)operands; ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
if (!index.has_value()) { if (!index.has_value()) {
// First thing the op does is branch into the shape calculation. // First thing the op does is branch into the calculation.
regions.emplace_back(&getShapeCalculation()); regions.emplace_back(&op.getCalculation());
return; return;
} }
if (*index == 0) { if (*index == 0) {
// Body returns control to the outer op, passing through results. // Body returns control to the outer op, passing through results.
regions.emplace_back(getResults()); regions.emplace_back(op.getResults());
return; return;
} }
assert(*index == 1); assert(*index == 1);
// Shape calculation branches to the body. // Calculation branches to the body.
regions.emplace_back(&getBody()); regions.emplace_back(&op.getBody());
}
void ShapeCalculateOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, operands, regions);
}
//===----------------------------------------------------------------------===//
// DtypeCalculateOp
//===----------------------------------------------------------------------===//
void DtypeCalculateOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
getSuccessorRegionsForCalculateOp(*this, index, operands, regions);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2333,6 +2349,25 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// DtypeCalculateYieldDtypesOp
//===----------------------------------------------------------------------===//
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
Optional<unsigned> index) {
// The dtype operands don't get forwarded to the body.
// MutableOperandRange always has an owning operation, even if empty, so
// create a 0-length range.
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
}
LogicalResult DtypeCalculateYieldDtypesOp::verify() {
auto parent = cast<DtypeCalculateOp>(getOperation()->getParentOp());
if (parent.getNumResults() != getNumOperands())
return emitOpError("expected number of dtypes to match number of results");
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GlobalSlotModuleInitializerOp // GlobalSlotModuleInitializerOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -8,7 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// //
// This file is auto-generated! Do not edit!!! // This file is auto-generated! Do not edit!!!
// Generated with the script `build_tools/update_shape_lib.sh`. // Generated with the script `build_tools/update_abstract_interp_lib.sh`.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -16,7 +16,7 @@
using namespace mlir; using namespace mlir;
StringRef mlir::torch::Torch::getShapeLibrary() { StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
#ifndef _MSC_VER #ifndef _MSC_VER
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Woverlength-strings" #pragma clang diagnostic ignored "-Woverlength-strings"
@ -5470,6 +5470,25 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %3 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -5618,11 +5637,11 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
@ -5721,6 +5740,23 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -6414,6 +6450,12 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -7086,6 +7128,15 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n" " return %4 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
"}\n" "}\n"
""; "";
// clang-format on // clang-format on

View File

@ -1,7 +1,7 @@
add_mlir_library(TorchMLIRTorchPasses add_mlir_library(TorchMLIRTorchPasses
AdjustCallingConventions.cpp AdjustCallingConventions.cpp
DecomposeComplexOps.cpp DecomposeComplexOps.cpp
DropShapeCalculations.cpp DropAbstractInterpCalculations.cpp
EraseModuleInitializer.cpp EraseModuleInitializer.cpp
Passes.cpp Passes.cpp
GlobalizeObjectGraph.cpp GlobalizeObjectGraph.cpp
@ -13,8 +13,12 @@ add_mlir_library(TorchMLIRTorchPasses
RefinePublicReturn.cpp RefinePublicReturn.cpp
RefineTypes.cpp RefineTypes.cpp
ReifyShapeCalculations.cpp ReifyShapeCalculations.cpp
ShapeLibrary.cpp ReifyDtypeCalculations.cpp
ReifyAbstractInterpCalculationsUtils.cpp
AbstractInterpLibrary.cpp
SimplifyShapeCalculations.cpp SimplifyShapeCalculations.cpp
SimplifyDtypeCalculations.cpp
SimplifyAbstractInterpCalculationsUtils.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms

View File

@ -9,29 +9,22 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
namespace { namespace {
class DropShapeCalculateOp : public OpConversionPattern<ShapeCalculateOp> { template <typename CalculateOp>
class DropCalculateOp : public OpConversionPattern<CalculateOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern<CalculateOp>::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(ShapeCalculateOp op, OpAdaptor adaptor, matchAndRewrite(CalculateOp op, typename CalculateOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Block *block = &op.getBody().front(); Block *block = &op.getBody().front();
Operation *terminator = block->getTerminator(); Operation *terminator = block->getTerminator();
@ -45,16 +38,18 @@ public:
} // namespace } // namespace
namespace { namespace {
class DropShapeCalculationsPass class DropAbstractInterpCalculationsPass
: public DropShapeCalculationsBase<DropShapeCalculationsPass> { : public DropAbstractInterpCalculationsBase<
DropAbstractInterpCalculationsPass> {
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
patterns.insert<DropShapeCalculateOp>(context); patterns.insert<DropCalculateOp<DtypeCalculateOp>>(context);
patterns.insert<DropCalculateOp<ShapeCalculateOp>>(context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect>(); target.addLegalDialect<Torch::TorchDialect>();
target.addIllegalOp<ShapeCalculateOp>(); target.addIllegalOp<DtypeCalculateOp, ShapeCalculateOp>();
target.addLegalOp<func::FuncOp>(); target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
@ -66,6 +61,6 @@ class DropShapeCalculationsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDropShapeCalculationsPass() { mlir::torch::Torch::createDropAbstractInterpCalculationsPass() {
return std::make_unique<DropShapeCalculationsPass>(); return std::make_unique<DropAbstractInterpCalculationsPass>();
} }

View File

@ -90,8 +90,8 @@ static LogicalResult checkType(Operation *op, Type type,
->emitError( ->emitError(
"unsupported by backend contract: tensor with unknown rank") "unsupported by backend contract: tensor with unknown rank")
.attachNote() .attachNote()
.append("this is likely due to a missing shape transfer function " .append("this is likely due to a missing transfer function "
"in shape_lib_gen.py"); "in abstract_interp_lib_gen.py");
} else { } else {
return failure(); return failure();
} }

View File

@ -121,6 +121,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// inference), because Torch type promotion rules actually depend on the shape // inference), because Torch type promotion rules actually depend on the shape
// of the operand. // of the operand.
createTorchShapeRefinementPipeline(pm); createTorchShapeRefinementPipeline(pm);
createTorchDtypeRefinementPipeline(pm);
// Refine types in the program, which mainly means inferring dtypes of ops. // Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass()); pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by // Propagate to ABI return types the shape/dtype information discovered by
@ -137,23 +138,41 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
} }
} }
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) { static void createRefinementPipeline(
// Reify the shape functions for each op that is present in the shape library. mlir::OpPassManager &pm,
pm.addPass(Torch::createReifyShapeCalculationsPass()); llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()>
reifyCalculationsPass,
llvm::function_ref<
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
simplifyCalculationsPass) {
// Reify the library functions for each op that is present in the library.
pm.addPass(reifyCalculationsPass());
// Inline the shape functions to enable analysis and transformation. // Inline the library functions to enable analysis and transformation.
// TODO: Only inline shape functions (this will currently inline everything). // TODO: Only inline library functions (this will currently inline
pm.addPass(createInlinerPass()); // everything).
pm.addPass(mlir::createInlinerPass());
// Now, try to simplify shape calculations. This is unfortunately a "optimize // Now, try to simplify calculations. This is unfortunately a "optimize
// as hard as possible" kind of thing, so it's inherently somewhat brittle. // as hard as possible" kind of thing, so it's inherently somewhat brittle.
// The idea is to keep strengthening what we do here to support the shape // The idea is to keep strengthening what we do here to support the
// library. We don't need to support arbitrary programs, thankfully. // library functions. We don't need to support arbitrary programs, thankfully.
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass()); pm.addNestedPass<mlir::func::FuncOp>(simplifyCalculationsPass());
// Run CSE, then see if we can simplify further. // Run CSE, then see if we can simplify further.
pm.addNestedPass<func::FuncOp>(createCSEPass()); pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass()); pm.addNestedPass<mlir::func::FuncOp>(simplifyCalculationsPass());
// Drop shape calculations, leaving behind the shape-refined program. // Drop calculations, leaving behind the-refined program.
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass()); pm.addNestedPass<mlir::func::FuncOp>(
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
}
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
Torch::createSimplifyShapeCalculationsPass);
}
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) {
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
Torch::createSimplifyDtypeCalculationsPass);
} }

View File

@ -491,44 +491,6 @@ private:
}; };
} // namespace } // namespace
// This is the type rule used for deciding dtype for:
// 1. A new tensor created from given data.
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
// If the data is floating-point, the `dtype` is inferred to be the
// default dtype, see `torch.get_default_dtype`.
static Type getDefaultDtypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>()) {
// For now, use float32 which is the initial default dtype returned by
// `torch.get_default_dtype`.
return Float32Type::get(context);
}
if (type.isa<Torch::IntType>())
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
return IntegerType::get(context, 1);
llvm_unreachable(
"getDefaultDtypeForTorchScalar called on an unsupported type");
}
// This is the type rule used for deciding builtin type for:
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
// PrimNumToTensorScalarOp.
// 2. The scalar type for type promotion when a scalar is an operand of scalar
// only operation like AtenAddOp.
static Type getBuiltInTypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>())
return Float64Type::get(context);
if (type.isa<Torch::IntType>())
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
return IntegerType::get(context, 1);
llvm_unreachable(
"getBuiltInTypeForTorchScalar called on an unsupported type");
}
static torch_upstream::ResultTypeState static torch_upstream::ResultTypeState
updateResultTypeState(Type scalarType, updateResultTypeState(Type scalarType,
const torch_upstream::ResultTypeState &inState) { const torch_upstream::ResultTypeState &inState) {
@ -583,8 +545,11 @@ static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
state = state =
updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state); updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state);
} }
return getTorchTypeForScalarType(scalarTypes[0].getContext(), FailureOr<Type> result = getTorchTypeForScalarType(
result_type(state)); scalarTypes[0].getContext(), result_type(state));
if (failed(result))
return Type();
return *result;
} }
// Returns most generic type Type() if the tensor dtype is unknown. // Returns most generic type Type() if the tensor dtype is unknown.
@ -707,9 +672,9 @@ void TypeAnalysis::visitOperation(Operation *op,
} }
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr. // Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp, if (isa<AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) { AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
ValueKnowledge knowledge = ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype; Type dtype = operands[0]->getValue().dtype;
@ -770,7 +735,7 @@ void TypeAnalysis::visitOperation(Operation *op,
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp, if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp, AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenThresholdBackwardOp, AtenFloorDivideOp>(op)) { AtenThresholdBackwardOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType( knowledge.dtype = getPromotedResultType(
@ -816,7 +781,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote LHS with scalar RHS. // Promote LHS with scalar RHS.
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) { AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
auto lhs = operands[0]->getValue(); auto lhs = operands[0]->getValue();
Value scalar = op->getOperand(1); Value scalar = op->getOperand(1);
auto knowledge = auto knowledge =

View File

@ -0,0 +1,290 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return "__torch_mlir_shape_fn.";
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return "__torch_mlir_dtype_fn.";
llvm_unreachable(
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
}
static Operation *createCalculateOp(OpBuilder &b, Location loc,
TypeRange resultTypes,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateOp>(loc, resultTypes);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateOp>(loc, resultTypes);
llvm_unreachable(
"`createCalculateOp` called with an unsupported `LibraryFunctionKind`");
}
static Operation *createCalculateYieldOp(OpBuilder &b, Location loc,
ValueRange results,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateYieldOp>(loc, results);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateYieldOp>(loc, results);
llvm_unreachable("`createCalculateYieldOp` called with an unsupported "
"`LibraryFunctionKind`");
}
static Operation *
createCalculateYieldCalculationOp(OpBuilder &b, Location loc,
ValueRange results,
LibraryFunctionKind libFuncKind) {
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
return b.create<ShapeCalculateYieldShapesOp>(loc, results);
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
return b.create<DtypeCalculateYieldDtypesOp>(loc, results);
llvm_unreachable("`createCalculateYieldCalculationOp` called with an "
"unsupported `LibraryFunctionKind`");
}
LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
Operation *op, ModuleOp library, LibraryFunctionKind libFuncKind,
SmallVector<std::string> &libFuncNamesUsed,
function_ref<FailureOr<SmallVector<Value>>(OpBuilder &, Location,
ValueRange, func::FuncOp)>
libFuncArgsBuilder) {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
auto name = op->getName().stripDialect();
// For value-semantic variant ops, i.e. valsem-ops (ops that are
// mechanically consistent with existing torch conventions of in-place vs.
// out-of-place (value-semantic) variants), remove the prefix when
// looking them up in the library.
if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
std::string libFuncName =
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
if (!libFunc)
return success();
libFuncNamesUsed.push_back(libFuncName);
OpBuilder b(op);
Operation *calculate =
createCalculateOp(b, loc, op->getResultTypes(), libFuncKind);
op->replaceAllUsesWith(calculate);
{
// Move the op into the body of the `torch.{libFuncType}.calculate` op
// and yield its results.
OpBuilder b(context);
Block *bodyBlock = b.createBlock(&calculate->getRegion(0));
op->moveBefore(bodyBlock, bodyBlock->end());
b.setInsertionPointAfter(op);
createCalculateYieldOp(b, loc, op->getResults(), libFuncKind);
}
{
OpBuilder b(context);
b.createBlock(&calculate->getRegion(1));
// Create the call to the library function!
FailureOr<SmallVector<Value>> libFuncArgs =
libFuncArgsBuilder(b, loc, op->getOperands(), libFunc);
if (failed(libFuncArgs))
return failure();
auto call = b.create<mlir::func::CallOp>(loc, libFunc, *libFuncArgs);
// Python models multiple results with a tuple, so we need to unpack it
// if the op has multiple results.
SmallVector<Value> unpackedResults;
assert(call.getNumResults() == 1 &&
"Multiple results are packed in a tuple in Python!");
Value result = call.getResult(0);
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
auto unpack = b.create<PrimTupleUnpackOp>(
loc, tupleType.getContainedTypes(), result);
llvm::append_range(unpackedResults, unpack.getResults());
} else {
unpackedResults.push_back(result);
}
// Terminate the region.
createCalculateYieldCalculationOp(b, loc, unpackedResults, libFuncKind);
}
return success();
}
void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library,
SmallVector<std::string> functionsNeeded) {
// Import just the functions we need. This includes transitive callees,
// so we use a worklist algorithm.
llvm::StringSet<> importedFunctions;
while (!functionsNeeded.empty()) {
std::string symName = functionsNeeded.pop_back_val();
if (importedFunctions.contains(symName))
continue;
auto func = library.lookupSymbol<func::FuncOp>(symName);
assert(func && "broken library");
// Move the function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time
// so this is safe to do).
func->moveBefore(&module.getBody()->front());
// Set the visibility to private so that the functions go away
// nicely after we are done with them.
func.setVisibility(SymbolTable::Visibility::Private);
// Continue the DFS.
importedFunctions.insert(symName);
func.walk([&](func::CallOp op) {
functionsNeeded.push_back(op.getCallee().str());
});
}
}
FailureOr<Value> Torch::adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation) {
operand = baseTransformation(b, loc, operand, desiredType);
// No need for adjustment if they already match.
auto operandType = operand.getType();
if (operandType == desiredType)
return operand;
if (desiredType.isa<Torch::AnyType>()) {
// Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need
// to actually support it in the future.
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// !torch.union<int, float> is the type used for `Scalar` inputs. At
// compile time, such inputs will usually be resolved to an `int` or a `float`
// so we need to derefine to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType.isa<Torch::IntType, Torch::FloatType>();
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// If the operand is NoneType, then we just need to derefine it to the
// optional type in the function signature.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
"Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.
// For the non-None case, we need to unwrap the optional type and then adjust
// it recursively (which also takes care of derefining it to ultimate desired
// type).
// A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
// if optional is None:
// return derefine(None)
// else:
// return adjust(unchecked_cast(optional))
auto none = b.create<ConstantNoneOp>(loc);
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
{
Region &thenRegion = primIf.getThenRegion();
b.createBlock(&thenRegion, thenRegion.end());
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
}
{
Region &elseRegion = primIf.getElseRegion();
b.createBlock(&elseRegion, elseRegion.end());
auto downcasted = b.create<PrimUncheckedCastOp>(
loc, operandOptionalType.getContainedType(), operand);
FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, downcasted, desiredType, baseTransformation);
if (failed(adjusted))
return failure();
b.create<PrimIfYieldOp>(loc, *adjusted);
}
b.setInsertionPointAfter(primIf);
return primIf.getResult(0);
}
}
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
FailureOr<Value> adjusted = adjustFunctionArg(
b, loc, operand, desiredOptionalType.getContainedType(),
baseTransformation);
if (failed(adjusted))
return failure();
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
}
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
// Pseudocode:
//
// operand = ...
// adjusted_list = []
// for i in range(len(operand)):
// adjusted_list.append(adjust(operand[i]))
// return adjusted_list
auto providedType = operand.getType().cast<Torch::ListType>();
Value adjustedList =
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
// Create a for-like PrimLoopOp.
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
/*initialCondition=*/cTrue,
/*iterArgsInit=*/ValueRange({}));
// Create the loop body.
{
OpBuilder::InsertionGuard guard(b);
Block *body =
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
TypeRange({b.getType<Torch::IntType>()}), {loc});
Value iterationNumber = body->getArgument(0);
Value element = b.create<Aten__Getitem__TOp>(
loc, providedType.getContainedType(), operand, iterationNumber);
FailureOr<Value> adjustedElement =
adjustFunctionArg(b, loc, element, desiredListType.getContainedType(),
baseTransformation);
if (failed(adjustedElement))
return failure();
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
*adjustedElement);
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
/*iterArgs=*/ValueRange({}));
}
return adjustedList;
}
// The library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation).
if (desiredType.isa<Torch::FloatType>() &&
operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}
// Pass the operand as-is.
return operand;
}

View File

@ -0,0 +1,67 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
namespace mlir {
namespace torch {
namespace Torch {
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition };
// Searches the function library for an abstract interpretation function for
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
// the first region, and a call to the abstract interpretation function is
// inserted into the second region.
//
// Note: this returns success if no abstract interpretation function is found,
// since some abstract interpretation functions (such as decompositions) are
// optional.
//
// Note: This function does *not* import the abstract interpretation function
// from the library into the IR.
LogicalResult wrapWithCalculateOpIfLibraryFunctionAvailable(
Operation *op, ModuleOp library, LibraryFunctionKind funcKind,
SmallVector<std::string> &libFuncNamesUsed,
function_ref<FailureOr<SmallVector<Value>>(OpBuilder &, Location,
ValueRange, func::FuncOp)>
libFuncArgsBuilder);
// Imports the functions in `functionsNeeded` from the library into the module.
// This function assumes that all functions needed exist in the library.
//
// Note: This function modifies the library.
void importLibraryFunctions(ModuleOp module, ModuleOp library,
SmallVector<std::string> functionsNeeded);
// Recursively adjust `operand` to match `desiredType`.
//
// This function by default handles a few types such as `UnionType`,
// `OptionalType`, and `ListType`, to name a few. Handling of base element types
// can be customized by defining `baseTransformation`, which gets called at the
// beginning of each recursive call. This function can be thought of as mapping
// `baseTransformation` across `UnionType/OptionalType/ListType`.
FailureOr<Value> adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
[](OpBuilder &, Location, Value operand, Type) { return operand; });
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H

View File

@ -0,0 +1,136 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static bool isTensorTypeOrWrappedTensorType(Type type) {
// Allowing tuples as arguments to dtype calculation functions can cause
// issues. For example, if an argument is a tuple of tensors and ints, there
// would be no way of differentiating the original ints from the ints created
// to represent the dtype and rank of the tensors. Therefore, to avoid this
// and keep things simple, the tuple type is not allowed. This works well in
// practice, since PyTorch op signatures don't seem to take tuples as inputs.
assert(!type.isa<Torch::TupleType>() &&
"dtype calculation functions are expected to not have tuples of "
"tensors as arguments");
if (type.isa<Torch::BaseTensorType>())
return true;
if (auto optionalType = type.dyn_cast<Torch::OptionalType>()) {
return isTensorTypeOrWrappedTensorType(optionalType.getContainedType());
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
return isTensorTypeOrWrappedTensorType(listType.getContainedType());
} else {
return false;
}
}
// Massage the op operands to match the dtype function signature.
// The dtype function generally takes the same operands as the op, with a few
// systematic modifications, such as replacing tensors with a rank and dtype
// argument.
static FailureOr<SmallVector<Value>>
dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
ValueRange originalOperands, func::FuncOp dtypeFunc) {
// Turns a tensor operand into an operand representing the rank of the tensor
auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() &&
operand.getType().isa<Torch::BaseTensorType>()) {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(b.getContext()));
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
return b.create<AtenLenTOp>(loc, desiredType, size);
}
return operand;
};
// Turns a tensor operand into an operand representing the dtype of the tensor
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() &&
operand.getType().isa<Torch::BaseTensorType>()) {
return b.create<PrimDtypeOp>(loc, desiredType, operand);
}
return operand;
};
SmallVector<Value> dtypeFuncArgs;
ArrayRef<Type> desiredTypes = dtypeFunc.getArgumentTypes();
for (auto operand : originalOperands) {
assert(!desiredTypes.empty() &&
"`dtypeFunc` should have at least one argument for each argument in "
"`originalOperands`");
Type desiredType = desiredTypes.front();
if (isTensorTypeOrWrappedTensorType(operand.getType())) {
assert(desiredTypes.size() >= 2 &&
"`dtypeFunc` should have two arguments for each tensor argument "
"in `originalOperands`");
FailureOr<Value> rankArg, dtypeArg;
if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType,
rankArgAdjuster)))
return failure();
desiredTypes = desiredTypes.drop_front();
desiredType = desiredTypes.front();
if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType,
dtypeArgAdjuster)))
return failure();
dtypeFuncArgs.append({*rankArg, *dtypeArg});
} else {
FailureOr<Value> otherArg;
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType)))
return failure();
dtypeFuncArgs.push_back(*otherArg);
}
desiredTypes = desiredTypes.drop_front();
}
return dtypeFuncArgs;
}
namespace {
class ReifyDtypeCalculationsPass
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
// Walk all the operations, and if we have a dtype function, wrap the op
// in a `torch.dtype.calculate` op.
SmallVector<std::string> functionsNeeded;
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
return wrapWithCalculateOpIfLibraryFunctionAvailable(
op, *library, LibraryFunctionKind::DtypeFunction, functionsNeeded,
dtypeFunctionArgsBuilder);
});
if (walkResult.wasInterrupted())
return signalPassFailure();
importLibraryFunctions(module, *library, std::move(functionsNeeded));
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
Torch::createReifyDtypeCalculationsPass() {
return std::make_unique<ReifyDtypeCalculationsPass>();
}

View File

@ -9,209 +9,49 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser/Parser.h" #include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
static Value adjustShapeFunctionArg(Value operand, Type desiredType, static FailureOr<SmallVector<Value>>
OpBuilder &b, Location loc); shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
ValueRange originalOperands, func::FuncOp shapeFunc) {
static Value adjustListArg(Value operand, Torch::ListType desiredType,
OpBuilder &b, Location loc) {
auto providedType = operand.getType().cast<Torch::ListType>();
// Pseudocode:
//
// operand = ...
// adjusted_list = []
// for i in range(len(operand)):
// adjusted_list.append(adjust(operand[i]))
// return adjusted_list
Value adjustedList =
b.create<PrimListConstructOp>(loc, desiredType, ValueRange({}));
// Create a for-like PrimLoopOp.
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
/*initialCondition=*/cTrue,
/*iterArgsInit=*/ValueRange({}));
// Create the loop body.
{
OpBuilder::InsertionGuard guard(b);
Block *body =
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
TypeRange({b.getType<Torch::IntType>()}), {loc});
Value iterationNumber = body->getArgument(0);
Value element = b.create<Aten__Getitem__TOp>(
loc, providedType.getContainedType(), operand, iterationNumber);
Value adjustedElement =
adjustShapeFunctionArg(element, desiredType.getContainedType(), b, loc);
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
adjustedElement);
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
/*iterArgs=*/ValueRange({}));
}
return adjustedList;
}
static Value adjustShapeFunctionArg(Value operand, Type desiredType,
OpBuilder &b, Location loc) {
auto operandType = operand.getType();
// No need for adjustment if they already match.
if (operandType == desiredType)
return operand;
if (desiredType.isa<Torch::AnyType>()) {
// Generator's are currently passed as Any because TorchScript cannot
// compile a function with Generator type arguments.
// Ignoring that hack, this is a correct handling of Any type should we need
// to actually support it in the future.
return b.create<DerefineOp>(loc, desiredType, operand);
}
// If the operand is NoneType, then we just need to derefine it to the
// optional type in the shape function signature.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
"Don't expect shape functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand);
}
// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.
// For the non-None case, we need to unwrap the optional type and then adjust
// it recursively (which also takes care of derefining it to ultimate desired
// type).
// A case where this happens is `!torch.optional<vtensor>` ->
// `!torch.optional<list<int>>>`.
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
if (desiredType.isa<Torch::OptionalType>()) {
// if optional is None:
// return derefine(None)
// else:
// return adjust(unchecked_cast(optional))
auto none = b.create<ConstantNoneOp>(loc);
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
{
Region &thenRegion = primIf.getThenRegion();
b.createBlock(&thenRegion, thenRegion.end());
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
}
{
Region &elseRegion = primIf.getElseRegion();
b.createBlock(&elseRegion, elseRegion.end());
auto downcasted = b.create<PrimUncheckedCastOp>(
loc, operandOptionalType.getContainedType(), operand);
auto adjusted = adjustShapeFunctionArg(downcasted, desiredType, b, loc);
b.create<PrimIfYieldOp>(loc, adjusted);
}
b.setInsertionPointAfter(primIf);
return primIf.getResult(0);
}
}
// If the desired type is OptionalType, then recursively adjust the operand to
// the contained type, then derefine it to `!torch.optional`. For example,
// `!torch.vtensor -> !torch.optional<list<int>>>`.
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
auto adjusted = adjustShapeFunctionArg(
operand, desiredOptionalType.getContainedType(), b, loc);
return b.create<DerefineOp>(loc, desiredType, adjusted);
}
// The shape library functions have tensor operands replaced with
// `!torch.list<int>` types for the shape. Get the sizes.
if (operand.getType().isa<Torch::BaseTensorType>()) {
assert(desiredType.isa<Torch::ListType>() &&
"Don't expect shape functions to have tensor parameters");
return b.create<AtenSizeOp>(loc, desiredType, operand);
}
// Run this after `operand.getType().isa<Torch::BaseTensorType>()` so that
// `!torch.vtensor` -> `!torch.list<int>` is handled there specially
// first.
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
return adjustListArg(operand, desiredListType, b, loc);
}
// The shape library functions use `float` where the operator
// signature uses `Scalar` (see comments in torch_ods_gen.py for
// explanation).
if (desiredType.isa<Torch::FloatType>() &&
operand.getType().isa<Torch::IntType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand);
}
// Pass the operand as-is.
return operand;
}
// Populates the shape calculation region with a call to the shape function
// from the shape library.
static LogicalResult
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
mlir::func::FuncOp shapeFunction) {
// Create a call to the shape function in the `shapeCalculation` region.
// We will import the callee from the shape library later.
OpBuilder b(op.getContext());
Location loc = op->getLoc();
b.createBlock(&op.getShapeCalculation());
// Massage the op operands to match the shape function signature. // Massage the op operands to match the shape function signature.
// The shape function generally takes the same operands as the op, with a few // The shape function generally takes the same operands as the op, with a few
// systematic modifications, such as replacing tensors with their shapes. // systematic modifications, such as replacing tensors with their shapes.
SmallVector<Value> shapeFunctionArgs; SmallVector<Value> shapeFuncArgs;
for (auto operandAndDesiredType : for (auto operandAndDesiredType :
llvm::zip(originalOperands, shapeFunction.getArgumentTypes())) { llvm::zip(originalOperands, shapeFunc.getArgumentTypes())) {
Value operand; Value operand;
Type desiredType; Type desiredType;
std::tie(operand, desiredType) = operandAndDesiredType; std::tie(operand, desiredType) = operandAndDesiredType;
Value shapeFunctionArg = FailureOr<Value> shapeFuncArg = adjustFunctionArg(
adjustShapeFunctionArg(operand, desiredType, b, loc); b, loc, operand, desiredType,
if (!shapeFunctionArg) [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
// The shape library functions have tensor operands replaced with
// `!torch.list<int>` types for the shape. Get the sizes.
auto desiredListType = desiredType.dyn_cast<Torch::ListType>();
if (!desiredListType)
return operand;
if (operand.getType().isa<Torch::BaseTensorType>() &&
desiredListType.getContainedType().isa<Torch::IntType>()) {
return b.create<AtenSizeOp>(loc, desiredType, operand);
}
return operand;
});
if (failed(shapeFuncArg))
return failure(); return failure();
shapeFunctionArgs.push_back(shapeFunctionArg); shapeFuncArgs.push_back(*shapeFuncArg);
} }
// Create the call to the shape function! return shapeFuncArgs;
auto call =
b.create<mlir::func::CallOp>(loc, shapeFunction, shapeFunctionArgs);
// Python models multiple results with a tuple, so we need to unpack it
// if the op has multiple results.
SmallVector<Value> unpackedResults;
assert(call.getNumResults() == 1 &&
"Multiple results are packed in a tuple in Python!");
Value result = call.getResult(0);
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
auto unpack =
b.create<PrimTupleUnpackOp>(loc, tupleType.getContainedTypes(), result);
llvm::append_range(unpackedResults, unpack.getResults());
} else {
unpackedResults.push_back(result);
}
// Terminate the region.
b.create<ShapeCalculateYieldShapesOp>(loc, unpackedResults);
return success();
} }
namespace { namespace {
@ -222,74 +62,23 @@ class ReifyShapeCalculationsPass
ModuleOp module = getOperation(); ModuleOp module = getOperation();
// TODO: Find a way to not have to parse this every time. // TODO: Find a way to not have to parse this every time.
// The shape library is O(#ops we know about), and this pass should be // The library is O(#ops we know about), and this pass should be
// O(#ops in the program) ideally. // O(#ops in the program) ideally.
auto shapeLibrary = parseSourceString<ModuleOp>(getShapeLibrary(), context); OwningOpRef<ModuleOp> library =
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
// Walk all the operations, and if we have a shape function, wrap the op // Walk all the operations, and if we have a shape function, wrap the op
// in a `torch.shape.calculate` op. // in a `torch.shape.calculate` op.
SmallVector<std::string> neededShapeFunctions; SmallVector<std::string> functionsNeeded;
bool hadError = false; WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
module.walk([&](Operation *op) { return wrapWithCalculateOpIfLibraryFunctionAvailable(
Location loc = op->getLoc(); op, *library, LibraryFunctionKind::ShapeFunction, functionsNeeded,
auto name = op->getName().stripDialect(); shapeFunctionArgsBuilder);
// For value-semantic variant ops, i.e. valsem-ops (ops that are
// mechanically consistent with existing torch conventions of in-place vs.
// out-of-place (value-semantic) variants), remove the prefix when
// looking them up in the shape library.
if (name.startswith("valsem."))
name = name.drop_front(strlen("valsem."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction =
shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
if (!shapeFunction)
return;
neededShapeFunctions.push_back(shapeFunctionName);
auto shapeCalculate =
OpBuilder(op).create<ShapeCalculateOp>(loc, op->getResultTypes());
op->replaceAllUsesWith(shapeCalculate);
{
// Move the op into the body of the `torch.shape.calculate` op and yield
// its results.
OpBuilder b(context);
Block *block = b.createBlock(&shapeCalculate.getBody());
op->moveBefore(block, block->end());
b.setInsertionPointAfter(op);
b.create<ShapeCalculateYieldOp>(loc, op->getResults());
}
if (failed(populateShapeCalculationRegion(
shapeCalculate, op->getOperands(), shapeFunction))) {
hadError = true;
return;
}
}); });
if (hadError) if (walkResult.wasInterrupted())
return signalPassFailure(); return signalPassFailure();
importLibraryFunctions(module, *library, std::move(functionsNeeded));
// Import just the functions we need. This includes transitive callees,
// so we use a worklist algorithm.
llvm::StringSet<> importedFunctions;
SmallVector<std::string> worklist;
llvm::append_range(worklist, neededShapeFunctions);
while (!worklist.empty()) {
auto symName = worklist.pop_back_val();
if (importedFunctions.count(symName))
continue;
auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
assert(func && "broken shape library");
// Move the shape function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time
// so this is safe to do).
func->moveBefore(&module.getBody()->front());
// Set the visibility to private so that the shape functions go away
// nicely after we are done with them.
func.setVisibility(SymbolTable::Visibility::Private);
// Continue the DFS.
importedFunctions.insert(symName);
func.walk(
[&](func::CallOp op) { worklist.push_back(op.getCallee().str()); });
}
} }
}; };
} // namespace } // namespace

View File

@ -0,0 +1,99 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "SimplifyAbstractInterpCalculationsUtils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum,
Type newResultType,
PatternRewriter &rewriter) {
Location loc = calculateOp->getLoc();
auto result = calculateOp->getResult(resultNum);
Type originalResultType = result.getType();
Type updatedType;
if (auto originalBaseTensorType =
originalResultType.template dyn_cast<BaseTensorType>()) {
// If we didn't get any new information, there is nothing left for us to do.
updatedType = meetTensorTypes(originalBaseTensorType,
newResultType.cast<BaseTensorType>());
if (!updatedType || updatedType == originalBaseTensorType)
return rewriter.notifyMatchFailure(
calculateOp, "New type information does not refine old type");
} else if (auto originalResultType =
result.getType().template dyn_cast<Torch::NumberType>()) {
if (!newResultType.isa<Torch::FloatType, Torch::IntType>()) {
return rewriter.notifyMatchFailure(
calculateOp,
"Refinement of `NumberType` must be a `FloatType` or `IntType`");
}
updatedType = newResultType;
} else {
return rewriter.notifyMatchFailure(calculateOp,
"Unimplemented: Expected result type to "
"be `BaseTensorType` or `NumberType`");
}
// Update all the uses of the result type to the new type, if possible. Insert
// a TensorStaticInfoCastOp for any users that might require the exact
// previous type.
Value originalTypedValue;
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
if (use.getOwner()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
continue;
}
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(calculateOp);
if (originalResultType.isa<BaseTensorType>()) {
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
loc, originalResultType, result);
} else if (originalResultType.isa<Torch::NumberType>()) {
originalTypedValue =
rewriter.create<DerefineOp>(loc, originalResultType, result);
} else {
return rewriter.notifyMatchFailure(
calculateOp, "Unimplemented: Expected result type to "
"be `BaseTensorType` or `NumberType`");
}
}
use.set(originalTypedValue);
}
result.setType(updatedType);
// Update the value yielded from the body to match the new result type. If we
// can refine the def in place, do that, otherwise insert a
// TensorStaticInfoCastOp.
Operation *yieldValues = calculateOp->getRegion(0).front().getTerminator();
OpOperand &use = yieldValues->getOpOperand(resultNum);
Value def = use.get();
Value newYieldedValue;
if (def.isa<OpResult>() &&
def.cast<OpResult>()
.getDefiningOp()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
newYieldedValue = def;
} else {
rewriter.setInsertionPoint(yieldValues);
if (updatedType.isa<BaseTensorType>()) {
newYieldedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
} else {
newYieldedValue =
rewriter.create<PrimUncheckedCastOp>(loc, updatedType, def);
}
}
use.set(newYieldedValue);
newYieldedValue.setType(updatedType);
return success();
}

View File

@ -0,0 +1,30 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
#include "mlir/IR/PatternMatch.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
namespace mlir {
namespace torch {
namespace Torch {
// Updates the type of result `resultNum` of both `calculateOp` and the torch op
// being wrapped by `calculateOp` to the type `newResultType`.
LogicalResult updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum, Type newResultType,
PatternRewriter &rewriter);
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H

View File

@ -0,0 +1,209 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
auto yieldDtypes = op.getCalculation().front().getTerminator();
auto dtype = yieldDtypes->getOperand(resultNum);
auto result = op->getResult(resultNum);
int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "Expected result from the DtypeCalculateOp calculation to be a "
"constant int");
auto dtypeScalarType = static_cast<torch_upstream::ScalarType>(dtypeInt);
// Calculate the updated type incorporating the new information.
Type impliedTypeFromDtype;
if (result.getType().isa<Torch::NumberType>()) {
FailureOr<Type> torchType =
getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(torchType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert result dtype to `Torch::FloatType` or "
"`Torch::IntType`");
}
impliedTypeFromDtype = *torchType;
} else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) {
impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(),
getTypeForScalarType(op->getContext(), dtypeScalarType));
} else {
return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to "
"be `BaseTensorType` or `NumberType`");
}
return updateCalculateOpResultTypes(op, resultNum, impliedTypeFromDtype,
rewriter);
}
namespace {
// This pattern propagates information out of the dtype calculation region and
// into the DtypeCalculateOp result types.
class RefineDtypeCalculateOp : public OpRewritePattern<DtypeCalculateOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DtypeCalculateOp op,
PatternRewriter &rewriter) const override {
LogicalResult result = failure();
for (int i = 0, e = op->getNumResults(); i != e; i++) {
if (succeeded(refineDtypeCalculateResult(op, i, rewriter)))
result = success();
}
return result;
}
};
} // namespace
namespace {
class DecomposePromoteDtypesOp : public OpRewritePattern<PromoteDtypesOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PromoteDtypesOp op,
PatternRewriter &rewriter) const override {
SmallVector<Optional<int64_t>> ranks;
SmallVector<int64_t> dtypes;
if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) {
return rewriter.notifyMatchFailure(
op, "Expected `ranks` to be a list of optional constant ints");
}
if (!matchPattern(op.getDtypes(), m_TorchListOfConstantInts(dtypes))) {
return rewriter.notifyMatchFailure(
op, "Expected `dtypes` to be a list of constant ints");
}
if (ranks.empty() || dtypes.empty()) {
return rewriter.notifyMatchFailure(
op, "`ranks` list and `dtypes` list must be non-empty");
}
if (ranks.size() != dtypes.size()) {
return rewriter.notifyMatchFailure(
op, "`ranks` list and `dtypes` list must have the same size");
}
torch_upstream::ResultTypeState state{};
for (auto ranksAndDtypes : llvm::zip(ranks, dtypes)) {
Optional<int64_t> rank;
int64_t dtype;
std::tie(rank, dtype) = ranksAndDtypes;
auto scalarType = static_cast<torch_upstream::ScalarType>(dtype);
bool isScalarOnlyOp = llvm::all_of(
ranks, [](Optional<int64_t> rank) { return !rank.has_value(); });
if (!rank.has_value()) {
// If `rank` does not have a value, then we are dealing with a scalar
// input. For the type promotion, the behavior of a scalar argument is
// dependent on whether the op is performing an operation with only
// scalars (such as AtenAddOp) or with scalars and tensors (such as
// AtenAddScalarOp). Therefore, we convert back to the original torch
// type of the scalar first, and then determine the right scalar type to
// use for promotion based on whether the op involves only scalars or
// scalars and tensors.
FailureOr<Type> torchType =
getTorchTypeForScalarType(op->getContext(), scalarType);
if (failed(torchType)) {
return rewriter.notifyMatchFailure(
op, "Dtypes for arguments scalars must be convertible to "
"`Torch::FloatType` or `Torch::IntType`");
}
Type builtinType = isScalarOnlyOp
? getBuiltInTypeForTorchScalar(*torchType)
: getDefaultDtypeForTorchScalar(*torchType);
scalarType = getScalarTypeForType(builtinType);
state.wrappedResult =
promote_skip_undefined(state.wrappedResult, scalarType);
} else if (rank.value() == 0) {
state.zeroResult = promote_skip_undefined(state.zeroResult, scalarType);
} else if (rank.value() > 0) {
state.dimResult = promote_skip_undefined(state.dimResult, scalarType);
} else {
return rewriter.notifyMatchFailure(op, "Rank should not be negative");
}
}
auto resultType = static_cast<int64_t>(result_type(state));
rewriter.replaceOpWithNewOp<ConstantIntOp>(
op, rewriter.getI64IntegerAttr(resultType));
return success();
}
};
} // namespace
namespace {
class RefineNumToTensorScalarOpType
: public OpRewritePattern<PrimNumToTensorScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
PatternRewriter &rewriter) const override {
auto originalResultType = op.getResult().getType().cast<BaseTensorType>();
if (originalResultType.hasDtype())
return rewriter.notifyMatchFailure(
op, "`PrimNumToTensorScalarOp` already has a dtype");
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
auto impliedTypeFromInputType =
originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
inputType)
.cast<BaseTensorType>();
op.getResult().setType(impliedTypeFromInputType);
return success();
}
};
} // namespace
namespace {
class SimplifyDtypeCalculationsPass
: public SimplifyDtypeCalculationsBase<SimplifyDtypeCalculationsPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<RefineDtypeCalculateOp>(context);
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createSimplifyDtypeCalculationsPass() {
return std::make_unique<SimplifyDtypeCalculationsPass>();
}

View File

@ -9,20 +9,11 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -40,10 +31,12 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op->getContext(); MLIRContext *context = op->getContext();
if (!op.isForLike()) if (!op.isForLike())
return failure(); return rewriter.notifyMatchFailure(op, "Loop is not for-like");
int64_t maxTripCount; int64_t maxTripCount;
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
return failure(); return rewriter.notifyMatchFailure(
op, "Expected `maxTripCount` to be a constant int");
;
SmallVector<Value> indices; SmallVector<Value> indices;
for (int64_t i = 0; i < maxTripCount; i++) { for (int64_t i = 0; i < maxTripCount; i++) {
// TODO: Add convenience builder. // TODO: Add convenience builder.
@ -149,7 +142,8 @@ public:
for (Operation *user : usersToInterpret) { for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) { if (auto append = dyn_cast<AtenAppendTOp>(user)) {
if (!append.use_empty()) if (!append.use_empty())
return failure(); return rewriter.notifyMatchFailure(
op, "Expected `AtenAppendTOp` to not have users");
if (append.getSelf() == op) { if (append.getSelf() == op) {
runningList.push_back(append.getEl()); runningList.push_back(append.getEl());
generatedNewLiteral = true; generatedNewLiteral = true;
@ -159,13 +153,16 @@ public:
} }
if (auto insert = dyn_cast<AtenInsertTOp>(user)) { if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
if (!insert.use_empty()) if (!insert.use_empty())
return failure(); return rewriter.notifyMatchFailure(
op, "Expected `AtenInsertTOp` to not have users");
int64_t index; int64_t index;
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
return failure(); return rewriter.notifyMatchFailure(
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
// The index might be statically out of bounds. // The index might be statically out of bounds.
if (index < 0 || index > static_cast<int64_t>(runningList.size())) if (index < 0 || index > static_cast<int64_t>(runningList.size()))
return failure(); return rewriter.notifyMatchFailure(
op, "Index in `AtenInsertTOp` is out of bounds");
if (insert.getSelf() == op) { if (insert.getSelf() == op) {
runningList.insert(runningList.begin() + index, insert.getEl()); runningList.insert(runningList.begin() + index, insert.getEl());
generatedNewLiteral = true; generatedNewLiteral = true;
@ -175,13 +172,15 @@ public:
} }
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) { if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty()) if (!setItem.use_empty())
return failure(); return rewriter.notifyMatchFailure(
op, "Expected `Aten_SetItemTOp` to not have users");
llvm::Optional<int64_t> indexOpt = llvm::Optional<int64_t> indexOpt =
matchLegalConstantIndexIntoListOfSize(setItem.getIdx(), matchLegalConstantIndexIntoListOfSize(setItem.getIdx(),
runningList.size()); runningList.size());
// The index might be statically out of bounds. // The index might be statically out of bounds.
if (!indexOpt) if (!indexOpt)
return failure(); return rewriter.notifyMatchFailure(
op, "Index in `Aten_SetItemTOp` is out of bounds");
if (setItem.getL() == op) { if (setItem.getL() == op) {
runningList[*indexOpt] = setItem.getEl(); runningList[*indexOpt] = setItem.getEl();
generatedNewLiteral = true; generatedNewLiteral = true;
@ -196,7 +195,7 @@ public:
} }
if (!generatedNewLiteral) if (!generatedNewLiteral)
return failure(); return rewriter.notifyMatchFailure(op, "No new literal created");
// Rewrite all users to use the appropriate list literals. // Rewrite all users to use the appropriate list literals.
Value latestLiteral = rewriter.create<PrimListConstructOp>( Value latestLiteral = rewriter.create<PrimListConstructOp>(
@ -285,11 +284,10 @@ public:
}; };
} // namespace } // namespace
static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
PatternRewriter &rewriter, int resultNum,
bool &madeChange) { PatternRewriter &rewriter) {
auto yieldValues = op.getBody().front().getTerminator(); auto yieldShapes = op.getCalculation().front().getTerminator();
auto yieldShapes = op.getShapeCalculation().front().getTerminator();
auto shape = yieldShapes->getOperand(resultNum); auto shape = yieldShapes->getOperand(resultNum);
auto result = op->getResult(resultNum); auto result = op->getResult(resultNum);
@ -298,7 +296,9 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
// much as possible to literals. // much as possible to literals.
auto listConstruct = shape.getDefiningOp<PrimListConstructOp>(); auto listConstruct = shape.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) if (!listConstruct)
return; return rewriter.notifyMatchFailure(
op, "Expected result from ShapeCalculateOp calculation to be a "
"`PrimListConstructOp`");
llvm::BitVector clobberedElements(listConstruct->getNumOperands()); llvm::BitVector clobberedElements(listConstruct->getNumOperands());
// Analyze the users to determine if we can refine the shape. // Analyze the users to determine if we can refine the shape.
for (Operation *user : listConstruct->getUsers()) { for (Operation *user : listConstruct->getUsers()) {
@ -320,7 +320,7 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
continue; continue;
} }
// An unhandled op! We can't make any assumptions about the shape. // An unhandled op! We can't make any assumptions about the shape.
return; return rewriter.notifyMatchFailure(op, "Unhandled op that mutates lists");
} }
// Construct the list of sizes implied by the yielded shape. // Construct the list of sizes implied by the yielded shape.
@ -334,55 +334,15 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
sizes.push_back(kUnknownSize); sizes.push_back(kUnknownSize);
} }
// Calculate the updated type incorporating the new shape information. auto originalResultType = result.getType().cast<BaseTensorType>();
Type originalResultType = result.getType();
auto impliedTypesFromShape = auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype( originalResultType.cast<BaseTensorType>()
makeArrayRef(sizes), nullptr); .getWithSizesAndDtype(makeArrayRef(sizes),
auto updatedType = originalResultType.getOptionalDtype())
meetTensorTypes(originalResultType.cast<BaseTensorType>(), .cast<BaseTensorType>();
impliedTypesFromShape.cast<BaseTensorType>());
// If we didn't get any new information, there is nothing left for us to do.
if (!updatedType || updatedType == originalResultType)
return;
// Update all the uses of the result type to the new type, if possible. Insert return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
// a TensorStaticInfoCastOp for any users that might require the exact rewriter);
// previous type.
Value originalTypedValue;
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
if (use.getOwner()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
continue;
}
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(op);
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
op->getLoc(), originalResultType, result);
}
use.set(originalTypedValue);
}
result.setType(updatedType);
madeChange = true;
// Update the value yielded from the body to match the new result type. If we
// can refine the def in place, do that, otherwise insert a
// TensorStaticInfoCastOp.
OpOperand &use = op.getBody().front().getTerminator()->getOpOperand(resultNum);
Value def = use.get();
Value newYieldedValue;
if (def.isa<OpResult>() &&
def.cast<OpResult>()
.getDefiningOp()
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
newYieldedValue = def;
} else {
rewriter.setInsertionPoint(yieldValues);
newYieldedValue =
rewriter.create<TensorStaticInfoCastOp>(op->getLoc(), updatedType, def);
}
use.set(newYieldedValue);
newYieldedValue.setType(updatedType);
} }
namespace { namespace {
@ -393,10 +353,11 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShapeCalculateOp op, LogicalResult matchAndRewrite(ShapeCalculateOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
bool madeChange = false; LogicalResult result = failure();
for (int i = 0, e = op->getNumResults(); i != e; i++) for (int i = 0, e = op->getNumResults(); i != e; i++)
refineShapeCalculateResult(op, i, rewriter, madeChange); if (succeeded(refineShapeCalculateResult(op, i, rewriter)))
return success(madeChange); result = success();
return result;
} }
}; };
} // namespace } // namespace

View File

@ -100,19 +100,46 @@ Type Torch::getTypeForScalarType(
} }
} }
Type Torch::getTorchTypeForScalarType(MLIRContext *context, FailureOr<Type>
torch_upstream::ScalarType dtypeInt) { Torch::getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt) {
switch (dtypeInt) { switch (dtypeInt) {
case torch_upstream::ScalarType::Double: case torch_upstream::ScalarType::Double:
return Torch::FloatType::get(context); return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long: case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context); return Torch::IntType::get(context);
default: default:
llvm::report_fatal_error( return failure();
"Unsupported scalar type to Torch type conversion");
} }
} }
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>()) {
// For now, use float32 which is the initial default dtype returned by
// `torch.get_default_dtype`.
return Float32Type::get(context);
}
if (type.isa<Torch::IntType>())
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
return IntegerType::get(context, 1);
llvm_unreachable(
"getDefaultDtypeForTorchScalar called on an unsupported type");
}
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
MLIRContext *context = type.getContext();
if (type.isa<Torch::FloatType>())
return Float64Type::get(context);
if (type.isa<Torch::IntType>())
return IntegerType::get(context, 64, IntegerType::Signed);
if (type.isa<Torch::BoolType>())
return IntegerType::get(context, 1);
llvm_unreachable(
"getBuiltInTypeForTorchScalar called on an unsupported type");
}
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) { Type dtype) {
int intType = (int)getScalarTypeForType(dtype); int intType = (int)getScalarTypeForType(dtype);

View File

@ -104,7 +104,8 @@ endif()
################################################################################ ################################################################################
# Custom op example # Custom op example
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts. # Required for running the update_torch_ods.sh and update_abstract_interp_lib.sh
# scripts.
################################################################################ ################################################################################
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) # add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)

View File

@ -5,12 +5,12 @@ If you're reading this, you're likely looking to create or support a third-party
This isn't much different than [adding a new PyTorch op to torch-mlir](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation). This isn't much different than [adding a new PyTorch op to torch-mlir](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
You'll still go through the exact same process, with just a small change: You'll still go through the exact same process, with just a small change:
- Before running `update_torch_ods.sh` or `update_shape_lib.sh`, you'll want to set the `TORCH_MLIR_EXT_PYTHONPATH` to point to wherever your extension lives and the `TORCH_MLIR_EXT_MODULES` to the name of the python module. - Before running `update_torch_ods.sh` or `update_abstract_interp_lib.sh`, you'll want to set the `TORCH_MLIR_EXT_PYTHONPATH` to point to wherever your extension lives and the `TORCH_MLIR_EXT_MODULES` to the name of the python module.
For instance, let's say you've written a python package called `my_torch_ops`. For instance, let's say you've written a python package called `my_torch_ops`.
If `my_torch_ops` lives in `/example/subdirectory/`, then you'll want to set `TORCH_MLIR_EXT_PYTHONPATH=/example/subdirectory` and `TORCH_MLIR_EXT_MODULES=my_torch_ops`. If `my_torch_ops` lives in `/example/subdirectory/`, then you'll want to set `TORCH_MLIR_EXT_PYTHONPATH=/example/subdirectory` and `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
If you've installed your package (with `pip`, for instance), then you'll only need to set `TORCH_MLIR_EXT_MODULES=my_torch_ops`. If you've installed your package (with `pip`, for instance), then you'll only need to set `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
Note that the `update_torch_ods.sh` and `update_shape_lib.sh` scripts do not use the `PYTHONPATH` environment variable in your current shell. Note that the `update_torch_ods.sh` and `update_abstract_interp_lib.sh` scripts do not use the `PYTHONPATH` environment variable in your current shell.
This is on purpose, but it means that you either need to set `TORCH_MLIR_EXT_PYTHONPATH` to include your package or to include the paths set in your shell's `PYTHONPATH` variable. This is on purpose, but it means that you either need to set `TORCH_MLIR_EXT_PYTHONPATH` to include your package or to include the paths set in your shell's `PYTHONPATH` variable.
If you have more than one PyTorch extension, you can add them all by including each path in `TORCH_MLIR_EXT_PYTHONPATH` separated by colons (`:`) and each module in `TORCH_MLIR_EXT_MODULES` separated by commas (`,`). If you have more than one PyTorch extension, you can add them all by including each path in `TORCH_MLIR_EXT_PYTHONPATH` separated by colons (`:`) and each module in `TORCH_MLIR_EXT_MODULES` separated by commas (`,`).

View File

@ -0,0 +1,189 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import inspect
import re
from typing import List, Optional, Union
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch_mlir.passmanager import PassManager
from .registry import Registry
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
# that when `jit.script`ed converts a float scalar to a tensor
# with dtype that corresponds to Python's `float`.
#
# See definition of `NumToTensor`: https://github.com/pytorch/pytorch/blob/c09929659ce8ba2f1b7b2f6e50084ccbf854d44b/torch/csrc/jit/ir/ir.cpp#L1850
# Definition of `fromNumberType` used by `NumToTensor` to
# calculate dtype: https://github.com/pytorch/pytorch/blob/c09929659ce8ba2f1b7b2f6e50084ccbf854d44b/aten/src/ATen/core/jit_type.h#L1679
#
# Note that doing something like
# `torch.tensor(scalar).to(type(scalar)).dtype` does not work because
# `torch.tensor` needs to know the type of the input at compile time
# and there is no `jit.script` support for Python's `type`.
#
# TODO: A better way to handle this would be to add support for
# `isinstance` in torch-mlir, which requires adding a new torch dialect
# op.
return torch.ops.prim.NumToTensor(scalar).dtype
# When we import into torch-mlir, only the calls to
# `__torch_mlir_internal_promote_dtypes` are used to generate the
# `torch.promote_dtypes` ops. Therefore, to avoid generating extra
# MLIR code in the library, all calls made inside
# `__torch_mlir_internal_promote_dtypes` are `jit.ignore`d.
@torch.jit.ignore
def _get_scalar_with_dtype(dtype: torch.dtype) -> Union[int, float]:
if dtype == torch.int64:
return 0
elif dtype == torch.float64:
return 0.0
else:
raise ValueError(f"Unhandled dtype: {dtype}")
@torch.jit.ignore
def _promote_scalar_tensor(scalar_dtype: torch.dtype, tensor_rank: int,
tensor_dtype: torch.dtype) -> torch.dtype:
scalar = _get_scalar_with_dtype(scalar_dtype)
tensor = torch.rand([1] * tensor_rank).to(tensor_dtype)
return torch.result_type(scalar, tensor)
@torch.jit.ignore
def _promote_tensor_tensor(lhs_rank: int, lhs_dtype: torch.dtype,
rhs_rank: int, rhs_dtype: torch.dtype) -> torch.dtype:
lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype)
rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype)
return torch.result_type(lhs_tensor, rhs_tensor)
@torch.jit.ignore
def _promote_scalar_scalar(lhs_dtype: torch.dtype,
rhs_dtype: torch.dtype) -> torch.dtype:
# When `torch.result_type` is used on two scalars, the result
# dtype is the dtype one would expect for an op with signature
# (Scalar, Scalar) -> (Tensor). However, once a module gets
# jit.scripted, all math ops that work on scalars becomes
# (Scalar, Scalar) -> (Scalar) ops. So to get the right result
# dtype, we use the tensor-tensor promotion rules.
return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype)
def promote_dtypes(ranks: List[Optional[int]],
dtypes: List[torch.dtype]) -> torch.dtype:
"""Apply PyTorch dtype promotion rules and return the result type.
"""
return __torch_mlir_internal_promote_dtypes(ranks, dtypes)
def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
dtypes: List[torch.dtype]
) -> torch.dtype:
"""Apply PyTorch dtype promotion rules and return the result type.
This function serves two purposes:
1. It is handled in a special way during import into Torch-MLIR,
generating `torch.promote_dtypes` ops
2. Computes the actual promotion logic at the Python level in order
to be able to test dtype calculation functions against PyTorch
"""
lhs_optional_rank = ranks[0]
lhs_dtype = dtypes[0]
for rhs_optional_rank, rhs_dtype in zip(ranks, dtypes):
if lhs_optional_rank is None and rhs_optional_rank is None:
lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype)
elif lhs_optional_rank is None and rhs_optional_rank is not None:
lhs_dtype = _promote_scalar_tensor(
lhs_dtype, rhs_optional_rank, rhs_dtype)
lhs_optional_rank = rhs_optional_rank
elif lhs_optional_rank is not None and rhs_optional_rank is None:
lhs_dtype = _promote_scalar_tensor(
rhs_dtype, lhs_optional_rank, lhs_dtype)
elif lhs_optional_rank is not None and rhs_optional_rank is not None:
lhs_dtype = _promote_tensor_tensor(
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype)
lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank)
return lhs_dtype
def not_present_in_registry(f):
"""Decorator for abstract interpretation functions not present in the registry.
This can happen for "valsem" ops that we have in Torch-MLIR, such as
torch.valsem.aten.fill.Scalar, which are consistent with PyTorch conventions
(e.g. being the value-semantic correspondent of torch.aten.fill_.Scalar),
but that for whatever reason are not present in PyTorch. Such ops are useful
to keep certain passes within Torch-MLIR as consistent as possible.
For such ops, in the shape library generator, we treat them as if they
were registered torch ops (so we don't put "valsem" on them), to keep the
generator consistent.
To check if this decorator has been applied, use
`hasattr(f, "_not_present_in_registry")`.
"""
f._not_present_in_registry = None
return f
def _verify_signature_matches_registry(f, registry: Registry):
source = inspect.getsource(f)
signature = None
for line in source.splitlines():
if line.startswith("def "):
signature = line
break
assert signature is not None, f"Could not find signature for {f.__name__}"
assert "" in signature, f"Malformed signature {signature}. Signature missing the character `〡`"
function_name, function_kind = f.__name__.split("")
atoms = function_name.split("")
if len(atoms) == 2:
atoms += [""]
operator = registry.get_by_triple(tuple(atoms))
if function_kind == "shape":
expected_signature = operator.get_shape_function_signature()
elif function_kind == "dtype":
expected_signature = operator.get_dtype_function_signature()
elif function_kind == "decomposition":
expected_signature = operator.get_decomposition_function_signature()
else:
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
if signature != expected_signature:
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
def generate_library(globals_) -> str:
"""Convert all op functions in `globals()` into MLIR."""
mb = ModuleBuilder()
# We use the registry to ensure that the shape functions are consistent
# with the ops.
registry = Registry.load()
for k, v in globals_.items():
if "" not in k:
continue
if not hasattr(v, "_not_present_in_registry"):
_verify_signature_matches_registry(v, registry)
# Add it to the compilation unit.
torch.jit.script(v)
for function in torch.jit._state._python_cu.get_functions():
# Calls to the function `__torch_mlir_internal_promote_dtypes`
# will get converted to the torch-dialect op `torch.promote_dtypes`
# during import, so there is no need to import the actual
# function.
if function.name == "__torch_mlir_internal_promote_dtypes":
continue
mb.import_function(function)
# Clean up the IR a bit before writing it out.
pm = PassManager.parse("builtin.module(canonicalize)", context=mb.module.context)
pm.run(mb.module)
# Munge the IR a bit to make it more systematically accessible.
asm = mb.module.operation.get_asm()
# We'd like a unique function prefix to avoid collisions with user-
# defined symbols. Since all of our shape functions conveniently have
# a `` in them, we replace the torch namespace with our prefix. E.g.:
# __torch__.atenaddScalar -> __torch_mlir_shape_fn.atenaddScalar
asm = re.sub(r"__torch__\.([^.(]+)\\E3\\80\\87([^.(]+)\\E3\\80\\A1([^.(\"]+)",
r"__torch_mlir_\3_fn.\1\\E3\\80\\87\2",
asm)
# Put the `` back to a regular `.`.
asm = asm.replace("\\E3\\80\\87", ".")
return asm

View File

@ -5,7 +5,7 @@
"""Access to the Torch JIT operator registry.""" """Access to the Torch JIT operator registry."""
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union, Callable
import io import io
import itertools import itertools
@ -15,6 +15,48 @@ from .utils import TextEmitter
# Note that this utility exists only in the c-extension. # Note that this utility exists only in the c-extension.
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
if parameter_name == "from":
parameter_name = "from_" # Avoid using a Python keyword.
return parameter_name
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
default = ""
if "default_debug" in arg:
if "List" in arg["pytype"]:
# TorchScript doesn't allow lists as default parameters due
# to the weird Python semantics of mutable default
# arguments. So munge them into tuples, which work
# fine here. We only need these to simplify the invocation
# of the abstract interpretation functions as valid Python for
# testing against the real ops, and tuples work fine in all
# the places this kicks in (e.g. conv dilations -- we aren't
# mutating those lists).
default_debug = arg["default_debug"].replace(
'[', '(').replace(']', ')')
elif arg["pytype"] == "str":
default_debug = repr(arg["default_debug"]).replace("'", '"')
else:
default_debug = arg["default_debug"]
default = f" = {default_debug}"
return default
def _pytype_to_fn_pytype_common(pytype: str) -> str:
if "number" in pytype:
return pytype.replace("number", "Union[int, float]")
# `torch.device` is lowercase.
if pytype == "Device":
return "device"
if pytype == "Optional[Device]":
return "Optional[device]"
# Generators don't contribute to shapes, and are not scriptable currently.
# So just hack them to be passed as "Any".
if pytype == "Generator":
return "Any"
if pytype == "Optional[Generator]":
return "Any"
return pytype
def _pytype_to_shape_fn_pytype(pytype: str) -> str: def _pytype_to_shape_fn_pytype(pytype: str) -> str:
"""Convert a JitOperator pytype to the type relevant in shape functions. """Convert a JitOperator pytype to the type relevant in shape functions.
@ -28,31 +70,29 @@ def _pytype_to_shape_fn_pytype(pytype: str) -> str:
# logically real-valued), so it doesn't really matter much, and # logically real-valued), so it doesn't really matter much, and
# `float` helps make it clearer that it's not part of the shape # `float` helps make it clearer that it's not part of the shape
# function. # function.
if pytype == "number": # TODO: This is no longer needed. Scalars can be represented by
return "float" # Union[int, float].
if pytype == "Optional[number]": if "number" in pytype:
return "Optional[float]" return pytype.replace("number", "float")
# `torch.device` is lowercase. if "Tensor" in pytype:
if pytype == "Device": return pytype.replace("Tensor", "List[int]")
return "device" return _pytype_to_fn_pytype_common(pytype)
if pytype == "Optional[Device]":
return "Optional[device]" def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
# Shape functions only care about the shape of tensors. """Convert a JitOperator pytype to the type relevant in dtype functions.
if pytype == "Tensor":
return "List[int]" In particular, this converts `Tensor` to `int`, along with a few
if pytype == "Optional[Tensor]": other special cases.
return "Optional[List[int]]" """
if pytype == "List[Tensor]": # Dtype functions only care about the rank and dtype of tensors.
return "List[List[int]]" if "Tensor" in pytype:
if pytype == "List[Optional[Tensor]]": return pytype.replace("Tensor", "int")
return "List[Optional[List[int]]]" return _pytype_to_fn_pytype_common(pytype)
# Generators don't contribute to shapes, and are not scriptable currently.
# So just hack them to be passed as "Any". def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
if pytype == "Generator": """Convert a JitOperator pytype to the type relevant in decomposition functions.
return "Any" """
if pytype == "Optional[Generator]": return _pytype_to_fn_pytype_common(pytype)
return "Any"
return pytype
class JitOperator: class JitOperator:
"""Information about a single registered `torch::jit::Operator`""" """Information about a single registered `torch::jit::Operator`"""
@ -142,6 +182,23 @@ class JitOperator:
cpp_class_name = cpp_class_name.lstrip("_") cpp_class_name = cpp_class_name.lstrip("_")
return op_name, cpp_class_name return op_name, cpp_class_name
def _get_function_signature(self, function_kind: str,
parameter_decl_builder: Callable["SIG_ATTR_TYPE", str],
ret_decl_builder: Callable["SIG_ATTR_TYPE", str]) -> str:
mlir_op_name, _ = self.get_mlir_names()
# Replace `.` with a valid Python identifier character.
# `` vaguely looks like `.`.
def_name = "".join(mlir_op_name.split("."))
def_name += f"{function_kind}"
parameter_decls = list(map(parameter_decl_builder, self.arguments))
ret_decls = list(map(ret_decl_builder, self.returns))
parameters = ", ".join(parameter_decls)
result = ", ".join(ret_decls)
if len(ret_decls) >= 2:
result = f"Tuple[{result}]"
return f"def {def_name}({parameters}) -> {result}:"
def get_shape_function_signature(self): def get_shape_function_signature(self):
"""Gets the Python function signature for this op's shape function. """Gets the Python function signature for this op's shape function.
@ -150,45 +207,66 @@ class JitOperator:
ops have extra default arguments and stuff that are tedious to write out ops have extra default arguments and stuff that are tedious to write out
right. right.
""" """
mlir_op_name, _ = self.get_mlir_names() def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
# Replace `.` with a valid Python identifier character.
# `` vaguely looks like `.`.
def_name = "".join(mlir_op_name.split("."))
parameter_decls = []
for arg in self.arguments:
pytype = _pytype_to_shape_fn_pytype(arg["pytype"]) pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
default = "" default = _get_default_value(arg)
if "default_debug" in arg: parameter_name = _rename_python_keyword_parameter_name(arg["name"])
if "List" in arg["pytype"]: return f"{parameter_name}: {pytype}{default}"
# TorchScript doesn't allow lists as default parameters due
# to the weird Python semantics of mutable default
# arguments. So munge them into tuples, which work
# fine here. We only need these to simplify the invocation
# of the shape functions as valid Python for testing against
# the real ops, and tuples work fine in all the places this
# kicks in (e.g. conv dilations -- we aren't mutating those
# lists).
default_debug = arg["default_debug"].replace(
'[', '(').replace(']', ')')
elif arg["pytype"] == "str":
default_debug = repr(arg["default_debug"]).replace("'", '"')
else:
default_debug = arg["default_debug"]
default = f" = {default_debug}"
parameter_name = arg["name"]
if parameter_name == "from":
parameter_name = "from_" # Avoid using a Python keyword.
parameter_decls.append(f"{parameter_name}: {pytype}{default}")
ret_decls = []
for ret in self.returns:
pytype = _pytype_to_shape_fn_pytype(ret["pytype"])
ret_decls.append(f"{pytype}")
parameters = ", ".join(parameter_decls)
result = ", ".join(ret_decls)
if len(ret_decls) >= 2:
result = f"Tuple[{result}]"
return f"def {def_name}({parameters}) -> {result}:" def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return _pytype_to_shape_fn_pytype(arg["pytype"])
return self._get_function_signature(
"shape", parameter_decl_builder, ret_decl_builder)
def get_dtype_function_signature(self):
"""Gets the Python function signature for this op's dtype function.
While this is technically debug-only output, it is useful to copy-paste
it from the debug dump into the library definitions, as many
ops have extra default arguments and stuff that are tedious to write out
right.
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
if "Tensor" in arg["pytype"]:
return ", ".join([f"{parameter_name}_rank: {pytype}{default}",
f"{parameter_name}_dtype: {pytype}{default}"])
return f"{parameter_name}: {pytype}{default}"
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
# The dtype function is expected to return dtypes for scalar
# results of type `number`. Here we handle this case because
# `_pytype_to_dtype_fn_pytype` will replace `number` with
# `Union[int, float]`.
if arg["pytype"] == "number":
return "int"
return _pytype_to_dtype_fn_pytype(arg["pytype"])
return self._get_function_signature(
"dtype", parameter_decl_builder, ret_decl_builder)
def get_decomposition_function_signature(self):
"""Gets the Python function signature for this op's decomposition function.
While this is technically debug-only output, it is useful to copy-paste
it from the debug dump into the shape library definitions, as many
ops have extra default arguments and stuff that are tedious to write out
right.
"""
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
default = _get_default_value(arg)
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
return f"{parameter_name}: {pytype}{default}"
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
return self._get_function_signature(
"decomposition", parameter_decl_builder, ret_decl_builder)
def __repr__(self): def __repr__(self):
f = io.StringIO() f = io.StringIO()
@ -212,6 +290,10 @@ class JitOperator:
p(f"is_mutable = {self.is_mutable}") p(f"is_mutable = {self.is_mutable}")
if any(ret["type"] == "Tensor" for ret in self.returns): if any(ret["type"] == "Tensor" for ret in self.returns):
p(f"shape_function_signature = {self.get_shape_function_signature()}") p(f"shape_function_signature = {self.get_shape_function_signature()}")
p(f"decomposition_function_signature = {self.get_decomposition_function_signature()}")
if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns):
p(f"dtype_function_signature = {self.get_dtype_function_signature()}")
if not self.arguments: if not self.arguments:
p("arguments = []") p("arguments = []")
else: else:
@ -300,12 +382,13 @@ class Registry:
return self.by_triple[key] return self.by_triple[key]
# A List[Dict[str, _]] mapping attribute names to: # A Dict[str, _] mapping attribute names to:
# - str (e.g. {'name': 'dim'} ) # - str (e.g. {'name': 'dim'} )
# - int (e.g. {'N': 1} ) # - int (e.g. {'N': 1} )
# - Dict[str, List[str]] # - Dict[str, List[str]]
# (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} ) # (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} )
SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]] SIG_ATTR_TYPE = Dict[str, Union[str, int, Dict[str, List[str]]]]
SIGLIST_TYPE = List[SIG_ATTR_TYPE]
# A Dict[str, _] describing a registered op. Each field is either # A Dict[str, _] describing a registered op. Each field is either
# - bool (e.g. {'is_mutable': False} ) # - bool (e.g. {'is_mutable': False} )
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} ) # - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )

View File

@ -0,0 +1,315 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Any, List, Iterable, Optional, Callable
import torch
from torch import Tensor
# ==============================================================================
# Shape, dtype, and decomposition function testing infrastructure.
# ==============================================================================
# We expect all functions to be adequately tested. For functions
# implemented with upstream helpers, additional testing is usually not needed.
# But for functions that are authored/maintained by the Torch-MLIR
# project, we expect adequate testing.
#
# To do this, we provide decorators
# - `@check_shape_function`
# - `@check_dtype_function`
# - `@check_decomposition_function`
# which can be used to specify a series of operator invocations (such as "call
# this operator with two arguments -- a first tensor of size [2, 3] and a second
# tensor of size [3, 4]"). These tests are then run as part of this script, and
# any mismatches from the real op's behavior will be reported.
#
# A typical use of the decorator might look like:
# ```
# @check_shape_function([
# Invocation(TensorOfShape(2, 3, 4)), # Basic case.
# Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
# Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
# Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
# Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`.
# ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds.
# ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
# ])
# ```
# Each `Invocation` takes a list of args/kwargs which will be passed to both the
# shape function and the real op and the results compared.
# We expect both the successful and error cases to be tested.
#
# The typical iteration flow is to add invocations to the list and then re-run
# `build_tools/update_abstract_interp_lib.sh` to re-run the tests.
class TensorOfShape:
"""Symbolic placeholder for a tensor argument to an operation.
Shape functions take tensor arguments as `List[int]`, whereas the real ops
take them as `Tensor`, so we need a symbolic representation of a tensor
argument to an op in order to represent an invocation that can drive both
the shape function and the real op (see `Invocation`).
A plain list doesn't work, because plain lists are actually legal arguments
to a shape function (e.g. conv dilations), and we don't want them to receive
this special treatment.
This class also tracks a dtype of the tensor, since some ops require a
specific dtype.
"""
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32):
self.shape = list(shape)
self.dtype = dtype
def __repr__(self):
args_str = ", ".join(repr(x) for x in self.shape)
if self.dtype is torch.float32:
return f"TensorOfShape({args_str})"
else:
return f"TensorOfShape({args_str}, dtype={self.dtype})"
def LongTensorOfShape(*args, **kwargs):
"""Helper for indicating a TensorOfShape with integer type."""
return TensorOfShape(*args, **kwargs, dtype=torch.long)
def NonZeroDTensorWithDtype(dtype):
"""Helper for indicating a non-zero dim tensor with custom type."""
return TensorOfShape(1, dtype=dtype)
def ZeroDTensorWithDtype(dtype):
"""Helper for indicating a zero dim tensor with custom type."""
return TensorOfShape(dtype=dtype)
def _recursively_transform_tensor_args(
o: Any,
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
if o is None or isinstance(o, (float, int)):
return o
if isinstance(o, TensorOfShape):
return tensor_transformer(o)
if isinstance(o, list):
return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o]
if isinstance(o, tuple):
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
raise Exception(f"Unhandled type {type(o)}")
def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]:
"""Converts an Invocation argument to a dtype function argument.
TensorOfShape is replaced with two ints representing the rank
and dtype of the tensor, respectively.
"""
def contains_tensor(o: Any) -> bool:
if o is None or isinstance(o, (float, int)):
return False
if isinstance(o, TensorOfShape):
return True
if isinstance(o, (list, tuple)):
for elem in o:
if contains_tensor(elem):
return True
return False
raise Exception(f"Unhandled type {type(o)}")
result = []
for arg in arguments:
if contains_tensor(arg):
rank_arg = _recursively_transform_tensor_args(
arg, lambda x: len(x.shape))
dtype_arg = _recursively_transform_tensor_args(
arg, lambda x: x.dtype)
result += [rank_arg, dtype_arg]
else:
result.append(arg)
return result
class Invocation:
"""Representation of a single op invocation (i.e. list of args to the op).
This class is used to represent a single invocation of an op in a way that
we can use to both invoke the abstract interpretation function and invoke
the actual op, which have slightly different signatures.
Specifically, this class has special knowledge of `TensorOfShape` and
translates it appropriately to either a tensor (for the real op), a
`List[int]` for the shape function, and two `int`s representing
the tensor rank and dtype in the case of a dtype function.
This class also tracks whether the invocation is expected to raise an
exception for greater precision when interpreting errors raised during
testing.
"""
def __init__(self, *args: Any, **kwargs: Any):
self.args = list(args)
# We assume kwargs don't contain tensors, so they don't need any
# special handling.
self.kwargs = kwargs
def is_expected_to_raise_exception(self) -> bool:
"""Returns true if the invocation is expected to raise an exception.
The subclass ErrorInvocation overrides this to indicate an Invocation
that is expected to raise an exception.
"""
return False
def to_shape_function_args(self):
"""Gets positional arguments appropriate for a shape function."""
# Make a copy of the size list, since a shape function might
# modify it in-place. In the compiler, the lowering always
# produces a new list via a fresh invocation of `AtenSizeOp`,
# which allocates a new, unaliased list. So in-place mutations
# are ok since they make it a bit easier to write some shape
# functions.
tensor_transformer = lambda o: list(o.shape)
return _recursively_transform_tensor_args(
self.args, tensor_transformer)
def to_dtype_function_args(self):
"""Gets positional arguments appropriate for a dtype function."""
return _convert_to_dtype_function_args(self.args)
def to_real_op_args(self):
"""Gets positional arguments appropriate for the real op."""
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype)
return _recursively_transform_tensor_args(self.args, tensor_transformer)
def __repr__(self) -> str:
args_str = ", ".join(repr(x) for x in self.args)
kwargs_str = ""
if self.kwargs:
kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
return f"Invocation({args_str}{kwargs_str})"
class ErrorInvocation(Invocation):
"""An Invocation that raises an exception.
Explicitly knowing that an invocation is expected to raise an exception
avoids certain failure modes of the test infrastructure where a bug
slips through when both the abstract interpretation function and the real
op raise exceptions due to independent bugs (that cancel each other out and
spurioiusly make the two appear to "agree" that an exception needs to be
raised).
"""
def is_expected_to_raise_exception(self) -> bool:
return True
def _normalize_multiple_results_to_list(t: Any):
"""Returns a flat list of results.
This normalizes the fact that Python represents multiple returns with a
tuple, but single returns as a single value. We just want a list with
N elements for N results.
"""
if isinstance(t, tuple):
return list(t)
# Shape functions return List[int] instead of tensors.
if isinstance(t, (Tensor, list, torch.dtype, int, float)):
return [t]
raise ValueError(f"Unexpected type {type(t)}")
def _report(f, invocation: Invocation, error_message: str):
fn_type = f.__name__.split("")[-1]
raise ValueError(f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}")
def _get_fn_and_golden_results(f, invocation: List[Invocation]):
"""Run the invocation on the library function and torch op.
If no unexpected errors are detected, returns a tuple wth the first
element being the results from the library function and the second
element being the results from the torch op. The results will be `None`
if the library function and torch op expectedly result in errors.
"""
fn_name_without_fn_type, fn_type = f.__name__.split("")
fn_name_parts = fn_name_without_fn_type.split("")
ns, unqual = fn_name_parts[:2]
overload = "default" if len(fn_name_parts) != 3 else fn_name_parts[-1]
op = getattr(getattr(getattr(torch.ops, ns), unqual), overload)
fn_error, op_error, fn_results, golden_results = None, None, None, None
try:
fn_results = _normalize_multiple_results_to_list(f(
*(getattr(invocation, f"to_{fn_type}_function_args")()),
**invocation.kwargs))
except Exception as e:
fn_error = f"{e}"
try:
golden_results = _normalize_multiple_results_to_list(op(
*invocation.to_real_op_args(),
**invocation.kwargs))
except Exception as e:
op_error = f"{e}"
# Check for error behavior.
if invocation.is_expected_to_raise_exception():
if fn_error is None and op_error is None:
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function n or op raised an exception")
if fn_error is None:
_report(f, invocation, f"Op raised error {op_error!r}, but shape function did not.")
if op_error is None:
_report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.")
else:
if fn_error is not None and op_error is not None:
_report(f, invocation, f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.")
if fn_error is not None:
_report(f, invocation, f"{fn_type} function raised error {fn_error!r} but op did not raise any error.")
if op_error is not None:
_report(f, invocation, f"Op raised error {op_error!r} but {fn_type} function did not raise any error.")
return fn_results, golden_results
def check_shape_function(invocations: List[Invocation]):
"""Decorator that automatically tests a shape function.
The shape function, which is expected to be named systematically with
`` instead of `.`, is tested against the corresponding op in
`torch.ops.*` function using the given invocations.
"""
def decorator(f):
for invocation in invocations:
result_shapes, golden_results = _get_fn_and_golden_results(f, invocation)
if invocation.is_expected_to_raise_exception():
continue
# Check for matching results.
if len(result_shapes) != len(golden_results):
_report(f, invocation, f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}")
for result_shape, golden_result in zip(result_shapes, golden_results):
result_rank = len(result_shape)
golden_rank = len(golden_result.shape)
if result_rank != golden_rank:
_report(f, invocation, f"Expected result rank {golden_rank}, got {result_rank}")
for dimension_size, golden_dimension_size in zip(result_shape, golden_result.shape):
if dimension_size != golden_dimension_size:
_report(f, invocation, f"Expected result shape {golden_result.shape}, got {result_shape}")
return f
return decorator
def check_dtype_function(invocations: List[Invocation]):
"""Decorator that automatically tests a dtype function.
The dtype function, which is expected to be named systematically with
`` instead of `.`, is tested against the corresponding op in
`torch.ops.*` function using the given invocations.
"""
def decorator(f):
for invocation in invocations:
result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation)
if invocation.is_expected_to_raise_exception():
continue
if len(result_dtypes) != len(golden_results):
_report(f, invocation, f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}")
for result_dtype, golden_result in zip(result_dtypes, golden_results):
if isinstance(golden_result, torch.Tensor):
golden_dtype = golden_result.dtype
elif isinstance(golden_result, (int, float)):
# Turn Python type to PyTorch dtype
golden_dtype = torch.tensor([]).to(type(golden_result)).dtype
else:
raise ValueError(f"Unhandled return type {type(golden_result)}")
if result_dtype != golden_dtype:
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
return f
return decorator

View File

@ -327,13 +327,26 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type(), importOptions); return getMlirTypeFromTorchType(loc, v->type(), importOptions);
}); });
MlirOperation operation = createMlirOperationAtEnd( std::string functionName = node->input(0)->node()->s(c10::attr::name);
appendToBlock, "func.call_indirect", loc, std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs(), importOptions), getMlirTypesFromValues(loc, node->outputs(), importOptions);
lookupMappedValue(node->input(0)), std::vector<MlirValue> adjustedFuncArgs = adjustStaticInformationForValues(
adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)), expectedTypes, /*userAllowsRefinement=*/false);
expectedTypes, /*userAllowsRefinement=*/false)); MlirOperation operation;
// `__torch_mlir_internal_promote_dtypes` is a special python function that
// users can use in dtype refinement function definitions to get the
// promoted result dtype for a PyTorch computation. Here we turn the call to
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
if (functionName == "__torch_mlir_internal_promote_dtypes") {
operation =
createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc,
resultTypes, adjustedFuncArgs);
} else {
operation = createMlirOperationAtEnd(
appendToBlock, "func.call_indirect", loc, resultTypes,
lookupMappedValue(node->input(0)), adjustedFuncArgs);
}
mapResults(node, operation); mapResults(node, operation);
return; return;
} }

View File

@ -1,11 +1,11 @@
// RUN: torch-mlir-opt -torch-drop-shape-calculations -split-input-file %s | FileCheck %s // RUN: torch-mlir-opt -torch-drop-abstract-interp-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @basic( // CHECK-LABEL: func.func @basic$shape_calculate(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
// CHECK: return %[[ERASED]] : !torch.vtensor // CHECK: return %[[ERASED]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor { func.func @basic$shape_calculate(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%0 = torch.shape.calculate { %0 = torch.shape.calculate {
@ -22,6 +22,26 @@ func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
// ----- // -----
// CHECK-LABEL: func.func @basic$dtype_calculate(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[INT_6:.*]] = torch.constant.int 6
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @basic$dtype_calculate(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%int6 = torch.constant.int 6
%0 = torch.dtype.calculate {
%2 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
torch.dtype.calculate.yield %2 : !torch.vtensor<*,f32>
} dtypes {
torch.dtype.calculate.yield.dtypes %int6 : !torch.int
} : !torch.vtensor<*,f32>
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<*,f32> to !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @shape_calc_in_loop( // CHECK-LABEL: func.func @shape_calc_in_loop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> { func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {

View File

@ -142,6 +142,22 @@ func.func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
return %0 : !torch.vtensor return %0 : !torch.vtensor
} }
func.func @dtype_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.dtype.calculate {
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
torch.dtype.calculate.yield %1 : !torch.vtensor
} dtypes {
%2 = torch.prim.dtype %arg0 : !torch.vtensor -> !torch.int
torch.dtype.calculate.yield.dtypes %2 : !torch.int
} : !torch.vtensor
return %0 : !torch.vtensor
}
func.func @promote_dtypes(%ranks: !torch.list<optional<int>>, %dtypes: !torch.list<int>) -> !torch.int {
%0 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
return %0 : !torch.int
}
func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) { func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor %0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
return return

View File

@ -1,109 +0,0 @@
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
// -----
// CHECK-LABEL: func @tensor_tensor$same_category_different_width(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f64>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return
func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_tensor$different_category(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f64>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return
func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
%t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_tensor$same_category_zero_rank_wider(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f64>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
func.func @tensor_tensor$same_category_zero_rank_wider(
%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[],f64>,
%alpha: !torch.int) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_tensor$zero_rank_higher_category(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
%t1: !torch.vtensor<[],f32>,
%alpha: !torch.int) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_tensor$alpha_wider_no_contribution(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
// CHECK: return
func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f32>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_scalar$scalar_higher_category(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.float,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func @tensor_scalar$scalar_same_category_wider(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK: return
func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
return
}

View File

@ -121,14 +121,14 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
// ----- // -----
// CHECK-LABEL: func.func @f // CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> // CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor // CHECK: return %[[CAST]] : !torch.vtensor
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor) cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor): ^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.tanh %arg1 : !torch.vtensor -> !torch.vtensor %1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
@ -136,11 +136,11 @@ func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK-LABEL: func.func @f // CHECK-LABEL: func.func @f
// CHECK: func.func private @callee // CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> // CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func.func @f() { func.func @f() {
builtin.module { builtin.module {
func.func private @callee(%arg0: !torch.vtensor) { func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor %1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
return return
} }
func.func @caller(%arg0: !torch.vtensor<*,f32>) { func.func @caller(%arg0: !torch.vtensor<*,f32>) {

View File

@ -9,36 +9,36 @@
// ----- // -----
// CHECK-LABEL: func.func @basic( // CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor // CHECK: return %[[RESULT]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor return %1 : !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func.func @keep_existing_shape_information( // CHECK-LABEL: func.func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> // CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
// CHECK: return %[[TANH]] : !torch.vtensor<[2],f32> // CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32> %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
return %1 : !torch.vtensor<[2],f32> return %1 : !torch.vtensor<[2],f32>
} }
// ----- // -----
// CHECK-LABEL: func.func @propagate_through_multiple_ops( // CHECK-LABEL: func.func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH2:.*]] = torch.aten.tanh %[[TANH1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[TANH3:.*]] = torch.tensor_static_info_cast %[[TANH2]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[TANH3]] : !torch.vtensor // CHECK: return %[[COS3]] : !torch.vtensor
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor %2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
%3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor %3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
return %3 : !torch.vtensor return %3 : !torch.vtensor
} }
@ -47,99 +47,18 @@ func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torc
// refinement. // refinement.
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement( // CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH0]] : !torch.vtensor<*,f32> to !torch.vtensor // CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> // CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor // CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor %3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
return %1, %1 : !torch.vtensor, !torch.vtensor return %1, %1 : !torch.vtensor, !torch.vtensor
} }
// ----- // -----
// CHECK-LABEL: func.func @type_promotion$same_category_different_width(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],si64>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],si64> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func.func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func.func @type_promotion$different_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func.func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
%int3 = torch.constant.int 3
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func.func @type_promotion$same_category_zero_rank_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func.func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func.func @type_promotion$zero_rank_higher_category(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.int 2
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func.func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%int2 = torch.constant.int 2
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func.func @type_promotion$alpha_wider(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
func.func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
%float2.300000e00 = torch.constant.float 2.300000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],unk>
return %0 : !torch.vtensor<[?],unk>
}
// -----
// CHECK-LABEL: func.func @type_promotion_scalar_operation(
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
// CHECK: return %[[RET]] : !torch.number
func.func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
return %ret : !torch.number
}
// -----
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static( // CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, // CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { // CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {

View File

@ -0,0 +1,56 @@
// RUN: torch-mlir-opt -torch-reify-dtype-calculations -split-input-file %s | FileCheck %s
// CHECK: module {
// CHECK: func.func private @__torch_mlir_dtype_fn.aten.tanh(
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[RESULT:.*]] = torch.dtype.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor
// CHECK: torch.dtype.calculate.yield %[[TANH]] : !torch.vtensor
// CHECK: } dtypes {
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.tanh(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int
// CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int
// CHECK: } : !torch.vtensor
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func private @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(
// CHECK: {{.*}} = torch.promote_dtypes {{.*}} : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
// CHECK: {{.*}} = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes({{.*}}
// CHECK-LABEL: func.func @op_with_dtype_promotion(
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide({{.*}}
func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
// CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
// CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int
// CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
// CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list<int> -> !torch.int
// CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int
func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}

View File

@ -0,0 +1,285 @@
// RUN: torch-mlir-opt -torch-simplify-dtype-calculations -split-input-file %s | FileCheck %s
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[DTYPE_INT:.*]] = torch.constant.int 6
// CHECK: %[[RESULT:.*]] = torch.dtype.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: torch.dtype.calculate.yield %[[TANH]] : !torch.vtensor<*,f32>
// CHECK: } dtypes {
// CHECK: torch.dtype.calculate.yield.dtypes %[[DTYPE_INT]] : !torch.int
// CHECK: } : !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RESULT]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%0 = torch.dtype.calculate {
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
torch.dtype.calculate.yield %1 : !torch.vtensor
} dtypes {
%2 = torch.prim.dtype %arg0 : !torch.vtensor<*,f32> -> !torch.int
torch.dtype.calculate.yield.dtypes %2 : !torch.int
} : !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_same_category_different_width(
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f64>
func.func @promote_dtypes$tensor_tensor_same_category_different_width(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[1],f64>, %arg2: !torch.float) {
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%ranks = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[1],f64> -> !torch.int
%dtypes = torch.prim.ListConstruct %f32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %3 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_different_category(
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f64>
func.func @promote_dtypes$tensor_tensor_different_category(%arg0: !torch.vtensor<[1],si32>, %arg1: !torch.vtensor<[1],f64>, %arg2: !torch.float) {
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%si32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si32> -> !torch.int
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[1],f64> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
%dtypes = torch.prim.ListConstruct %si32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %3 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_same_category_zero_rank_wider(
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
func.func @promote_dtypes$tensor_tensor_same_category_zero_rank_wider(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.int) {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[],f64> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
%dtypes = torch.prim.ListConstruct %f32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %3 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_zero_rank_higher_category(
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
func.func @promote_dtypes$tensor_tensor_zero_rank_higher_category(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.int) {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%si64_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si64> -> !torch.int
%f32_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[],f32> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
%dtypes = torch.prim.ListConstruct %si64_dtype, %f32_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %3 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_alpha_wider_no_contribution(
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
func.func @promote_dtypes$tensor_tensor_alpha_wider_no_contribution(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.float) {
%int1 = torch.constant.int 1
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
%alpha_as_tensor = torch.prim.NumToTensor.Scalar %arg2 : !torch.float -> !torch.tensor<[],f64>
%f64_dtype = torch.prim.dtype %alpha_as_tensor : !torch.tensor<[],f64> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %int1, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>
%dtypes = torch.prim.ListConstruct %f32_dtype, %f32_dtype, %f64_dtype : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %3 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_scalar_scalar_higher_category(
// CHECK: {{.*}} = torch.aten.add.Scalar {{.*}} -> !torch.vtensor<[1],f32>
func.func @promote_dtypes$tensor_scalar_scalar_higher_category(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.float, %arg2: !torch.int) {
%none = torch.constant.none
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Scalar %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%si64_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si64> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.float -> !torch.tensor<[],f64>
%f64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],f64> -> !torch.int
%dtypes = torch.prim.ListConstruct %si64_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %5 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$tensor_scalar_scalar_same_category_wider(
// CHECK: {{.*}} = torch.aten.add.Scalar {{.*}} -> !torch.vtensor<[1],si32>
func.func @promote_dtypes$tensor_scalar_scalar_same_category_wider(%arg0: !torch.vtensor<[1],si32>, %arg1: !torch.int, %arg2: !torch.int) {
%none = torch.constant.none
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%0 = torch.dtype.calculate {
%1 = torch.aten.add.Scalar %arg0, %arg1, %arg2 : !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
} dtypes {
%si32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si32> -> !torch.int
%ranks = torch.prim.ListConstruct %int1, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.tensor<[],si64>
%si64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],si64> -> !torch.int
%dtypes = torch.prim.ListConstruct %si32_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %5 : !torch.int
} : !torch.vtensor<[1],unk>
return
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$scalar_scalar_different_category(
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.float
func.func @promote_dtypes$scalar_scalar_different_category(%arg0: !torch.float, %arg1: !torch.int) -> !torch.number {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.float, !torch.int -> !torch.number
torch.dtype.calculate.yield %1 : !torch.number
} dtypes {
%ranks = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>
%arg0_as_tensor = torch.prim.NumToTensor.Scalar %arg0 : !torch.float -> !torch.tensor<[],f64>
%f64_dtype = torch.prim.dtype %arg0_as_tensor : !torch.tensor<[],f64> -> !torch.int
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.tensor<[],si64>
%si64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],si64> -> !torch.int
%dtypes = torch.prim.ListConstruct %f64_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %7 : !torch.int
} : !torch.number
return %0 : !torch.number
}
// -----
// CHECK-LABEL: func.func @promote_dtypes$scalar_scalar_same_category(
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
func.func @promote_dtypes$scalar_scalar_same_category(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
torch.dtype.calculate.yield %1 : !torch.number
} dtypes {
%ranks = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>
%arg0_as_tensor = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.tensor<[],si64>
%si64_dtype = torch.prim.dtype %arg0_as_tensor : !torch.tensor<[],si64> -> !torch.int
%dtypes = torch.prim.ListConstruct %si64_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
torch.dtype.calculate.yield.dtypes %7 : !torch.int
} : !torch.number
return %0 : !torch.number
}
// -----
// CHECK-LABEL: func.func @refine_dtype$invalid_dtype_for_scalar(
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.number
func.func @refine_dtype$invalid_dtype_for_scalar(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
torch.dtype.calculate.yield %1 : !torch.number
} dtypes {
// dtype int for int32
%int3 = torch.constant.int 3
torch.dtype.calculate.yield.dtypes %int3 : !torch.int
} : !torch.number
return %0 : !torch.number
}
// -----
// CHECK-LABEL: func.func @refine_dtype$no_simplification
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.number
func.func @refine_dtype$no_simplification(%arg0: !torch.int, %arg1: !torch.int, %dtype: !torch.int) -> !torch.number {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
torch.dtype.calculate.yield %1 : !torch.number
} dtypes {
torch.dtype.calculate.yield.dtypes %dtype : !torch.int
} : !torch.number
return %0 : !torch.number
}
// -----
// If result type is already refined (even if wrong, as is the case here),
// don't make any changes to result type.
// TODO: This case should result in an error
// CHECK-LABEL: func.func @refine_dtype$result_type_already_refined
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
func.func @refine_dtype$result_type_already_refined(%arg0: !torch.float, %arg1: !torch.float) -> !torch.int {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.float, !torch.float -> !torch.int
torch.dtype.calculate.yield %1 : !torch.int
} dtypes {
// dtype int for float64
%int7 = torch.constant.int 7
torch.dtype.calculate.yield.dtypes %int7 : !torch.int
} : !torch.int
return %0 : !torch.int
}
// -----
// CHECK-LABEL: func.func @refine_dtype$derefine_result_type(
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
// CHECK: %[[ERASED:.*]] = torch.derefine {{.*}} : !torch.int to !torch.number
// CHECK: return %[[ERASED]] : !torch.number
func.func @refine_dtype$derefine_result_type(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
%none = torch.constant.none
%0 = torch.dtype.calculate {
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
torch.dtype.calculate.yield %1 : !torch.number
} dtypes {
// dtype int for int64
%int4 = torch.constant.int 4
torch.dtype.calculate.yield.dtypes %int4 : !torch.int
} : !torch.number
return %0 : !torch.number
}