mirror of https://github.com/llvm/torch-mlir
[custom op] Generalize shape library logic to work with dtypes (#1594)
* [custom op] Generalize shape library logic to work with dtypes This commit generalizes the shape library logic, so that dtype rules for ops can also be expressed using the same mechanism. In other words, each op can now have a shape function and a dtype function specified in Python that is imported during lowering to calculate the shapes and dtypes throught a program. For more information about how to specify a dtype function, see the updated `docs/adding_a_shape_and_dtype_function.md`. For those not familiar with how the shape library works, the file `docs/calculations_lib.md` provides an overview.pull/1716/head
parent
2acf7da63c
commit
a710237437
|
@ -61,14 +61,14 @@ jobs:
|
||||||
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
|
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
|
||||||
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
|
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
|
||||||
|
|
||||||
- name: Build and test (in-tree), also update ODS and shape library
|
- name: Build and test (in-tree), also update ODS and abstract interpretation library
|
||||||
if: env.PT_HASH_CHANGED != '0'
|
if: env.PT_HASH_CHANGED != '0'
|
||||||
run: |
|
run: |
|
||||||
cd ${GITHUB_WORKSPACE}
|
cd ${GITHUB_WORKSPACE}
|
||||||
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
|
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
|
||||||
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
|
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
|
||||||
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
|
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
|
||||||
TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \
|
TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \
|
||||||
./build_tools/python_deploy/build_linux_packages.sh
|
./build_tools/python_deploy/build_linux_packages.sh
|
||||||
|
|
||||||
- name: Push changes to main branch
|
- name: Push changes to main branch
|
||||||
|
@ -79,7 +79,7 @@ jobs:
|
||||||
git config user.name "Roll PyTorch Action"
|
git config user.name "Roll PyTorch Action"
|
||||||
git fetch --recurse-submodules=no
|
git fetch --recurse-submodules=no
|
||||||
git checkout main
|
git checkout main
|
||||||
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||||
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
|
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
|
||||||
|
|
||||||
- name: Update PyTorch Build Cache (if running on main branch)
|
- name: Update PyTorch Build Cache (if running on main branch)
|
||||||
|
|
|
@ -53,8 +53,8 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
|
||||||
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
||||||
# Skip running tests if you want quick iteration
|
# Skip running tests if you want quick iteration
|
||||||
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
|
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
|
||||||
# Update ODS and shape library files
|
# Update ODS and abstract interpretation library files
|
||||||
TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}"
|
TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB:-OFF}"
|
||||||
|
|
||||||
PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
|
PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
|
||||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
|
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
|
||||||
|
@ -119,7 +119,7 @@ function run_on_host() {
|
||||||
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
|
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
|
||||||
-e "TM_PACKAGES=${package}" \
|
-e "TM_PACKAGES=${package}" \
|
||||||
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
|
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
|
||||||
-e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \
|
-e "TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB=${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" \
|
||||||
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
|
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
|
||||||
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
|
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
|
||||||
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
|
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
|
||||||
|
@ -164,10 +164,10 @@ function run_in_docker() {
|
||||||
in-tree)
|
in-tree)
|
||||||
setup_venv "$python_version"
|
setup_venv "$python_version"
|
||||||
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
|
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
|
||||||
if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then
|
if [ "${TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB}" == "ON" ]; then
|
||||||
pushd /main_checkout/torch-mlir
|
pushd /main_checkout/torch-mlir
|
||||||
./build_tools/update_torch_ods.sh
|
./build_tools/update_torch_ods.sh
|
||||||
./build_tools/update_shape_lib.sh
|
./build_tools/update_abstract_interp_lib.sh
|
||||||
popd
|
popd
|
||||||
fi
|
fi
|
||||||
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
|
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
|
||||||
|
@ -253,8 +253,8 @@ function test_in_tree() {
|
||||||
cd /main_checkout/torch-mlir/
|
cd /main_checkout/torch-mlir/
|
||||||
export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir"
|
export PYTHONPATH="/main_checkout/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir"
|
||||||
|
|
||||||
echo ":::: Check that update_shape_lib.sh has been run"
|
echo ":::: Check that update_abstract_interp_lib.sh has been run"
|
||||||
_check_file_not_changed_by ./build_tools/update_shape_lib.sh lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
|
_check_file_not_changed_by ./build_tools/update_abstract_interp_lib.sh lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
|
||||||
|
|
||||||
echo ":::: Check that update_torch_ods.sh has been run"
|
echo ":::: Check that update_torch_ods.sh has been run"
|
||||||
_check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
_check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Updates auto-generated shape library files for the `torch` dialect.
|
# Updates auto-generated abstract interpretation library files for the
|
||||||
|
# `torch` dialect.
|
||||||
#
|
#
|
||||||
# Environment variables:
|
# Environment variables:
|
||||||
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
|
# TORCH_MLIR_EXT_MODULES: comma-separated list of python module names
|
||||||
|
@ -41,6 +42,6 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
PYTHONPATH="${pypath}" python \
|
PYTHONPATH="${pypath}" python \
|
||||||
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.shape_lib_gen \
|
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.abstract_interp_lib_gen \
|
||||||
--pytorch_op_extensions=${ext_module:-""} \
|
--pytorch_op_extensions=${ext_module:-""} \
|
||||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
|
@ -0,0 +1,130 @@
|
||||||
|
# Torch-MLIR Abstract Interpretation Library Infrastructure
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Torch-MLIR project has an infrastructure for maintaining a library of
|
||||||
|
calculation functions for different Torch operators, which supply extra
|
||||||
|
information such as result dtypes and shapes as well as decompositions. These
|
||||||
|
functions are fully executable specifications of the shape/dtype/decomposition
|
||||||
|
functions for each operator and can be authored and tested from Python for
|
||||||
|
convenience. These are then brought into the compiler and can be manipulated /
|
||||||
|
transformed for various purposes. Additionally, in the case of shape functions,
|
||||||
|
this effort is synergistic with upstream PyTorch efforts to maintain a library
|
||||||
|
of shape functions.
|
||||||
|
|
||||||
|
The two main use cases are:
|
||||||
|
|
||||||
|
- Refinement / inference. The `torch-shape-refinement-pipeline` and
|
||||||
|
`torch-dtype-refinement-pipeline` pass pipelines orchestrate a series of
|
||||||
|
passes that use the available information in the program to further refine the
|
||||||
|
types in the program.
|
||||||
|
|
||||||
|
- Error guard insertion for backends (Not Yet Implemented). The executable
|
||||||
|
functions can include error guards / assertions that abort the program in case
|
||||||
|
of invalid input (such as a matmul with a mismatching contracting dimension).
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Functions are defined as TorchScript-able Python functions in
|
||||||
|
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py`.
|
||||||
|
The signatures of the functions are systematically derived from Torch JIT
|
||||||
|
operator registry. Most shape functions are expected to reuse the upstream
|
||||||
|
helper functions
|
||||||
|
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
|
||||||
|
and any new shape functions should be added there.
|
||||||
|
|
||||||
|
The `build_tools/update_abstract_interp_lib.sh` script invokes
|
||||||
|
`abstract_interp_lib_gen.py` to generate an MLIR module containing the functions,
|
||||||
|
which is currently embedded as a string literal in
|
||||||
|
`lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp`.
|
||||||
|
|
||||||
|
The function `StringRef mlir::torch::Torch::getAbstractInterpLibrary()` is
|
||||||
|
available for use inside the compiler any time that the library is needed.
|
||||||
|
|
||||||
|
## Shape and Dtype Refinement Pipeline Architecture
|
||||||
|
|
||||||
|
One of the main services that Torch-MLIR provides for backends is to normalize
|
||||||
|
all Torch frontends into a common form which includes tensor shapes and dtypes
|
||||||
|
that are as precise as possible. This alleviates the need for backends to solve
|
||||||
|
this problem themselves. This process of shape and dtype refinement is
|
||||||
|
accomplished in Torch-MLIR through a pipeline of passes which uses the abstract
|
||||||
|
interpretation library combined with abstract interpretation of the calculation
|
||||||
|
functions to calculate shapes and dtypes that are as precise as possible.
|
||||||
|
|
||||||
|
The pipeline works as follows:
|
||||||
|
|
||||||
|
1. Calculations are reified. The `torch-reify-shape-calculations` and
|
||||||
|
`torch-reify-dtype-calculations` passes reify (i.e., materializes into the
|
||||||
|
IR) the functions for each op with a function in the calculation library. To
|
||||||
|
do this, the passes wrap those ops in a `torch.shape.calculate` or
|
||||||
|
`torch.dtype.calculate` op, respectively, which has two regions: 1) a body
|
||||||
|
with the op itself, and 2) the shape or dtype calculation, which calculates
|
||||||
|
the shapes or dtypes of the tensors yielded by the body.
|
||||||
|
|
||||||
|
2. Simplifying the functions and propagating the shapes and dtypes. After the
|
||||||
|
functions are reified, we then attempt to "optimize hard enough" until the
|
||||||
|
shapes and dtypes yielded by the calculation regions become obvious in the IR.
|
||||||
|
Those results are propagated through the IR, which usually reveals more
|
||||||
|
opportunities for simplification.
|
||||||
|
|
||||||
|
a. After reification, the functions are just a loose collection of
|
||||||
|
functions, which are difficult to analyze. The first step is to inline them.
|
||||||
|
|
||||||
|
b. After inlining, the `torch-simplify-shape-calculations` and
|
||||||
|
`torch-simplify-dtype-calculations` passes are used to simplify the
|
||||||
|
calculations. These passes bring in a number of targeted canonicalization
|
||||||
|
patterns and folds, along with a few specific patterns such as unrolling
|
||||||
|
fixed-trip-count loops and abstractly interpreting list operations (an
|
||||||
|
example is turning a series of "append" operations into a list
|
||||||
|
literal). These passes also look at the values yielded by the calculation
|
||||||
|
regions, and if the resulting shape or dtype can be deduced by looking at the
|
||||||
|
IR (for example, the shape is the list literal `[1, 2, 3]`), it will refine
|
||||||
|
the types of the `torch.shape.calculate` and `torch.dtype.calculate`
|
||||||
|
ops. This usually yields more opportunities for simplification. This process
|
||||||
|
runs to a fixed-point.
|
||||||
|
|
||||||
|
3. Dropping the calculations. Once all the types in the program have been
|
||||||
|
refined as much as possible, the ops that were originally wrapped in
|
||||||
|
`torch.shape.calculate` and `torch.dtype.calculate` are unwrapped by the
|
||||||
|
`torch-drop-abstract-interp-calculations` pass which drops the reified
|
||||||
|
calculations, leaving behind the shape and dtype refined program.
|
||||||
|
|
||||||
|
Inferring precise shapes and dtypes often is needed for correctness by
|
||||||
|
backends. That said, requiring "optimizing hard enough" for correctness is
|
||||||
|
usually considered quite brittle in a compiler flow. In this case, the saving
|
||||||
|
grace is that we are only optimizing the functions, which are authored by
|
||||||
|
compiler developers (not users), and thus there is some give-and-take in terms
|
||||||
|
of understanding the optimizable constructs while authoring the functions, or
|
||||||
|
improving the optimizations to enable easier authoring. Some brittleness is
|
||||||
|
likely to escape to users, unfortunately, since there will always be situations
|
||||||
|
where, for example, a statically shaped program allows the shape functions to be
|
||||||
|
simplified to a greater extent than in a dynamically shaped program (for
|
||||||
|
example, if the shape function checks "is this dimension of size 1"). We hope
|
||||||
|
that this is minimal.
|
||||||
|
|
||||||
|
## Adding to the abstract interpretation library
|
||||||
|
|
||||||
|
See [Adding a Shape and Dtype Function](adding_a_shape_and_dtype_function.md)
|
||||||
|
for details on how to add a shape and dtype function for an operator.
|
||||||
|
|
||||||
|
## Rationale
|
||||||
|
|
||||||
|
### Use of full operator signatures
|
||||||
|
|
||||||
|
The use of the full operator signature such as
|
||||||
|
`def aten〇add〇Tensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:`
|
||||||
|
for defining calculation functions is somewhat verbose and repetitive, especially when
|
||||||
|
there are multiple identical functions. Upstream uses a map with key-value
|
||||||
|
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is
|
||||||
|
more compact and less repetitive in some ways (upstream also allows trailing
|
||||||
|
arguments beyond those accepted by the shape function to be ignored, allowing
|
||||||
|
further deduplication). The decision to do it the more verbose way in Torch-MLIR
|
||||||
|
was based on the following goals:
|
||||||
|
|
||||||
|
- To make the system very easy to debug and test.
|
||||||
|
|
||||||
|
- To make the system maximally consistent between functions that are
|
||||||
|
implemented with the upstream shape helpers and the ones that are manually
|
||||||
|
written, which are still a fairly large and non-trivial set.
|
||||||
|
|
||||||
|
- To make it as mechanical as possible to add a new function.
|
|
@ -1,75 +0,0 @@
|
||||||
# Adding a Shape Function
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
As part of adding support for a Torch operator in Torch-MLIR, it is usually
|
|
||||||
necessary to define a shape function so that the compiler can infer the shapes
|
|
||||||
of result tensors for the operator. We use the [shape library](shape_lib.md) for this process.
|
|
||||||
|
|
||||||
## Step-by-step guide
|
|
||||||
|
|
||||||
We will use the example of adding support for the `torch.aten.tanh` op.
|
|
||||||
|
|
||||||
1. First, you need to find the shape function signature for the operator you are
|
|
||||||
implementing a shape function for. This can be found in
|
|
||||||
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt` generated
|
|
||||||
by the `build_tools/update_torch_ods.sh` script. That file is the "rosetta
|
|
||||||
stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`,
|
|
||||||
and the shape function signature
|
|
||||||
`def aten〇tanh(self: List[int]) -> List[int]:`. Note the use of `〇` as a
|
|
||||||
separator since `.` or `::` aren't legal in a Python identifier.
|
|
||||||
|
|
||||||
2. Paste the shape function signature into `shape_lib_gen.py` in an appropriate
|
|
||||||
place (ideally near other functions with a similar shape function). Note that
|
|
||||||
`shape_lib_gen.py` will check that this signature is verbatim identical with
|
|
||||||
the one given in `JITOperatorRegistryDump.txt` -- this ensures that the shape
|
|
||||||
functions don't get outdated if Torch changes an operator signature.
|
|
||||||
|
|
||||||
3. Fill in the body of the shape function. Ideally this will just be a call into
|
|
||||||
a helper function from
|
|
||||||
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
|
|
||||||
But in general, you will need to write the shape function and test it (see
|
|
||||||
the comments about "Shape function testing infrastructure" in
|
|
||||||
`shape_lib_gen.py`). New shape functions should be added upstream following
|
|
||||||
the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
|
|
||||||
though it can be useful to iterate locally in `shape_lib_gen.py` first.
|
|
||||||
|
|
||||||
4. Re-run the `build_tools/update_shape_lib.sh` script to update the shape
|
|
||||||
library. After this step happens, ideally everything "just works" and the
|
|
||||||
shape is now correctly inferred for the operator.
|
|
||||||
|
|
||||||
## When things go wrong
|
|
||||||
|
|
||||||
It is possible that the shape refinement pipeline (see
|
|
||||||
[Shape Refinement Pipeline Architecture](shape_lib.md#shape-refinement-pipeline-architecture))
|
|
||||||
is not able to infer the shape of a tensor with a given shape function. This
|
|
||||||
usually means that there is something about the shape function which the
|
|
||||||
optimizations in `torch-simplify-shape-functions`
|
|
||||||
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`) cannot handle.
|
|
||||||
|
|
||||||
To debug this, the overall goal is to pinpoint the IR construct that is not
|
|
||||||
being simplified. This is usually accomplished by a combination of looking at
|
|
||||||
the Python code for the shape function and the IR dumps. The best IR dump to
|
|
||||||
look at varies, but frequently the IR dump right before `DropShapeCalculations`
|
|
||||||
is the most useful, because it has already been simplified as much as possible,
|
|
||||||
making it is easy to see what is blocking further simplification. Examples of
|
|
||||||
issues you might see:
|
|
||||||
|
|
||||||
- You might find that there is a loop with a non-constant trip count, but based
|
|
||||||
on your understanding of the shape function, you would expect it to be
|
|
||||||
simplified to a constant trip count -- you can then look at the trip count
|
|
||||||
calculation and see if there is a missing fold or canonicalization.
|
|
||||||
|
|
||||||
- You might find that there is a list operation that is not currently understood
|
|
||||||
by the optimizations. You can then teach the optimizations about that
|
|
||||||
operation.
|
|
||||||
|
|
||||||
- You might find that there is an `Optional` value that you would expect to be
|
|
||||||
resolved to either a concrete value or `None`. You can then look at the calculation that produces the optional value and see what folds or canonicalizations are missing.
|
|
||||||
|
|
||||||
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
|
|
||||||
guidance on debugging Torch-MLIR.
|
|
||||||
|
|
||||||
As a last resort, you can rewrite the shape function using constructs that
|
|
||||||
`torch-simplify-shape-functions` can handle (look at other shape functions for
|
|
||||||
examples, sometimes it requires writing things a little awkwardly).
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Adding Abstract Interpretation Functions
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
As part of adding support for a Torch operator in Torch-MLIR, it is usually
|
||||||
|
necessary to define a shape and dtype function so that the compiler can infer
|
||||||
|
the shapes and dtypes of result tensors for the operator. We use the
|
||||||
|
[abstract interpretation library](abstract_interp_lib.md) for this process.
|
||||||
|
|
||||||
|
## Step-by-step guide
|
||||||
|
|
||||||
|
We will use the example of adding support for the `torch.aten.tanh` op.
|
||||||
|
|
||||||
|
1. First, you need to find the shape and dtype function signatures for
|
||||||
|
the operator you are implementing a functions for. This can be
|
||||||
|
found in
|
||||||
|
`include/torch-mlir/Dialect/Torch/IR/JITOperatorRegistryDump.txt`
|
||||||
|
generated by the `build_tools/update_torch_ods.sh` script. That
|
||||||
|
file is the "rosetta stone" that allows translating between
|
||||||
|
e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype
|
||||||
|
function signatures are:
|
||||||
|
|
||||||
|
- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
|
||||||
|
- `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:`
|
||||||
|
|
||||||
|
Note the use of `〇` as a separator since `.` or `::` aren't legal
|
||||||
|
in a Python identifier.
|
||||||
|
|
||||||
|
2. Paste the function signature into `abstract_interp_lib_gen.py` in an
|
||||||
|
appropriate place (ideally near other functions with a similar
|
||||||
|
functions). Note that `abstract_interp_lib_gen.py` will check that
|
||||||
|
these signatures are verbatim identical with the ones given in
|
||||||
|
`JITOperatorRegistryDump.txt` -- this ensures that the functions
|
||||||
|
don't get outdated if Torch changes an operator signature.
|
||||||
|
|
||||||
|
3. Fill in the body of the function. Ideally this will just be a call
|
||||||
|
into a helper function from
|
||||||
|
[`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1).
|
||||||
|
But in general, you will need to write the function and test it
|
||||||
|
(see the comments about "Shape, dtype, and decomposition function
|
||||||
|
testing infrastructure" in `testing_framework.py`). New shape
|
||||||
|
functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889),
|
||||||
|
though it can be useful to iterate locally in `abstract_interp_lib_gen.py`
|
||||||
|
first.
|
||||||
|
|
||||||
|
Similarly, dtype functions should ideally just be a call to the helper
|
||||||
|
`promote_dtypes` defined in `library_generator.py`. However, some ops will
|
||||||
|
require some extra logic to calculate the right result types. While dtypes
|
||||||
|
are expressed as `int`s in the arguments of the dtype function, using PyTorch
|
||||||
|
dtypes, such as `torch.int` and `torch.float32`, in the body of the dtype
|
||||||
|
function is fully supported. Dtype functions are also expected to be fully
|
||||||
|
tested.
|
||||||
|
|
||||||
|
4. Re-run the `build_tools/update_abstract_interp_lib.sh` script to
|
||||||
|
update the library. After this step happens, ideally everything
|
||||||
|
"just works" and the functions are now correctly inferred for the
|
||||||
|
operator.
|
||||||
|
|
||||||
|
## When things go wrong
|
||||||
|
|
||||||
|
It is possible that the refinement pipeline (see [Shape and Dtype Refinement Pipeline Architecture](abstract_interp_lib.md#shape-and-dtype-refinement-pipeline-architecture))
|
||||||
|
is not able to infer the shape or dtype of a tensor with a given
|
||||||
|
abstract interpretation function. This usually means that there is something
|
||||||
|
about the function which the optimizations in
|
||||||
|
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions`
|
||||||
|
(`lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp`,
|
||||||
|
`lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp`)
|
||||||
|
cannot handle.
|
||||||
|
|
||||||
|
To debug this, the overall goal is to pinpoint the IR construct that is not
|
||||||
|
being simplified. This is usually accomplished by a combination of looking at
|
||||||
|
the Python code for the function and the IR dumps. The best IR dump to look at
|
||||||
|
varies, but frequently the IR dump right before `DropAbstractInterpCalculations`
|
||||||
|
is the most useful, because it has already been simplified as much as possible,
|
||||||
|
making it is easy to see what is blocking further simplification. Examples of
|
||||||
|
issues you might see:
|
||||||
|
|
||||||
|
- You might find that there is a loop with a non-constant trip count,
|
||||||
|
but based on your understanding of the function, you would expect it
|
||||||
|
to be simplified to a constant trip count -- you can then look at
|
||||||
|
the trip count calculation and see if there is a missing fold or
|
||||||
|
canonicalization.
|
||||||
|
|
||||||
|
- You might find that there is a list operation that is not currently understood
|
||||||
|
by the optimizations. You can then teach the optimizations about that
|
||||||
|
operation.
|
||||||
|
|
||||||
|
- You might find that there is an `Optional` value that you would
|
||||||
|
expect to be resolved to either a concrete value or `None`. You can
|
||||||
|
then look at the calculation that produces the optional value and
|
||||||
|
see what folds or canonicalizations are missing.
|
||||||
|
|
||||||
|
See [this video](https://www.youtube.com/watch?v=E5epCJOtrf8) for general
|
||||||
|
guidance on debugging Torch-MLIR.
|
||||||
|
|
||||||
|
As a last resort, you can rewrite the function using constructs that
|
||||||
|
`torch-simplify-shape-functions` and `torch-simplify-dtype-functions` can handle
|
||||||
|
(look at other functions for examples, sometimes it requires writing things a
|
||||||
|
little awkwardly).
|
|
@ -446,10 +446,12 @@ LLVM updates in other projects like ONNX-MLIR and MLIR-HLO.
|
||||||
working on). If these fixes are too complex, please file a work-in-progress
|
working on). If these fixes are too complex, please file a work-in-progress
|
||||||
PR explaining the issues you are running into asking for help so that someone
|
PR explaining the issues you are running into asking for help so that someone
|
||||||
from the community can help.
|
from the community can help.
|
||||||
5. **Update Shape Library**: Run `build_tools/update_shape_lib.sh`. This is
|
5. **Update Abstract Interpretation Library**: Run
|
||||||
sometimes needed because upstream changes can affect canonicalization and
|
`build_tools/update_abstract_interp_lib.sh`. This is sometimes needed
|
||||||
other minor details of the IR in the shape library. See
|
because upstream changes can affect canonicalization and other minor details
|
||||||
[docs/shape_lib.md](docs/shape_lib.md) for more details on the shape library.
|
of the IR in the shape library. See
|
||||||
|
[docs/abstract_interp_lib.md](docs/abstract_interp_lib.md) for more details
|
||||||
|
on the abstract interpretation library.
|
||||||
|
|
||||||
|
|
||||||
Here are some examples of PRs updating the LLVM and MLIR-HLO submodules:
|
Here are some examples of PRs updating the LLVM and MLIR-HLO submodules:
|
||||||
|
|
|
@ -1,121 +0,0 @@
|
||||||
# Torch-MLIR Shape Library Infrastructure
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The Torch-MLIR project has an infrastructure for maintaining a library of shape
|
|
||||||
functions for different Torch operators. These shape functions are fully
|
|
||||||
executable specifications of the shape functions for each operator and can be
|
|
||||||
authored and tested from Python for convenience. These are then brought into the
|
|
||||||
compiler and can be manipulated / transformed for various purposes.
|
|
||||||
Additionally, this effort is synergistic with upstream PyTorch efforts to
|
|
||||||
maintain a library of shape functions.
|
|
||||||
|
|
||||||
The two main use cases are:
|
|
||||||
|
|
||||||
- Shape refinement / shape inference. The `torch-shape-refinement-pipeline` pass
|
|
||||||
pipeline orchestrates a series of passes that use the available shape information in the program to further refine the types in the program.
|
|
||||||
|
|
||||||
- Error guard insertion for backends (Not Yet Implemented). The executable shape
|
|
||||||
functions can include error guards / assertions that abort the program in case
|
|
||||||
of invalid input (such as a matmul with a mismatching contracting dimension).
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
Shape functions are defined as TorchScript-able Python functions in
|
|
||||||
`python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py`.
|
|
||||||
The signatures of the shape functions are systematically derived from Torch JIT
|
|
||||||
operator registry (mainly by replacing `Tensor` with `List[int]` in the operator
|
|
||||||
signatures). Most shape functions are expected to reuse the upstream helper
|
|
||||||
functions [`torch/jit/_shape_functions.py`](https://github.com/pytorch/pytorch/blob/279634f384662b7c3a9f8bf7ccc3a6afd2f05657/torch/jit/_shape_functions.py#L1),
|
|
||||||
and any new shape functions should be added there.
|
|
||||||
|
|
||||||
The `build_tools/update_shape_lib.sh` script invokes `shape_lib_gen.py` to
|
|
||||||
generate an MLIR module containing the shape functions, which is currently
|
|
||||||
embedded as a string literal in `lib/Dialect/Torch/Transforms/ShapeLibrary.cpp`.
|
|
||||||
|
|
||||||
The function `StringRef mlir::torch::Torch::getShapeLibrary()` is available for
|
|
||||||
use inside the compiler any time that the shape library is needed.
|
|
||||||
|
|
||||||
## Shape Refinement Pipeline Architecture
|
|
||||||
|
|
||||||
One of the main services that Torch-MLIR provides for backends is to normalize
|
|
||||||
all Torch frontends into a common form which includes tensor shapes that are as
|
|
||||||
precise as possible. This alleviates the need for backends to solve this problem
|
|
||||||
themselves. This process of shape refinement is accomplished in Torch-MLIR
|
|
||||||
through a pipeline of passes which uses the shape library combined with abstract
|
|
||||||
interpretation of the shape functions to calculate shapes that are as precise as
|
|
||||||
possible.
|
|
||||||
|
|
||||||
The pipeline works as follows:
|
|
||||||
|
|
||||||
1. Shape calculations are reified. The `torch-reify-shape-calculations` reifies
|
|
||||||
(i.e., materializes into the IR) the shape functions for each op with a shape
|
|
||||||
function in the shape library. To do this, it wraps those ops in a
|
|
||||||
`torch.shape.calculate` op, which has two regions: 1) a body with the op
|
|
||||||
itself, and 2) the shape calculation, which calculates the shapes of the
|
|
||||||
tensors yielded by the body.
|
|
||||||
|
|
||||||
2. Simplifying the shape functions and propagating the shapes. After the shape
|
|
||||||
functions are reified, we then attempt to "optimize hard enough" until the
|
|
||||||
shapes yielded by the shape calculation regions become obvious in the IR.
|
|
||||||
Those shapes are propagated through the IR, which usually reveals more
|
|
||||||
opportunities for simplification.
|
|
||||||
|
|
||||||
a. After reification, the shape functions are just a loose collection of
|
|
||||||
functions, which are difficult to analyze. The first step is to inline them.
|
|
||||||
|
|
||||||
b. After inlining, the `torch-simplify-shape-calculations` pass is used to
|
|
||||||
simplify the shape calculations. This pass brings in a number of targeted
|
|
||||||
canonicalization patterns and folds, along with a few specific patterns such
|
|
||||||
as unrolling fixed-trip-count loops and abstractly interpreting list
|
|
||||||
operations (an example is turning a series of "append" operations into a list
|
|
||||||
literal). This pass also looks at the values yielded by the shape calculation
|
|
||||||
regions, and if the resulting shape can be deduced by looking at the IR (for
|
|
||||||
example, the shape is the list literal `[1, 2, 3]`), it will refine the types
|
|
||||||
of the `torch.shape.calculate` op. This usually yields more opportunities for
|
|
||||||
simplification. This process runs to a fixed-point.
|
|
||||||
|
|
||||||
3. Dropping the shape calculations. Once all the types in the program have been
|
|
||||||
refined as much as possible, the ops that were originally wrapped in
|
|
||||||
`torch.shape.calculate` are unwrapped by the `torch-drop-shape-calculations`
|
|
||||||
pass which drops the reified shape calculations, leaving behind the shape-refined program.
|
|
||||||
|
|
||||||
Inferring precise shape often is needed for correctness by backends. That said,
|
|
||||||
requiring "optimizing hard enough" for correctness is usually considered quite
|
|
||||||
brittle in a compiler flow. In this case, the saving grace is that we are only
|
|
||||||
optimizing the shape functions, which are authored by compiler developers (not
|
|
||||||
users), and thus there is some give-and-take in terms of understanding the
|
|
||||||
optimizable constructs while authoring the shape functions, or improving the
|
|
||||||
optimizations to enable easier authoring. Some brittleness is likely to escape
|
|
||||||
to users, unfortunately, since there will always be situations where, for
|
|
||||||
example, a statically shaped program allows the shape functions to be simplified
|
|
||||||
to a greater extent than in a dynamically shaped program (for example, if the
|
|
||||||
shape function checks "is this dimension of size 1"). We hope that this is
|
|
||||||
minimal.
|
|
||||||
|
|
||||||
## Adding to the shape library
|
|
||||||
|
|
||||||
See [Adding a Shape Function](adding_a_shape_function.md) for details on how to
|
|
||||||
add a shpae function for an operator.
|
|
||||||
|
|
||||||
## Rationale
|
|
||||||
|
|
||||||
### Use of full operator signatures
|
|
||||||
|
|
||||||
The use of the full operator signature such as
|
|
||||||
`def aten〇add〇Tensor(self: List[int], other: List[int], alpha: float = 1) -> List[int]:`
|
|
||||||
for defining shape functions is somewhat verbose and repetitive, especially when
|
|
||||||
there are multiple identical shape functions. Upstream uses a map with key-value
|
|
||||||
pairs like `"aten.add.Tensor": upstream_shape_functions.broadcast`, which is
|
|
||||||
more compact and less repetitive in some ways (upstream also allows trailing
|
|
||||||
arguments beyond those accepted by the shape function to be ignored, allowing
|
|
||||||
further deduplication). The decision to do it the more verbose way in Torch-MLIR
|
|
||||||
was based on the following goals:
|
|
||||||
|
|
||||||
- To make the system very easy to debug and test.
|
|
||||||
|
|
||||||
- To make the system maximally consistent between shape functions that are
|
|
||||||
implemented with the upstream shape helpers and the ones that are manually
|
|
||||||
written, which are still a fairly large and non-trivial set.
|
|
||||||
|
|
||||||
- To make it as mechanical as possible to add a new shape function.
|
|
|
@ -171,6 +171,42 @@ m_TorchListOfConstantInts(SmallVectorImpl<int64_t> &bind_values) {
|
||||||
return detail::torch_list_of_constant_ints_op_binder(bind_values);
|
return detail::torch_list_of_constant_ints_op_binder(bind_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
/// Matches the optional constant integers stored in a `torch.ListConstruct`.
|
||||||
|
struct torch_list_of_optional_constant_ints_op_binder {
|
||||||
|
SmallVectorImpl<Optional<int64_t>> &bind_values;
|
||||||
|
|
||||||
|
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||||
|
torch_list_of_optional_constant_ints_op_binder(
|
||||||
|
SmallVectorImpl<Optional<int64_t>> &bvs)
|
||||||
|
: bind_values(bvs) {}
|
||||||
|
|
||||||
|
bool match(Operation *op) {
|
||||||
|
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||||
|
if (!listConstruct)
|
||||||
|
return false;
|
||||||
|
for (Value value : listConstruct.getElements()) {
|
||||||
|
int64_t num;
|
||||||
|
if (matchPattern(value, m_TorchConstantInt(&num)))
|
||||||
|
bind_values.push_back(num);
|
||||||
|
else if (value.getType().isa<Torch::NoneType>())
|
||||||
|
bind_values.push_back(llvm::None);
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// Matches the optional constant integers stored in a
|
||||||
|
/// `torch.prim.ListConstruct`.
|
||||||
|
inline detail::torch_list_of_optional_constant_ints_op_binder
|
||||||
|
m_TorchListOfOptionalConstantInts(
|
||||||
|
SmallVectorImpl<Optional<int64_t>> &bind_values) {
|
||||||
|
return detail::torch_list_of_optional_constant_ints_op_binder(bind_values);
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
/// Matches the constant bools stored in a `torch.ListConstruct`.
|
/// Matches the constant bools stored in a `torch.ListConstruct`.
|
||||||
struct torch_list_of_constant_bools_op_binder {
|
struct torch_list_of_constant_bools_op_binder {
|
||||||
|
|
|
@ -1101,6 +1101,40 @@ def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.flo
|
||||||
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
|
let assemblyFormat = "$self `,` $p `,` $generator attr-dict `:` type($self) `,` type($p) `,` type($generator) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PromoteDtypesOp: Torch_Op<"promote_dtypes", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "`promote_dtypes op : (int?[], int[]) -> (int)`";
|
||||||
|
let description = [{
|
||||||
|
This op is generated when the python function
|
||||||
|
`__torch_mlir_internal_promote_dtypes` is used in a dtype refinement
|
||||||
|
function. It represents the type promotion logic used by PyTorch to
|
||||||
|
determine result types.
|
||||||
|
|
||||||
|
The first argument is a list of optional ranks for each of the inputs
|
||||||
|
being used for promotion. The ranks are optional to allow representing
|
||||||
|
`Scalar` inputs, which follow their own set of promotion rules.
|
||||||
|
|
||||||
|
The second argument is a list of dtypes for each of the inputs being used
|
||||||
|
for promotion.
|
||||||
|
|
||||||
|
The order of the values in each list must be the same. In other words,
|
||||||
|
the ith rank and the ith dtype must be from the same Scalar/Tensor.
|
||||||
|
|
||||||
|
It is an error to call this op with empty lists or lists of different size.
|
||||||
|
}];
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchListOfOptionalIntType:$ranks,
|
||||||
|
AnyTorchListOfTorchIntType:$dtypes
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_IntType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$ranks `,` $dtypes attr-dict `:` functional-type(operands, results)";
|
||||||
|
}
|
||||||
|
|
||||||
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
// To handle runtime assertions, torchscript provides us `torch._assert` operation.
|
||||||
// But TS compiler introduces control flow for `torch._assert` operation. The
|
// But TS compiler introduces control flow for `torch._assert` operation. The
|
||||||
// `torch._assert` would introduce control flow like:
|
// `torch._assert` would introduce control flow like:
|
||||||
|
@ -1140,31 +1174,31 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
|
||||||
let summary = "Shape calculation encapsulation op";
|
let summary = "Shape calculation encapsulation op";
|
||||||
let description = [{
|
let description = [{
|
||||||
The `torch.shape.calculate` op captures a shape calculation
|
The `torch.shape.calculate` op captures a shape calculation
|
||||||
(in the region `shapeCalculation`) which calculates the shapes for
|
(in the region `calculation`) which calculates the shapes for
|
||||||
the set of values yielded by the `body` region.
|
the set of values yielded by the `body` region.
|
||||||
|
|
||||||
The `shapeCalculation` region yields a `!torch.list<int>` for each
|
The `calculation` region yields a `!torch.list<int>` for each
|
||||||
value yielded by the `body` region.
|
value yielded by the `body` region.
|
||||||
|
|
||||||
Conceptually, the `shapeCalculation` region executes first, then `body`
|
Conceptually, the `calculation` region executes first, then `body`
|
||||||
region. So the `shapeCalculation` region can also contain arbitrary
|
region. So the `calculation` region can also contain arbitrary
|
||||||
assertions or side-effecting code which guard the validity of the execution
|
assertions or side-effecting code which guard the validity of the execution
|
||||||
of the body (typically by terminating the program with a
|
of the body (typically by terminating the program with a
|
||||||
torch.prim.RaiseException op).
|
torch.prim.RaiseException op).
|
||||||
|
|
||||||
The program has undefined behavior if the values yielded by the `body`
|
The program has undefined behavior if the values yielded by the `body`
|
||||||
region do not have the shapes yielded by the `shapeCalculation` region.
|
region do not have the shapes yielded by the `calculation` region.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
let results = (outs Variadic<AnyTorchType>:$results);
|
let results = (outs Variadic<AnyTorchType>:$results);
|
||||||
let regions = (region
|
let regions = (region
|
||||||
SizedRegion<1>:$body,
|
SizedRegion<1>:$body,
|
||||||
SizedRegion<1>:$shapeCalculation
|
SizedRegion<1>:$calculation
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$body `shapes` $shapeCalculation attr-dict `:` type($results)
|
$body `shapes` $calculation attr-dict `:` type($results)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1211,4 +1245,84 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dtype calculation modeling ops.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_DtypeCalculateOp : Torch_Op<"dtype.calculate", [
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
|
||||||
|
let summary = "Dtype calculation encapsulation op";
|
||||||
|
let description = [{
|
||||||
|
The `torch.dtype.calculate` op captures a dtype calculation
|
||||||
|
(in the region `calculation`) which calculates the dtypes for
|
||||||
|
the set of values yielded by the `body` region.
|
||||||
|
|
||||||
|
The `calculation` region yields a `!torch.int` for each
|
||||||
|
value yielded by the `body` region.
|
||||||
|
|
||||||
|
Conceptually, the `calculation` region executes first, then `body`
|
||||||
|
region. So the `calculation` region can also contain arbitrary
|
||||||
|
assertions or side-effecting code which guard the validity of the execution
|
||||||
|
of the body (typically by terminating the program with a
|
||||||
|
torch.prim.RaiseException op).
|
||||||
|
|
||||||
|
The program has undefined behavior if the values yielded by the `body`
|
||||||
|
region do not have the dtypes yielded by the `calculation` region.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins);
|
||||||
|
let results = (outs Variadic<AnyTorchType>:$results);
|
||||||
|
let regions = (region
|
||||||
|
SizedRegion<1>:$body,
|
||||||
|
SizedRegion<1>:$calculation
|
||||||
|
);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$body `dtypes` $calculation attr-dict `:` type($results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_DtypeCalculateYieldOp : Torch_Op<"dtype.calculate.yield", [
|
||||||
|
Terminator,
|
||||||
|
ReturnLike,
|
||||||
|
HasParent<"::mlir::torch::Torch::DtypeCalculateOp">]> {
|
||||||
|
let summary = "yield-like terminator for torch.dtype.calculate";
|
||||||
|
let description = [{
|
||||||
|
This op terminates the `body` region of a `torch.dtype.calculate` op.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<AnyTorchType>:$results
|
||||||
|
);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict ($results^ `:` type($results))?
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes", [
|
||||||
|
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
|
||||||
|
Terminator,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly,
|
||||||
|
HasParent<"::mlir::torch::Torch::DtypeCalculateOp">]> {
|
||||||
|
let summary = "yield-like terminator for torch.dtype.calculate shape region";
|
||||||
|
let description = [{
|
||||||
|
This op terminates the `dtypeCalculation` region of a
|
||||||
|
`torch.dtype.calculate` op.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<Torch_IntType>:$results
|
||||||
|
);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict ($results^ `:` type($results))?
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
|
@ -390,6 +390,9 @@ def AnyTorchListOfTensorType:
|
||||||
def AnyTorchListOfOptionalTensorType :
|
def AnyTorchListOfOptionalTensorType :
|
||||||
ListOf<[AnyTorchOptionalTensorType],
|
ListOf<[AnyTorchOptionalTensorType],
|
||||||
"Any optional tensor list type (Tensor?[])">;
|
"Any optional tensor list type (Tensor?[])">;
|
||||||
|
def AnyTorchListOfOptionalIntType :
|
||||||
|
ListOf<[AnyTorchOptionalIntType],
|
||||||
|
"List of optional ints type (int?[])">;
|
||||||
def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">;
|
def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">;
|
||||||
def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">;
|
def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">;
|
||||||
// Note: TorchScript does not consider !torch.bool to be a Scalar.
|
// Note: TorchScript does not consider !torch.bool to be a Scalar.
|
||||||
|
|
|
@ -80,6 +80,9 @@ void createTorchSimplificationPipeline(
|
||||||
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
/// Creates a pipeline that refines shapes of tensor operations in the program.
|
||||||
void createTorchShapeRefinementPipeline(OpPassManager &pm);
|
void createTorchShapeRefinementPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
|
/// Creates a pipeline that refines dtype of tensor operations in the program.
|
||||||
|
void createTorchDtypeRefinementPipeline(OpPassManager &pm);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||||
|
@ -102,7 +105,13 @@ std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createSimplifyShapeCalculationsPass();
|
createSimplifyShapeCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createReifyDtypeCalculationsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createSimplifyDtypeCalculationsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createDropAbstractInterpCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createEraseModuleInitializerPass();
|
createEraseModuleInitializerPass();
|
||||||
|
@ -113,7 +122,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
||||||
|
|
||||||
StringRef getShapeLibrary();
|
StringRef getAbstractInterpLibrary();
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
||||||
|
|
|
@ -239,7 +239,7 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||||
let summary = "Decompose complicated torch operations";
|
let summary = "Reify shape calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
|
let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
}];
|
}];
|
||||||
|
@ -253,9 +253,23 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func:
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> {
|
def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> {
|
||||||
let summary = "Drop reified shape calculations.";
|
let summary = "Reify dtype calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
|
let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def SimplifyDtypeCalculations : Pass<"torch-simplify-dtype-calculations", "func::FuncOp"> {
|
||||||
|
let summary = "Simplify reified dtype calculations.";
|
||||||
|
let constructor = "mlir::torch::Torch::createSimplifyDtypeCalculationsPass()";
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def DropAbstractInterpCalculations : Pass<"torch-drop-abstract-interp-calculations", "func::FuncOp"> {
|
||||||
|
let summary = "Drop reified abstract interpretation calculations.";
|
||||||
|
let constructor = "mlir::torch::Torch::createDropAbstractInterpCalculationsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,6 @@
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -35,9 +34,24 @@ Type getTypeForTorchType(
|
||||||
MLIRContext *context, Type type,
|
MLIRContext *context, Type type,
|
||||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||||
|
|
||||||
Type getTorchTypeForScalarType(MLIRContext *context,
|
FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
|
||||||
torch_upstream::ScalarType dtypeInt);
|
torch_upstream::ScalarType dtypeInt);
|
||||||
|
|
||||||
|
// This is the type rule used for deciding dtype for:
|
||||||
|
// 1. A new tensor created from given data.
|
||||||
|
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
|
||||||
|
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
|
||||||
|
// If the data is floating-point, the `dtype` is inferred to be the
|
||||||
|
// default dtype, see `torch.get_default_dtype`.
|
||||||
|
Type getDefaultDtypeForTorchScalar(Type type);
|
||||||
|
|
||||||
|
// This is the type rule used for deciding builtin type for:
|
||||||
|
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
|
||||||
|
// PrimNumToTensorScalarOp.
|
||||||
|
// 2. The scalar type for type promotion when a scalar is an operand of scalar
|
||||||
|
// only operation like AtenAddOp.
|
||||||
|
Type getBuiltInTypeForTorchScalar(Type type);
|
||||||
|
|
||||||
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
Type dtype);
|
Type dtype);
|
||||||
// Helper to convert a tensor to a specific scalar type.
|
// Helper to convert a tensor to a specific scalar type.
|
||||||
|
|
|
@ -2294,24 +2294,40 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// ShapeCalculateOp
|
// ShapeCalculateOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ShapeCalculateOp::getSuccessorRegions(
|
template <typename CalculateOp>
|
||||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
static void
|
||||||
|
getSuccessorRegionsForCalculateOp(CalculateOp op, Optional<unsigned> index,
|
||||||
|
ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
(void)operands;
|
|
||||||
|
|
||||||
if (!index.has_value()) {
|
if (!index.has_value()) {
|
||||||
// First thing the op does is branch into the shape calculation.
|
// First thing the op does is branch into the calculation.
|
||||||
regions.emplace_back(&getShapeCalculation());
|
regions.emplace_back(&op.getCalculation());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (*index == 0) {
|
if (*index == 0) {
|
||||||
// Body returns control to the outer op, passing through results.
|
// Body returns control to the outer op, passing through results.
|
||||||
regions.emplace_back(getResults());
|
regions.emplace_back(op.getResults());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
assert(*index == 1);
|
assert(*index == 1);
|
||||||
// Shape calculation branches to the body.
|
// Calculation branches to the body.
|
||||||
regions.emplace_back(&getBody());
|
regions.emplace_back(&op.getBody());
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShapeCalculateOp::getSuccessorRegions(
|
||||||
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
getSuccessorRegionsForCalculateOp(*this, index, operands, regions);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DtypeCalculateOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void DtypeCalculateOp::getSuccessorRegions(
|
||||||
|
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||||
|
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||||
|
getSuccessorRegionsForCalculateOp(*this, index, operands, regions);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2333,6 +2349,25 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DtypeCalculateYieldDtypesOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
|
||||||
|
Optional<unsigned> index) {
|
||||||
|
// The dtype operands don't get forwarded to the body.
|
||||||
|
// MutableOperandRange always has an owning operation, even if empty, so
|
||||||
|
// create a 0-length range.
|
||||||
|
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult DtypeCalculateYieldDtypesOp::verify() {
|
||||||
|
auto parent = cast<DtypeCalculateOp>(getOperation()->getParentOp());
|
||||||
|
if (parent.getNumResults() != getNumOperands())
|
||||||
|
return emitOpError("expected number of dtypes to match number of results");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GlobalSlotModuleInitializerOp
|
// GlobalSlotModuleInitializerOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
//
|
||||||
// This file is auto-generated! Do not edit!!!
|
// This file is auto-generated! Do not edit!!!
|
||||||
// Generated with the script `build_tools/update_shape_lib.sh`.
|
// Generated with the script `build_tools/update_abstract_interp_lib.sh`.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
StringRef mlir::torch::Torch::getShapeLibrary() {
|
StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
#ifndef _MSC_VER
|
#ifndef _MSC_VER
|
||||||
#pragma clang diagnostic push
|
#pragma clang diagnostic push
|
||||||
#pragma clang diagnostic ignored "-Woverlength-strings"
|
#pragma clang diagnostic ignored "-Woverlength-strings"
|
||||||
|
@ -5470,6 +5470,25 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
" %int15 = torch.constant.int 15\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %int7 = torch.constant.int 7\n"
|
||||||
|
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %3 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -5618,11 +5637,11 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
@ -5721,6 +5740,23 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %3 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||||
|
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
|
||||||
|
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
|
||||||
|
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
|
||||||
|
" return %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -6414,6 +6450,12 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %arg0, %arg2 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||||
|
" %1 = torch.prim.ListConstruct %arg1, %arg3 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -7086,6 +7128,15 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %4 : !torch.list<int>\n"
|
" return %4 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||||
|
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||||
|
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
"}\n"
|
"}\n"
|
||||||
"";
|
"";
|
||||||
// clang-format on
|
// clang-format on
|
|
@ -1,7 +1,7 @@
|
||||||
add_mlir_library(TorchMLIRTorchPasses
|
add_mlir_library(TorchMLIRTorchPasses
|
||||||
AdjustCallingConventions.cpp
|
AdjustCallingConventions.cpp
|
||||||
DecomposeComplexOps.cpp
|
DecomposeComplexOps.cpp
|
||||||
DropShapeCalculations.cpp
|
DropAbstractInterpCalculations.cpp
|
||||||
EraseModuleInitializer.cpp
|
EraseModuleInitializer.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
|
@ -13,8 +13,12 @@ add_mlir_library(TorchMLIRTorchPasses
|
||||||
RefinePublicReturn.cpp
|
RefinePublicReturn.cpp
|
||||||
RefineTypes.cpp
|
RefineTypes.cpp
|
||||||
ReifyShapeCalculations.cpp
|
ReifyShapeCalculations.cpp
|
||||||
ShapeLibrary.cpp
|
ReifyDtypeCalculations.cpp
|
||||||
|
ReifyAbstractInterpCalculationsUtils.cpp
|
||||||
|
AbstractInterpLibrary.cpp
|
||||||
SimplifyShapeCalculations.cpp
|
SimplifyShapeCalculations.cpp
|
||||||
|
SimplifyDtypeCalculations.cpp
|
||||||
|
SimplifyAbstractInterpCalculationsUtils.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
|
||||||
|
|
|
@ -9,29 +9,22 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/Parser/Parser.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
|
||||||
#include "llvm/ADT/StringSet.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DropShapeCalculateOp : public OpConversionPattern<ShapeCalculateOp> {
|
template <typename CalculateOp>
|
||||||
|
class DropCalculateOp : public OpConversionPattern<CalculateOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern<CalculateOp>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(ShapeCalculateOp op, OpAdaptor adaptor,
|
matchAndRewrite(CalculateOp op, typename CalculateOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Block *block = &op.getBody().front();
|
Block *block = &op.getBody().front();
|
||||||
Operation *terminator = block->getTerminator();
|
Operation *terminator = block->getTerminator();
|
||||||
|
@ -45,16 +38,18 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DropShapeCalculationsPass
|
class DropAbstractInterpCalculationsPass
|
||||||
: public DropShapeCalculationsBase<DropShapeCalculationsPass> {
|
: public DropAbstractInterpCalculationsBase<
|
||||||
|
DropAbstractInterpCalculationsPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns.insert<DropShapeCalculateOp>(context);
|
patterns.insert<DropCalculateOp<DtypeCalculateOp>>(context);
|
||||||
|
patterns.insert<DropCalculateOp<ShapeCalculateOp>>(context);
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<Torch::TorchDialect>();
|
target.addLegalDialect<Torch::TorchDialect>();
|
||||||
target.addIllegalOp<ShapeCalculateOp>();
|
target.addIllegalOp<DtypeCalculateOp, ShapeCalculateOp>();
|
||||||
target.addLegalOp<func::FuncOp>();
|
target.addLegalOp<func::FuncOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
@ -66,6 +61,6 @@ class DropShapeCalculationsPass
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createDropShapeCalculationsPass() {
|
mlir::torch::Torch::createDropAbstractInterpCalculationsPass() {
|
||||||
return std::make_unique<DropShapeCalculationsPass>();
|
return std::make_unique<DropAbstractInterpCalculationsPass>();
|
||||||
}
|
}
|
|
@ -90,8 +90,8 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
->emitError(
|
->emitError(
|
||||||
"unsupported by backend contract: tensor with unknown rank")
|
"unsupported by backend contract: tensor with unknown rank")
|
||||||
.attachNote()
|
.attachNote()
|
||||||
.append("this is likely due to a missing shape transfer function "
|
.append("this is likely due to a missing transfer function "
|
||||||
"in shape_lib_gen.py");
|
"in abstract_interp_lib_gen.py");
|
||||||
} else {
|
} else {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
// inference), because Torch type promotion rules actually depend on the shape
|
// inference), because Torch type promotion rules actually depend on the shape
|
||||||
// of the operand.
|
// of the operand.
|
||||||
createTorchShapeRefinementPipeline(pm);
|
createTorchShapeRefinementPipeline(pm);
|
||||||
|
createTorchDtypeRefinementPipeline(pm);
|
||||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||||
// Propagate to ABI return types the shape/dtype information discovered by
|
// Propagate to ABI return types the shape/dtype information discovered by
|
||||||
|
@ -137,23 +138,41 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
|
static void createRefinementPipeline(
|
||||||
// Reify the shape functions for each op that is present in the shape library.
|
mlir::OpPassManager &pm,
|
||||||
pm.addPass(Torch::createReifyShapeCalculationsPass());
|
llvm::function_ref<std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>()>
|
||||||
|
reifyCalculationsPass,
|
||||||
|
llvm::function_ref<
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>()>
|
||||||
|
simplifyCalculationsPass) {
|
||||||
|
// Reify the library functions for each op that is present in the library.
|
||||||
|
pm.addPass(reifyCalculationsPass());
|
||||||
|
|
||||||
// Inline the shape functions to enable analysis and transformation.
|
// Inline the library functions to enable analysis and transformation.
|
||||||
// TODO: Only inline shape functions (this will currently inline everything).
|
// TODO: Only inline library functions (this will currently inline
|
||||||
pm.addPass(createInlinerPass());
|
// everything).
|
||||||
|
pm.addPass(mlir::createInlinerPass());
|
||||||
|
|
||||||
// Now, try to simplify shape calculations. This is unfortunately a "optimize
|
// Now, try to simplify calculations. This is unfortunately a "optimize
|
||||||
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
|
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
|
||||||
// The idea is to keep strengthening what we do here to support the shape
|
// The idea is to keep strengthening what we do here to support the
|
||||||
// library. We don't need to support arbitrary programs, thankfully.
|
// library functions. We don't need to support arbitrary programs, thankfully.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
pm.addNestedPass<mlir::func::FuncOp>(simplifyCalculationsPass());
|
||||||
// Run CSE, then see if we can simplify further.
|
// Run CSE, then see if we can simplify further.
|
||||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
pm.addNestedPass<mlir::func::FuncOp>(simplifyCalculationsPass());
|
||||||
|
|
||||||
// Drop shape calculations, leaving behind the shape-refined program.
|
// Drop calculations, leaving behind the-refined program.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
|
pm.addNestedPass<mlir::func::FuncOp>(
|
||||||
|
mlir::torch::Torch::createDropAbstractInterpCalculationsPass());
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) {
|
||||||
|
createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass,
|
||||||
|
Torch::createSimplifyShapeCalculationsPass);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) {
|
||||||
|
createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass,
|
||||||
|
Torch::createSimplifyDtypeCalculationsPass);
|
||||||
}
|
}
|
||||||
|
|
|
@ -491,44 +491,6 @@ private:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// This is the type rule used for deciding dtype for:
|
|
||||||
// 1. A new tensor created from given data.
|
|
||||||
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
|
|
||||||
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
|
|
||||||
// If the data is floating-point, the `dtype` is inferred to be the
|
|
||||||
// default dtype, see `torch.get_default_dtype`.
|
|
||||||
static Type getDefaultDtypeForTorchScalar(Type type) {
|
|
||||||
MLIRContext *context = type.getContext();
|
|
||||||
if (type.isa<Torch::FloatType>()) {
|
|
||||||
// For now, use float32 which is the initial default dtype returned by
|
|
||||||
// `torch.get_default_dtype`.
|
|
||||||
return Float32Type::get(context);
|
|
||||||
}
|
|
||||||
if (type.isa<Torch::IntType>())
|
|
||||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
|
||||||
if (type.isa<Torch::BoolType>())
|
|
||||||
return IntegerType::get(context, 1);
|
|
||||||
llvm_unreachable(
|
|
||||||
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the type rule used for deciding builtin type for:
|
|
||||||
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
|
|
||||||
// PrimNumToTensorScalarOp.
|
|
||||||
// 2. The scalar type for type promotion when a scalar is an operand of scalar
|
|
||||||
// only operation like AtenAddOp.
|
|
||||||
static Type getBuiltInTypeForTorchScalar(Type type) {
|
|
||||||
MLIRContext *context = type.getContext();
|
|
||||||
if (type.isa<Torch::FloatType>())
|
|
||||||
return Float64Type::get(context);
|
|
||||||
if (type.isa<Torch::IntType>())
|
|
||||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
|
||||||
if (type.isa<Torch::BoolType>())
|
|
||||||
return IntegerType::get(context, 1);
|
|
||||||
llvm_unreachable(
|
|
||||||
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
|
||||||
}
|
|
||||||
|
|
||||||
static torch_upstream::ResultTypeState
|
static torch_upstream::ResultTypeState
|
||||||
updateResultTypeState(Type scalarType,
|
updateResultTypeState(Type scalarType,
|
||||||
const torch_upstream::ResultTypeState &inState) {
|
const torch_upstream::ResultTypeState &inState) {
|
||||||
|
@ -583,8 +545,11 @@ static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
|
||||||
state =
|
state =
|
||||||
updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state);
|
updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state);
|
||||||
}
|
}
|
||||||
return getTorchTypeForScalarType(scalarTypes[0].getContext(),
|
FailureOr<Type> result = getTorchTypeForScalarType(
|
||||||
result_type(state));
|
scalarTypes[0].getContext(), result_type(state));
|
||||||
|
if (failed(result))
|
||||||
|
return Type();
|
||||||
|
return *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns most generic type Type() if the tensor dtype is unknown.
|
// Returns most generic type Type() if the tensor dtype is unknown.
|
||||||
|
@ -707,9 +672,9 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
||||||
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
|
if (isa<AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||||
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
|
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
||||||
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
Type dtype = operands[0]->getValue().dtype;
|
Type dtype = operands[0]->getValue().dtype;
|
||||||
|
@ -770,7 +735,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
||||||
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
|
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
|
||||||
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||||
AtenThresholdBackwardOp, AtenFloorDivideOp>(op)) {
|
AtenThresholdBackwardOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultType(
|
knowledge.dtype = getPromotedResultType(
|
||||||
|
@ -816,7 +781,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
// Promote LHS with scalar RHS.
|
// Promote LHS with scalar RHS.
|
||||||
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
|
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
|
||||||
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
|
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
|
||||||
AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
|
AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
|
||||||
auto lhs = operands[0]->getValue();
|
auto lhs = operands[0]->getValue();
|
||||||
Value scalar = op->getOperand(1);
|
Value scalar = op->getOperand(1);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
|
|
|
@ -0,0 +1,290 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) {
|
||||||
|
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||||
|
return "__torch_mlir_shape_fn.";
|
||||||
|
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||||
|
return "__torch_mlir_dtype_fn.";
|
||||||
|
llvm_unreachable(
|
||||||
|
"`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`");
|
||||||
|
}
|
||||||
|
|
||||||
|
static Operation *createCalculateOp(OpBuilder &b, Location loc,
|
||||||
|
TypeRange resultTypes,
|
||||||
|
LibraryFunctionKind libFuncKind) {
|
||||||
|
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||||
|
return b.create<ShapeCalculateOp>(loc, resultTypes);
|
||||||
|
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||||
|
return b.create<DtypeCalculateOp>(loc, resultTypes);
|
||||||
|
llvm_unreachable(
|
||||||
|
"`createCalculateOp` called with an unsupported `LibraryFunctionKind`");
|
||||||
|
}
|
||||||
|
|
||||||
|
static Operation *createCalculateYieldOp(OpBuilder &b, Location loc,
|
||||||
|
ValueRange results,
|
||||||
|
LibraryFunctionKind libFuncKind) {
|
||||||
|
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||||
|
return b.create<ShapeCalculateYieldOp>(loc, results);
|
||||||
|
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||||
|
return b.create<DtypeCalculateYieldOp>(loc, results);
|
||||||
|
llvm_unreachable("`createCalculateYieldOp` called with an unsupported "
|
||||||
|
"`LibraryFunctionKind`");
|
||||||
|
}
|
||||||
|
|
||||||
|
static Operation *
|
||||||
|
createCalculateYieldCalculationOp(OpBuilder &b, Location loc,
|
||||||
|
ValueRange results,
|
||||||
|
LibraryFunctionKind libFuncKind) {
|
||||||
|
if (libFuncKind == LibraryFunctionKind::ShapeFunction)
|
||||||
|
return b.create<ShapeCalculateYieldShapesOp>(loc, results);
|
||||||
|
else if (libFuncKind == LibraryFunctionKind::DtypeFunction)
|
||||||
|
return b.create<DtypeCalculateYieldDtypesOp>(loc, results);
|
||||||
|
llvm_unreachable("`createCalculateYieldCalculationOp` called with an "
|
||||||
|
"unsupported `LibraryFunctionKind`");
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
|
Operation *op, ModuleOp library, LibraryFunctionKind libFuncKind,
|
||||||
|
SmallVector<std::string> &libFuncNamesUsed,
|
||||||
|
function_ref<FailureOr<SmallVector<Value>>(OpBuilder &, Location,
|
||||||
|
ValueRange, func::FuncOp)>
|
||||||
|
libFuncArgsBuilder) {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
auto name = op->getName().stripDialect();
|
||||||
|
// For value-semantic variant ops, i.e. valsem-ops (ops that are
|
||||||
|
// mechanically consistent with existing torch conventions of in-place vs.
|
||||||
|
// out-of-place (value-semantic) variants), remove the prefix when
|
||||||
|
// looking them up in the library.
|
||||||
|
if (name.startswith("valsem."))
|
||||||
|
name = name.drop_front(strlen("valsem."));
|
||||||
|
std::string libFuncName =
|
||||||
|
(getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str();
|
||||||
|
auto libFunc = library.lookupSymbol<func::FuncOp>(libFuncName);
|
||||||
|
if (!libFunc)
|
||||||
|
return success();
|
||||||
|
libFuncNamesUsed.push_back(libFuncName);
|
||||||
|
OpBuilder b(op);
|
||||||
|
Operation *calculate =
|
||||||
|
createCalculateOp(b, loc, op->getResultTypes(), libFuncKind);
|
||||||
|
op->replaceAllUsesWith(calculate);
|
||||||
|
{
|
||||||
|
// Move the op into the body of the `torch.{libFuncType}.calculate` op
|
||||||
|
// and yield its results.
|
||||||
|
OpBuilder b(context);
|
||||||
|
Block *bodyBlock = b.createBlock(&calculate->getRegion(0));
|
||||||
|
op->moveBefore(bodyBlock, bodyBlock->end());
|
||||||
|
b.setInsertionPointAfter(op);
|
||||||
|
createCalculateYieldOp(b, loc, op->getResults(), libFuncKind);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder b(context);
|
||||||
|
b.createBlock(&calculate->getRegion(1));
|
||||||
|
// Create the call to the library function!
|
||||||
|
FailureOr<SmallVector<Value>> libFuncArgs =
|
||||||
|
libFuncArgsBuilder(b, loc, op->getOperands(), libFunc);
|
||||||
|
if (failed(libFuncArgs))
|
||||||
|
return failure();
|
||||||
|
auto call = b.create<mlir::func::CallOp>(loc, libFunc, *libFuncArgs);
|
||||||
|
|
||||||
|
// Python models multiple results with a tuple, so we need to unpack it
|
||||||
|
// if the op has multiple results.
|
||||||
|
SmallVector<Value> unpackedResults;
|
||||||
|
assert(call.getNumResults() == 1 &&
|
||||||
|
"Multiple results are packed in a tuple in Python!");
|
||||||
|
Value result = call.getResult(0);
|
||||||
|
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
|
||||||
|
auto unpack = b.create<PrimTupleUnpackOp>(
|
||||||
|
loc, tupleType.getContainedTypes(), result);
|
||||||
|
llvm::append_range(unpackedResults, unpack.getResults());
|
||||||
|
} else {
|
||||||
|
unpackedResults.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Terminate the region.
|
||||||
|
createCalculateYieldCalculationOp(b, loc, unpackedResults, libFuncKind);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library,
|
||||||
|
SmallVector<std::string> functionsNeeded) {
|
||||||
|
// Import just the functions we need. This includes transitive callees,
|
||||||
|
// so we use a worklist algorithm.
|
||||||
|
llvm::StringSet<> importedFunctions;
|
||||||
|
while (!functionsNeeded.empty()) {
|
||||||
|
std::string symName = functionsNeeded.pop_back_val();
|
||||||
|
if (importedFunctions.contains(symName))
|
||||||
|
continue;
|
||||||
|
auto func = library.lookupSymbol<func::FuncOp>(symName);
|
||||||
|
assert(func && "broken library");
|
||||||
|
// Move the function from the library to the module this pass
|
||||||
|
// is running on. (this mutates the library, but we re-parse it each time
|
||||||
|
// so this is safe to do).
|
||||||
|
func->moveBefore(&module.getBody()->front());
|
||||||
|
// Set the visibility to private so that the functions go away
|
||||||
|
// nicely after we are done with them.
|
||||||
|
func.setVisibility(SymbolTable::Visibility::Private);
|
||||||
|
// Continue the DFS.
|
||||||
|
importedFunctions.insert(symName);
|
||||||
|
func.walk([&](func::CallOp op) {
|
||||||
|
functionsNeeded.push_back(op.getCallee().str());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> Torch::adjustFunctionArg(
|
||||||
|
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
||||||
|
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation) {
|
||||||
|
operand = baseTransformation(b, loc, operand, desiredType);
|
||||||
|
|
||||||
|
// No need for adjustment if they already match.
|
||||||
|
auto operandType = operand.getType();
|
||||||
|
if (operandType == desiredType)
|
||||||
|
return operand;
|
||||||
|
|
||||||
|
if (desiredType.isa<Torch::AnyType>()) {
|
||||||
|
// Generator's are currently passed as Any because TorchScript cannot
|
||||||
|
// compile a function with Generator type arguments.
|
||||||
|
// Ignoring that hack, this is a correct handling of Any type should we need
|
||||||
|
// to actually support it in the future.
|
||||||
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// !torch.union<int, float> is the type used for `Scalar` inputs. At
|
||||||
|
// compile time, such inputs will usually be resolved to an `int` or a `float`
|
||||||
|
// so we need to derefine to match the library function signature.
|
||||||
|
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
|
||||||
|
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||||
|
return containedType.isa<Torch::IntType, Torch::FloatType>();
|
||||||
|
}))
|
||||||
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the operand is NoneType, then we just need to derefine it to the
|
||||||
|
// optional type in the function signature.
|
||||||
|
if (operandType.isa<Torch::NoneType>()) {
|
||||||
|
assert(desiredType.isa<Torch::OptionalType>() &&
|
||||||
|
"Don't expect library functions to have NoneType parameters");
|
||||||
|
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the operand type is statically !torch.optional, then we need to do
|
||||||
|
// different things for the None and non-None cases.
|
||||||
|
// For the None case, we just need to derefine it to the desired type.
|
||||||
|
// For the non-None case, we need to unwrap the optional type and then adjust
|
||||||
|
// it recursively (which also takes care of derefining it to ultimate desired
|
||||||
|
// type).
|
||||||
|
// A case where this happens is `!torch.optional<vtensor>` ->
|
||||||
|
// `!torch.optional<list<int>>>`.
|
||||||
|
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
|
||||||
|
if (desiredType.isa<Torch::OptionalType>()) {
|
||||||
|
// if optional is None:
|
||||||
|
// return derefine(None)
|
||||||
|
// else:
|
||||||
|
// return adjust(unchecked_cast(optional))
|
||||||
|
auto none = b.create<ConstantNoneOp>(loc);
|
||||||
|
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
|
||||||
|
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
|
||||||
|
{
|
||||||
|
Region &thenRegion = primIf.getThenRegion();
|
||||||
|
b.createBlock(&thenRegion, thenRegion.end());
|
||||||
|
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
|
||||||
|
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
|
||||||
|
}
|
||||||
|
{
|
||||||
|
Region &elseRegion = primIf.getElseRegion();
|
||||||
|
b.createBlock(&elseRegion, elseRegion.end());
|
||||||
|
auto downcasted = b.create<PrimUncheckedCastOp>(
|
||||||
|
loc, operandOptionalType.getContainedType(), operand);
|
||||||
|
FailureOr<Value> adjusted = adjustFunctionArg(
|
||||||
|
b, loc, downcasted, desiredType, baseTransformation);
|
||||||
|
if (failed(adjusted))
|
||||||
|
return failure();
|
||||||
|
b.create<PrimIfYieldOp>(loc, *adjusted);
|
||||||
|
}
|
||||||
|
b.setInsertionPointAfter(primIf);
|
||||||
|
return primIf.getResult(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the desired type is OptionalType, then recursively adjust the operand to
|
||||||
|
// the contained type, then derefine it to `!torch.optional`. For example,
|
||||||
|
// `!torch.vtensor -> !torch.optional<list<int>>>`.
|
||||||
|
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
|
||||||
|
FailureOr<Value> adjusted = adjustFunctionArg(
|
||||||
|
b, loc, operand, desiredOptionalType.getContainedType(),
|
||||||
|
baseTransformation);
|
||||||
|
if (failed(adjusted))
|
||||||
|
return failure();
|
||||||
|
return b.create<DerefineOp>(loc, desiredType, *adjusted).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
|
||||||
|
// Pseudocode:
|
||||||
|
//
|
||||||
|
// operand = ...
|
||||||
|
// adjusted_list = []
|
||||||
|
// for i in range(len(operand)):
|
||||||
|
// adjusted_list.append(adjust(operand[i]))
|
||||||
|
// return adjusted_list
|
||||||
|
auto providedType = operand.getType().cast<Torch::ListType>();
|
||||||
|
Value adjustedList =
|
||||||
|
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
|
||||||
|
// Create a for-like PrimLoopOp.
|
||||||
|
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
|
||||||
|
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
|
||||||
|
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
|
||||||
|
/*initialCondition=*/cTrue,
|
||||||
|
/*iterArgsInit=*/ValueRange({}));
|
||||||
|
|
||||||
|
// Create the loop body.
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(b);
|
||||||
|
Block *body =
|
||||||
|
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
|
||||||
|
TypeRange({b.getType<Torch::IntType>()}), {loc});
|
||||||
|
Value iterationNumber = body->getArgument(0);
|
||||||
|
Value element = b.create<Aten__Getitem__TOp>(
|
||||||
|
loc, providedType.getContainedType(), operand, iterationNumber);
|
||||||
|
FailureOr<Value> adjustedElement =
|
||||||
|
adjustFunctionArg(b, loc, element, desiredListType.getContainedType(),
|
||||||
|
baseTransformation);
|
||||||
|
if (failed(adjustedElement))
|
||||||
|
return failure();
|
||||||
|
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
|
||||||
|
*adjustedElement);
|
||||||
|
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
|
||||||
|
/*iterArgs=*/ValueRange({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
return adjustedList;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The library functions use `float` where the operator
|
||||||
|
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
||||||
|
// explanation).
|
||||||
|
if (desiredType.isa<Torch::FloatType>() &&
|
||||||
|
operand.getType().isa<Torch::IntType>()) {
|
||||||
|
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass the operand as-is.
|
||||||
|
return operand;
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
||||||
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/OperationSupport.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace Torch {
|
||||||
|
|
||||||
|
enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition };
|
||||||
|
|
||||||
|
// Searches the function library for an abstract interpretation function for
|
||||||
|
// `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in
|
||||||
|
// the first region, and a call to the abstract interpretation function is
|
||||||
|
// inserted into the second region.
|
||||||
|
//
|
||||||
|
// Note: this returns success if no abstract interpretation function is found,
|
||||||
|
// since some abstract interpretation functions (such as decompositions) are
|
||||||
|
// optional.
|
||||||
|
//
|
||||||
|
// Note: This function does *not* import the abstract interpretation function
|
||||||
|
// from the library into the IR.
|
||||||
|
LogicalResult wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
|
Operation *op, ModuleOp library, LibraryFunctionKind funcKind,
|
||||||
|
SmallVector<std::string> &libFuncNamesUsed,
|
||||||
|
function_ref<FailureOr<SmallVector<Value>>(OpBuilder &, Location,
|
||||||
|
ValueRange, func::FuncOp)>
|
||||||
|
libFuncArgsBuilder);
|
||||||
|
|
||||||
|
// Imports the functions in `functionsNeeded` from the library into the module.
|
||||||
|
// This function assumes that all functions needed exist in the library.
|
||||||
|
//
|
||||||
|
// Note: This function modifies the library.
|
||||||
|
void importLibraryFunctions(ModuleOp module, ModuleOp library,
|
||||||
|
SmallVector<std::string> functionsNeeded);
|
||||||
|
|
||||||
|
// Recursively adjust `operand` to match `desiredType`.
|
||||||
|
//
|
||||||
|
// This function by default handles a few types such as `UnionType`,
|
||||||
|
// `OptionalType`, and `ListType`, to name a few. Handling of base element types
|
||||||
|
// can be customized by defining `baseTransformation`, which gets called at the
|
||||||
|
// beginning of each recursive call. This function can be thought of as mapping
|
||||||
|
// `baseTransformation` across `UnionType/OptionalType/ListType`.
|
||||||
|
FailureOr<Value> adjustFunctionArg(
|
||||||
|
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
||||||
|
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation =
|
||||||
|
[](OpBuilder &, Location, Value operand, Type) { return operand; });
|
||||||
|
} // namespace Torch
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_REIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
|
@ -0,0 +1,136 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
|
#include "mlir/Parser/Parser.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static bool isTensorTypeOrWrappedTensorType(Type type) {
|
||||||
|
// Allowing tuples as arguments to dtype calculation functions can cause
|
||||||
|
// issues. For example, if an argument is a tuple of tensors and ints, there
|
||||||
|
// would be no way of differentiating the original ints from the ints created
|
||||||
|
// to represent the dtype and rank of the tensors. Therefore, to avoid this
|
||||||
|
// and keep things simple, the tuple type is not allowed. This works well in
|
||||||
|
// practice, since PyTorch op signatures don't seem to take tuples as inputs.
|
||||||
|
assert(!type.isa<Torch::TupleType>() &&
|
||||||
|
"dtype calculation functions are expected to not have tuples of "
|
||||||
|
"tensors as arguments");
|
||||||
|
|
||||||
|
if (type.isa<Torch::BaseTensorType>())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto optionalType = type.dyn_cast<Torch::OptionalType>()) {
|
||||||
|
return isTensorTypeOrWrappedTensorType(optionalType.getContainedType());
|
||||||
|
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
|
||||||
|
return isTensorTypeOrWrappedTensorType(listType.getContainedType());
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Massage the op operands to match the dtype function signature.
|
||||||
|
// The dtype function generally takes the same operands as the op, with a few
|
||||||
|
// systematic modifications, such as replacing tensors with a rank and dtype
|
||||||
|
// argument.
|
||||||
|
static FailureOr<SmallVector<Value>>
|
||||||
|
dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
|
ValueRange originalOperands, func::FuncOp dtypeFunc) {
|
||||||
|
// Turns a tensor operand into an operand representing the rank of the tensor
|
||||||
|
auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||||
|
Type desiredType) -> Value {
|
||||||
|
if (desiredType.isa<Torch::IntType>() &&
|
||||||
|
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||||
|
auto sizeListType =
|
||||||
|
Torch::ListType::get(Torch::IntType::get(b.getContext()));
|
||||||
|
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
||||||
|
return b.create<AtenLenTOp>(loc, desiredType, size);
|
||||||
|
}
|
||||||
|
return operand;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turns a tensor operand into an operand representing the dtype of the tensor
|
||||||
|
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||||
|
Type desiredType) -> Value {
|
||||||
|
if (desiredType.isa<Torch::IntType>() &&
|
||||||
|
operand.getType().isa<Torch::BaseTensorType>()) {
|
||||||
|
return b.create<PrimDtypeOp>(loc, desiredType, operand);
|
||||||
|
}
|
||||||
|
return operand;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<Value> dtypeFuncArgs;
|
||||||
|
ArrayRef<Type> desiredTypes = dtypeFunc.getArgumentTypes();
|
||||||
|
for (auto operand : originalOperands) {
|
||||||
|
assert(!desiredTypes.empty() &&
|
||||||
|
"`dtypeFunc` should have at least one argument for each argument in "
|
||||||
|
"`originalOperands`");
|
||||||
|
Type desiredType = desiredTypes.front();
|
||||||
|
if (isTensorTypeOrWrappedTensorType(operand.getType())) {
|
||||||
|
assert(desiredTypes.size() >= 2 &&
|
||||||
|
"`dtypeFunc` should have two arguments for each tensor argument "
|
||||||
|
"in `originalOperands`");
|
||||||
|
FailureOr<Value> rankArg, dtypeArg;
|
||||||
|
if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType,
|
||||||
|
rankArgAdjuster)))
|
||||||
|
return failure();
|
||||||
|
desiredTypes = desiredTypes.drop_front();
|
||||||
|
desiredType = desiredTypes.front();
|
||||||
|
if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType,
|
||||||
|
dtypeArgAdjuster)))
|
||||||
|
return failure();
|
||||||
|
dtypeFuncArgs.append({*rankArg, *dtypeArg});
|
||||||
|
} else {
|
||||||
|
FailureOr<Value> otherArg;
|
||||||
|
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType)))
|
||||||
|
return failure();
|
||||||
|
dtypeFuncArgs.push_back(*otherArg);
|
||||||
|
}
|
||||||
|
desiredTypes = desiredTypes.drop_front();
|
||||||
|
}
|
||||||
|
|
||||||
|
return dtypeFuncArgs;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ReifyDtypeCalculationsPass
|
||||||
|
: public ReifyDtypeCalculationsBase<ReifyDtypeCalculationsPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ModuleOp module = getOperation();
|
||||||
|
OwningOpRef<ModuleOp> library =
|
||||||
|
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||||
|
|
||||||
|
// Walk all the operations, and if we have a dtype function, wrap the op
|
||||||
|
// in a `torch.dtype.calculate` op.
|
||||||
|
SmallVector<std::string> functionsNeeded;
|
||||||
|
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
|
||||||
|
return wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
|
op, *library, LibraryFunctionKind::DtypeFunction, functionsNeeded,
|
||||||
|
dtypeFunctionArgsBuilder);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (walkResult.wasInterrupted())
|
||||||
|
return signalPassFailure();
|
||||||
|
importLibraryFunctions(module, *library, std::move(functionsNeeded));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
Torch::createReifyDtypeCalculationsPass() {
|
||||||
|
return std::make_unique<ReifyDtypeCalculationsPass>();
|
||||||
|
}
|
|
@ -9,209 +9,49 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "ReifyAbstractInterpCalculationsUtils.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/Parser/Parser.h"
|
#include "mlir/Parser/Parser.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
|
||||||
#include "llvm/ADT/StringSet.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static Value adjustShapeFunctionArg(Value operand, Type desiredType,
|
static FailureOr<SmallVector<Value>>
|
||||||
OpBuilder &b, Location loc);
|
shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
|
ValueRange originalOperands, func::FuncOp shapeFunc) {
|
||||||
static Value adjustListArg(Value operand, Torch::ListType desiredType,
|
|
||||||
OpBuilder &b, Location loc) {
|
|
||||||
auto providedType = operand.getType().cast<Torch::ListType>();
|
|
||||||
|
|
||||||
// Pseudocode:
|
|
||||||
//
|
|
||||||
// operand = ...
|
|
||||||
// adjusted_list = []
|
|
||||||
// for i in range(len(operand)):
|
|
||||||
// adjusted_list.append(adjust(operand[i]))
|
|
||||||
// return adjusted_list
|
|
||||||
Value adjustedList =
|
|
||||||
b.create<PrimListConstructOp>(loc, desiredType, ValueRange({}));
|
|
||||||
// Create a for-like PrimLoopOp.
|
|
||||||
Value maxTripCount = b.create<AtenLenTOp>(loc, operand);
|
|
||||||
Value cTrue = b.create<Torch::ConstantBoolOp>(loc, true);
|
|
||||||
auto loop = b.create<PrimLoopOp>(loc, TypeRange({}), maxTripCount,
|
|
||||||
/*initialCondition=*/cTrue,
|
|
||||||
/*iterArgsInit=*/ValueRange({}));
|
|
||||||
|
|
||||||
// Create the loop body.
|
|
||||||
{
|
|
||||||
OpBuilder::InsertionGuard guard(b);
|
|
||||||
Block *body =
|
|
||||||
b.createBlock(&loop.getRegion(), loop.getRegion().begin(),
|
|
||||||
TypeRange({b.getType<Torch::IntType>()}), {loc});
|
|
||||||
Value iterationNumber = body->getArgument(0);
|
|
||||||
Value element = b.create<Aten__Getitem__TOp>(
|
|
||||||
loc, providedType.getContainedType(), operand, iterationNumber);
|
|
||||||
Value adjustedElement =
|
|
||||||
adjustShapeFunctionArg(element, desiredType.getContainedType(), b, loc);
|
|
||||||
b.create<AtenAppendTOp>(loc, adjustedList.getType(), adjustedList,
|
|
||||||
adjustedElement);
|
|
||||||
b.create<PrimLoopConditionOp>(loc, /*shouldContinue=*/cTrue,
|
|
||||||
/*iterArgs=*/ValueRange({}));
|
|
||||||
}
|
|
||||||
|
|
||||||
return adjustedList;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value adjustShapeFunctionArg(Value operand, Type desiredType,
|
|
||||||
OpBuilder &b, Location loc) {
|
|
||||||
auto operandType = operand.getType();
|
|
||||||
|
|
||||||
// No need for adjustment if they already match.
|
|
||||||
if (operandType == desiredType)
|
|
||||||
return operand;
|
|
||||||
|
|
||||||
if (desiredType.isa<Torch::AnyType>()) {
|
|
||||||
// Generator's are currently passed as Any because TorchScript cannot
|
|
||||||
// compile a function with Generator type arguments.
|
|
||||||
// Ignoring that hack, this is a correct handling of Any type should we need
|
|
||||||
// to actually support it in the future.
|
|
||||||
return b.create<DerefineOp>(loc, desiredType, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the operand is NoneType, then we just need to derefine it to the
|
|
||||||
// optional type in the shape function signature.
|
|
||||||
if (operandType.isa<Torch::NoneType>()) {
|
|
||||||
assert(desiredType.isa<Torch::OptionalType>() &&
|
|
||||||
"Don't expect shape functions to have NoneType parameters");
|
|
||||||
return b.create<DerefineOp>(loc, desiredType, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the operand type is statically !torch.optional, then we need to do
|
|
||||||
// different things for the None and non-None cases.
|
|
||||||
// For the None case, we just need to derefine it to the desired type.
|
|
||||||
// For the non-None case, we need to unwrap the optional type and then adjust
|
|
||||||
// it recursively (which also takes care of derefining it to ultimate desired
|
|
||||||
// type).
|
|
||||||
// A case where this happens is `!torch.optional<vtensor>` ->
|
|
||||||
// `!torch.optional<list<int>>>`.
|
|
||||||
if (auto operandOptionalType = operandType.dyn_cast<Torch::OptionalType>()) {
|
|
||||||
if (desiredType.isa<Torch::OptionalType>()) {
|
|
||||||
// if optional is None:
|
|
||||||
// return derefine(None)
|
|
||||||
// else:
|
|
||||||
// return adjust(unchecked_cast(optional))
|
|
||||||
auto none = b.create<ConstantNoneOp>(loc);
|
|
||||||
auto isNone = b.create<Aten__Is__Op>(loc, operand, none);
|
|
||||||
auto primIf = b.create<PrimIfOp>(loc, desiredType, isNone);
|
|
||||||
{
|
|
||||||
Region &thenRegion = primIf.getThenRegion();
|
|
||||||
b.createBlock(&thenRegion, thenRegion.end());
|
|
||||||
auto derefineNone = b.create<DerefineOp>(loc, desiredType, none);
|
|
||||||
b.create<PrimIfYieldOp>(loc, ValueRange{derefineNone});
|
|
||||||
}
|
|
||||||
{
|
|
||||||
Region &elseRegion = primIf.getElseRegion();
|
|
||||||
b.createBlock(&elseRegion, elseRegion.end());
|
|
||||||
auto downcasted = b.create<PrimUncheckedCastOp>(
|
|
||||||
loc, operandOptionalType.getContainedType(), operand);
|
|
||||||
auto adjusted = adjustShapeFunctionArg(downcasted, desiredType, b, loc);
|
|
||||||
b.create<PrimIfYieldOp>(loc, adjusted);
|
|
||||||
}
|
|
||||||
b.setInsertionPointAfter(primIf);
|
|
||||||
return primIf.getResult(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the desired type is OptionalType, then recursively adjust the operand to
|
|
||||||
// the contained type, then derefine it to `!torch.optional`. For example,
|
|
||||||
// `!torch.vtensor -> !torch.optional<list<int>>>`.
|
|
||||||
if (auto desiredOptionalType = desiredType.dyn_cast<Torch::OptionalType>()) {
|
|
||||||
auto adjusted = adjustShapeFunctionArg(
|
|
||||||
operand, desiredOptionalType.getContainedType(), b, loc);
|
|
||||||
return b.create<DerefineOp>(loc, desiredType, adjusted);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The shape library functions have tensor operands replaced with
|
|
||||||
// `!torch.list<int>` types for the shape. Get the sizes.
|
|
||||||
if (operand.getType().isa<Torch::BaseTensorType>()) {
|
|
||||||
assert(desiredType.isa<Torch::ListType>() &&
|
|
||||||
"Don't expect shape functions to have tensor parameters");
|
|
||||||
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run this after `operand.getType().isa<Torch::BaseTensorType>()` so that
|
|
||||||
// `!torch.vtensor` -> `!torch.list<int>` is handled there specially
|
|
||||||
// first.
|
|
||||||
if (auto desiredListType = desiredType.dyn_cast<Torch::ListType>()) {
|
|
||||||
return adjustListArg(operand, desiredListType, b, loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The shape library functions use `float` where the operator
|
|
||||||
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
|
||||||
// explanation).
|
|
||||||
if (desiredType.isa<Torch::FloatType>() &&
|
|
||||||
operand.getType().isa<Torch::IntType>()) {
|
|
||||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass the operand as-is.
|
|
||||||
return operand;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populates the shape calculation region with a call to the shape function
|
|
||||||
// from the shape library.
|
|
||||||
static LogicalResult
|
|
||||||
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
|
|
||||||
mlir::func::FuncOp shapeFunction) {
|
|
||||||
// Create a call to the shape function in the `shapeCalculation` region.
|
|
||||||
// We will import the callee from the shape library later.
|
|
||||||
OpBuilder b(op.getContext());
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
b.createBlock(&op.getShapeCalculation());
|
|
||||||
// Massage the op operands to match the shape function signature.
|
// Massage the op operands to match the shape function signature.
|
||||||
// The shape function generally takes the same operands as the op, with a few
|
// The shape function generally takes the same operands as the op, with a few
|
||||||
// systematic modifications, such as replacing tensors with their shapes.
|
// systematic modifications, such as replacing tensors with their shapes.
|
||||||
SmallVector<Value> shapeFunctionArgs;
|
SmallVector<Value> shapeFuncArgs;
|
||||||
for (auto operandAndDesiredType :
|
for (auto operandAndDesiredType :
|
||||||
llvm::zip(originalOperands, shapeFunction.getArgumentTypes())) {
|
llvm::zip(originalOperands, shapeFunc.getArgumentTypes())) {
|
||||||
Value operand;
|
Value operand;
|
||||||
Type desiredType;
|
Type desiredType;
|
||||||
std::tie(operand, desiredType) = operandAndDesiredType;
|
std::tie(operand, desiredType) = operandAndDesiredType;
|
||||||
Value shapeFunctionArg =
|
FailureOr<Value> shapeFuncArg = adjustFunctionArg(
|
||||||
adjustShapeFunctionArg(operand, desiredType, b, loc);
|
b, loc, operand, desiredType,
|
||||||
if (!shapeFunctionArg)
|
[](OpBuilder &b, Location loc, Value operand,
|
||||||
|
Type desiredType) -> Value {
|
||||||
|
// The shape library functions have tensor operands replaced with
|
||||||
|
// `!torch.list<int>` types for the shape. Get the sizes.
|
||||||
|
auto desiredListType = desiredType.dyn_cast<Torch::ListType>();
|
||||||
|
if (!desiredListType)
|
||||||
|
return operand;
|
||||||
|
if (operand.getType().isa<Torch::BaseTensorType>() &&
|
||||||
|
desiredListType.getContainedType().isa<Torch::IntType>()) {
|
||||||
|
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
||||||
|
}
|
||||||
|
return operand;
|
||||||
|
});
|
||||||
|
if (failed(shapeFuncArg))
|
||||||
return failure();
|
return failure();
|
||||||
shapeFunctionArgs.push_back(shapeFunctionArg);
|
shapeFuncArgs.push_back(*shapeFuncArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the call to the shape function!
|
return shapeFuncArgs;
|
||||||
auto call =
|
|
||||||
b.create<mlir::func::CallOp>(loc, shapeFunction, shapeFunctionArgs);
|
|
||||||
|
|
||||||
// Python models multiple results with a tuple, so we need to unpack it
|
|
||||||
// if the op has multiple results.
|
|
||||||
SmallVector<Value> unpackedResults;
|
|
||||||
assert(call.getNumResults() == 1 &&
|
|
||||||
"Multiple results are packed in a tuple in Python!");
|
|
||||||
Value result = call.getResult(0);
|
|
||||||
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
|
|
||||||
auto unpack =
|
|
||||||
b.create<PrimTupleUnpackOp>(loc, tupleType.getContainedTypes(), result);
|
|
||||||
llvm::append_range(unpackedResults, unpack.getResults());
|
|
||||||
} else {
|
|
||||||
unpackedResults.push_back(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Terminate the region.
|
|
||||||
b.create<ShapeCalculateYieldShapesOp>(loc, unpackedResults);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -222,74 +62,23 @@ class ReifyShapeCalculationsPass
|
||||||
ModuleOp module = getOperation();
|
ModuleOp module = getOperation();
|
||||||
|
|
||||||
// TODO: Find a way to not have to parse this every time.
|
// TODO: Find a way to not have to parse this every time.
|
||||||
// The shape library is O(#ops we know about), and this pass should be
|
// The library is O(#ops we know about), and this pass should be
|
||||||
// O(#ops in the program) ideally.
|
// O(#ops in the program) ideally.
|
||||||
auto shapeLibrary = parseSourceString<ModuleOp>(getShapeLibrary(), context);
|
OwningOpRef<ModuleOp> library =
|
||||||
|
parseSourceString<ModuleOp>(getAbstractInterpLibrary(), context);
|
||||||
|
|
||||||
// Walk all the operations, and if we have a shape function, wrap the op
|
// Walk all the operations, and if we have a shape function, wrap the op
|
||||||
// in a `torch.shape.calculate` op.
|
// in a `torch.shape.calculate` op.
|
||||||
SmallVector<std::string> neededShapeFunctions;
|
SmallVector<std::string> functionsNeeded;
|
||||||
bool hadError = false;
|
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
|
||||||
module.walk([&](Operation *op) {
|
return wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
Location loc = op->getLoc();
|
op, *library, LibraryFunctionKind::ShapeFunction, functionsNeeded,
|
||||||
auto name = op->getName().stripDialect();
|
shapeFunctionArgsBuilder);
|
||||||
// For value-semantic variant ops, i.e. valsem-ops (ops that are
|
|
||||||
// mechanically consistent with existing torch conventions of in-place vs.
|
|
||||||
// out-of-place (value-semantic) variants), remove the prefix when
|
|
||||||
// looking them up in the shape library.
|
|
||||||
if (name.startswith("valsem."))
|
|
||||||
name = name.drop_front(strlen("valsem."));
|
|
||||||
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
|
|
||||||
auto shapeFunction =
|
|
||||||
shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
|
|
||||||
if (!shapeFunction)
|
|
||||||
return;
|
|
||||||
neededShapeFunctions.push_back(shapeFunctionName);
|
|
||||||
auto shapeCalculate =
|
|
||||||
OpBuilder(op).create<ShapeCalculateOp>(loc, op->getResultTypes());
|
|
||||||
op->replaceAllUsesWith(shapeCalculate);
|
|
||||||
{
|
|
||||||
// Move the op into the body of the `torch.shape.calculate` op and yield
|
|
||||||
// its results.
|
|
||||||
OpBuilder b(context);
|
|
||||||
Block *block = b.createBlock(&shapeCalculate.getBody());
|
|
||||||
op->moveBefore(block, block->end());
|
|
||||||
b.setInsertionPointAfter(op);
|
|
||||||
b.create<ShapeCalculateYieldOp>(loc, op->getResults());
|
|
||||||
}
|
|
||||||
if (failed(populateShapeCalculationRegion(
|
|
||||||
shapeCalculate, op->getOperands(), shapeFunction))) {
|
|
||||||
hadError = true;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (hadError)
|
if (walkResult.wasInterrupted())
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
importLibraryFunctions(module, *library, std::move(functionsNeeded));
|
||||||
// Import just the functions we need. This includes transitive callees,
|
|
||||||
// so we use a worklist algorithm.
|
|
||||||
llvm::StringSet<> importedFunctions;
|
|
||||||
SmallVector<std::string> worklist;
|
|
||||||
llvm::append_range(worklist, neededShapeFunctions);
|
|
||||||
while (!worklist.empty()) {
|
|
||||||
auto symName = worklist.pop_back_val();
|
|
||||||
if (importedFunctions.count(symName))
|
|
||||||
continue;
|
|
||||||
auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
|
|
||||||
assert(func && "broken shape library");
|
|
||||||
// Move the shape function from the library to the module this pass
|
|
||||||
// is running on. (this mutates the library, but we re-parse it each time
|
|
||||||
// so this is safe to do).
|
|
||||||
func->moveBefore(&module.getBody()->front());
|
|
||||||
// Set the visibility to private so that the shape functions go away
|
|
||||||
// nicely after we are done with them.
|
|
||||||
func.setVisibility(SymbolTable::Visibility::Private);
|
|
||||||
// Continue the DFS.
|
|
||||||
importedFunctions.insert(symName);
|
|
||||||
func.walk(
|
|
||||||
[&](func::CallOp op) { worklist.push_back(op.getCallee().str()); });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
|
int resultNum,
|
||||||
|
Type newResultType,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
Location loc = calculateOp->getLoc();
|
||||||
|
auto result = calculateOp->getResult(resultNum);
|
||||||
|
Type originalResultType = result.getType();
|
||||||
|
Type updatedType;
|
||||||
|
if (auto originalBaseTensorType =
|
||||||
|
originalResultType.template dyn_cast<BaseTensorType>()) {
|
||||||
|
// If we didn't get any new information, there is nothing left for us to do.
|
||||||
|
updatedType = meetTensorTypes(originalBaseTensorType,
|
||||||
|
newResultType.cast<BaseTensorType>());
|
||||||
|
if (!updatedType || updatedType == originalBaseTensorType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
calculateOp, "New type information does not refine old type");
|
||||||
|
} else if (auto originalResultType =
|
||||||
|
result.getType().template dyn_cast<Torch::NumberType>()) {
|
||||||
|
if (!newResultType.isa<Torch::FloatType, Torch::IntType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
calculateOp,
|
||||||
|
"Refinement of `NumberType` must be a `FloatType` or `IntType`");
|
||||||
|
}
|
||||||
|
updatedType = newResultType;
|
||||||
|
} else {
|
||||||
|
return rewriter.notifyMatchFailure(calculateOp,
|
||||||
|
"Unimplemented: Expected result type to "
|
||||||
|
"be `BaseTensorType` or `NumberType`");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update all the uses of the result type to the new type, if possible. Insert
|
||||||
|
// a TensorStaticInfoCastOp for any users that might require the exact
|
||||||
|
// previous type.
|
||||||
|
Value originalTypedValue;
|
||||||
|
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
|
||||||
|
if (use.getOwner()
|
||||||
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!originalTypedValue) {
|
||||||
|
rewriter.setInsertionPointAfter(calculateOp);
|
||||||
|
if (originalResultType.isa<BaseTensorType>()) {
|
||||||
|
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
|
loc, originalResultType, result);
|
||||||
|
} else if (originalResultType.isa<Torch::NumberType>()) {
|
||||||
|
originalTypedValue =
|
||||||
|
rewriter.create<DerefineOp>(loc, originalResultType, result);
|
||||||
|
} else {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
calculateOp, "Unimplemented: Expected result type to "
|
||||||
|
"be `BaseTensorType` or `NumberType`");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
use.set(originalTypedValue);
|
||||||
|
}
|
||||||
|
result.setType(updatedType);
|
||||||
|
|
||||||
|
// Update the value yielded from the body to match the new result type. If we
|
||||||
|
// can refine the def in place, do that, otherwise insert a
|
||||||
|
// TensorStaticInfoCastOp.
|
||||||
|
Operation *yieldValues = calculateOp->getRegion(0).front().getTerminator();
|
||||||
|
OpOperand &use = yieldValues->getOpOperand(resultNum);
|
||||||
|
Value def = use.get();
|
||||||
|
Value newYieldedValue;
|
||||||
|
if (def.isa<OpResult>() &&
|
||||||
|
def.cast<OpResult>()
|
||||||
|
.getDefiningOp()
|
||||||
|
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
||||||
|
newYieldedValue = def;
|
||||||
|
} else {
|
||||||
|
rewriter.setInsertionPoint(yieldValues);
|
||||||
|
if (updatedType.isa<BaseTensorType>()) {
|
||||||
|
newYieldedValue =
|
||||||
|
rewriter.create<TensorStaticInfoCastOp>(loc, updatedType, def);
|
||||||
|
} else {
|
||||||
|
newYieldedValue =
|
||||||
|
rewriter.create<PrimUncheckedCastOp>(loc, updatedType, def);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
use.set(newYieldedValue);
|
||||||
|
newYieldedValue.setType(updatedType);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
||||||
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace Torch {
|
||||||
|
|
||||||
|
// Updates the type of result `resultNum` of both `calculateOp` and the torch op
|
||||||
|
// being wrapped by `calculateOp` to the type `newResultType`.
|
||||||
|
LogicalResult updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
|
int resultNum, Type newResultType,
|
||||||
|
PatternRewriter &rewriter);
|
||||||
|
|
||||||
|
} // namespace Torch
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_DIALECT_TORCH_TRANSFORMS_SIMPLIFY_ABSTRACT_INTERP_CALCULATIONS_UTILS_H
|
|
@ -0,0 +1,209 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
||||||
|
int resultNum,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
auto yieldDtypes = op.getCalculation().front().getTerminator();
|
||||||
|
auto dtype = yieldDtypes->getOperand(resultNum);
|
||||||
|
auto result = op->getResult(resultNum);
|
||||||
|
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected result from the DtypeCalculateOp calculation to be a "
|
||||||
|
"constant int");
|
||||||
|
auto dtypeScalarType = static_cast<torch_upstream::ScalarType>(dtypeInt);
|
||||||
|
|
||||||
|
// Calculate the updated type incorporating the new information.
|
||||||
|
Type impliedTypeFromDtype;
|
||||||
|
if (result.getType().isa<Torch::NumberType>()) {
|
||||||
|
FailureOr<Type> torchType =
|
||||||
|
getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||||
|
if (failed(torchType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Failed to convert result dtype to `Torch::FloatType` or "
|
||||||
|
"`Torch::IntType`");
|
||||||
|
}
|
||||||
|
impliedTypeFromDtype = *torchType;
|
||||||
|
} else if (auto originalResultType =
|
||||||
|
result.getType().dyn_cast<BaseTensorType>()) {
|
||||||
|
impliedTypeFromDtype =
|
||||||
|
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
||||||
|
originalResultType.getOptionalSizes(),
|
||||||
|
getTypeForScalarType(op->getContext(), dtypeScalarType));
|
||||||
|
} else {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Unimplemented: Expected result type to "
|
||||||
|
"be `BaseTensorType` or `NumberType`");
|
||||||
|
}
|
||||||
|
return updateCalculateOpResultTypes(op, resultNum, impliedTypeFromDtype,
|
||||||
|
rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This pattern propagates information out of the dtype calculation region and
|
||||||
|
// into the DtypeCalculateOp result types.
|
||||||
|
class RefineDtypeCalculateOp : public OpRewritePattern<DtypeCalculateOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(DtypeCalculateOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
LogicalResult result = failure();
|
||||||
|
for (int i = 0, e = op->getNumResults(); i != e; i++) {
|
||||||
|
if (succeeded(refineDtypeCalculateResult(op, i, rewriter)))
|
||||||
|
result = success();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposePromoteDtypesOp : public OpRewritePattern<PromoteDtypesOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(PromoteDtypesOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Optional<int64_t>> ranks;
|
||||||
|
SmallVector<int64_t> dtypes;
|
||||||
|
if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `ranks` to be a list of optional constant ints");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matchPattern(op.getDtypes(), m_TorchListOfConstantInts(dtypes))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `dtypes` to be a list of constant ints");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ranks.empty() || dtypes.empty()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "`ranks` list and `dtypes` list must be non-empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ranks.size() != dtypes.size()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "`ranks` list and `dtypes` list must have the same size");
|
||||||
|
}
|
||||||
|
|
||||||
|
torch_upstream::ResultTypeState state{};
|
||||||
|
for (auto ranksAndDtypes : llvm::zip(ranks, dtypes)) {
|
||||||
|
Optional<int64_t> rank;
|
||||||
|
int64_t dtype;
|
||||||
|
std::tie(rank, dtype) = ranksAndDtypes;
|
||||||
|
auto scalarType = static_cast<torch_upstream::ScalarType>(dtype);
|
||||||
|
|
||||||
|
bool isScalarOnlyOp = llvm::all_of(
|
||||||
|
ranks, [](Optional<int64_t> rank) { return !rank.has_value(); });
|
||||||
|
|
||||||
|
if (!rank.has_value()) {
|
||||||
|
// If `rank` does not have a value, then we are dealing with a scalar
|
||||||
|
// input. For the type promotion, the behavior of a scalar argument is
|
||||||
|
// dependent on whether the op is performing an operation with only
|
||||||
|
// scalars (such as AtenAddOp) or with scalars and tensors (such as
|
||||||
|
// AtenAddScalarOp). Therefore, we convert back to the original torch
|
||||||
|
// type of the scalar first, and then determine the right scalar type to
|
||||||
|
// use for promotion based on whether the op involves only scalars or
|
||||||
|
// scalars and tensors.
|
||||||
|
FailureOr<Type> torchType =
|
||||||
|
getTorchTypeForScalarType(op->getContext(), scalarType);
|
||||||
|
if (failed(torchType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Dtypes for arguments scalars must be convertible to "
|
||||||
|
"`Torch::FloatType` or `Torch::IntType`");
|
||||||
|
}
|
||||||
|
Type builtinType = isScalarOnlyOp
|
||||||
|
? getBuiltInTypeForTorchScalar(*torchType)
|
||||||
|
: getDefaultDtypeForTorchScalar(*torchType);
|
||||||
|
scalarType = getScalarTypeForType(builtinType);
|
||||||
|
state.wrappedResult =
|
||||||
|
promote_skip_undefined(state.wrappedResult, scalarType);
|
||||||
|
} else if (rank.value() == 0) {
|
||||||
|
state.zeroResult = promote_skip_undefined(state.zeroResult, scalarType);
|
||||||
|
} else if (rank.value() > 0) {
|
||||||
|
state.dimResult = promote_skip_undefined(state.dimResult, scalarType);
|
||||||
|
} else {
|
||||||
|
return rewriter.notifyMatchFailure(op, "Rank should not be negative");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultType = static_cast<int64_t>(result_type(state));
|
||||||
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(
|
||||||
|
op, rewriter.getI64IntegerAttr(resultType));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class RefineNumToTensorScalarOpType
|
||||||
|
: public OpRewritePattern<PrimNumToTensorScalarOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto originalResultType = op.getResult().getType().cast<BaseTensorType>();
|
||||||
|
if (originalResultType.hasDtype())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "`PrimNumToTensorScalarOp` already has a dtype");
|
||||||
|
|
||||||
|
Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType());
|
||||||
|
auto impliedTypeFromInputType =
|
||||||
|
originalResultType.cast<BaseTensorType>()
|
||||||
|
.getWithSizesAndDtype(originalResultType.getOptionalSizes(),
|
||||||
|
inputType)
|
||||||
|
.cast<BaseTensorType>();
|
||||||
|
|
||||||
|
op.getResult().setType(impliedTypeFromInputType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class SimplifyDtypeCalculationsPass
|
||||||
|
: public SimplifyDtypeCalculationsBase<SimplifyDtypeCalculationsPass> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.insert<RefineDtypeCalculateOp>(context);
|
||||||
|
patterns.insert<DecomposePromoteDtypesOp>(context);
|
||||||
|
patterns.insert<RefineNumToTensorScalarOpType>(context);
|
||||||
|
|
||||||
|
// TODO: Debug visitation order to make this more efficient.
|
||||||
|
// A single linear scan should suffice.
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
config.useTopDownTraversal = true;
|
||||||
|
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
|
config))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::Torch::createSimplifyDtypeCalculationsPass() {
|
||||||
|
return std::make_unique<SimplifyDtypeCalculationsPass>();
|
||||||
|
}
|
|
@ -9,20 +9,11 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Builders.h"
|
|
||||||
#include "mlir/Parser/Parser.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "llvm/ADT/BitVector.h"
|
|
||||||
#include "llvm/ADT/StringExtras.h"
|
|
||||||
#include "llvm/ADT/StringSet.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -40,10 +31,12 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
if (!op.isForLike())
|
if (!op.isForLike())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
|
||||||
int64_t maxTripCount;
|
int64_t maxTripCount;
|
||||||
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
|
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `maxTripCount` to be a constant int");
|
||||||
|
;
|
||||||
SmallVector<Value> indices;
|
SmallVector<Value> indices;
|
||||||
for (int64_t i = 0; i < maxTripCount; i++) {
|
for (int64_t i = 0; i < maxTripCount; i++) {
|
||||||
// TODO: Add convenience builder.
|
// TODO: Add convenience builder.
|
||||||
|
@ -149,7 +142,8 @@ public:
|
||||||
for (Operation *user : usersToInterpret) {
|
for (Operation *user : usersToInterpret) {
|
||||||
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
|
||||||
if (!append.use_empty())
|
if (!append.use_empty())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `AtenAppendTOp` to not have users");
|
||||||
if (append.getSelf() == op) {
|
if (append.getSelf() == op) {
|
||||||
runningList.push_back(append.getEl());
|
runningList.push_back(append.getEl());
|
||||||
generatedNewLiteral = true;
|
generatedNewLiteral = true;
|
||||||
|
@ -159,13 +153,16 @@ public:
|
||||||
}
|
}
|
||||||
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
|
||||||
if (!insert.use_empty())
|
if (!insert.use_empty())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `AtenInsertTOp` to not have users");
|
||||||
int64_t index;
|
int64_t index;
|
||||||
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
|
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
|
||||||
// The index might be statically out of bounds.
|
// The index might be statically out of bounds.
|
||||||
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
|
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Index in `AtenInsertTOp` is out of bounds");
|
||||||
if (insert.getSelf() == op) {
|
if (insert.getSelf() == op) {
|
||||||
runningList.insert(runningList.begin() + index, insert.getEl());
|
runningList.insert(runningList.begin() + index, insert.getEl());
|
||||||
generatedNewLiteral = true;
|
generatedNewLiteral = true;
|
||||||
|
@ -175,13 +172,15 @@ public:
|
||||||
}
|
}
|
||||||
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
|
||||||
if (!setItem.use_empty())
|
if (!setItem.use_empty())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected `Aten_SetItemTOp` to not have users");
|
||||||
llvm::Optional<int64_t> indexOpt =
|
llvm::Optional<int64_t> indexOpt =
|
||||||
matchLegalConstantIndexIntoListOfSize(setItem.getIdx(),
|
matchLegalConstantIndexIntoListOfSize(setItem.getIdx(),
|
||||||
runningList.size());
|
runningList.size());
|
||||||
// The index might be statically out of bounds.
|
// The index might be statically out of bounds.
|
||||||
if (!indexOpt)
|
if (!indexOpt)
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Index in `Aten_SetItemTOp` is out of bounds");
|
||||||
if (setItem.getL() == op) {
|
if (setItem.getL() == op) {
|
||||||
runningList[*indexOpt] = setItem.getEl();
|
runningList[*indexOpt] = setItem.getEl();
|
||||||
generatedNewLiteral = true;
|
generatedNewLiteral = true;
|
||||||
|
@ -196,7 +195,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!generatedNewLiteral)
|
if (!generatedNewLiteral)
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(op, "No new literal created");
|
||||||
|
|
||||||
// Rewrite all users to use the appropriate list literals.
|
// Rewrite all users to use the appropriate list literals.
|
||||||
Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
Value latestLiteral = rewriter.create<PrimListConstructOp>(
|
||||||
|
@ -285,11 +284,10 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
|
static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||||
PatternRewriter &rewriter,
|
int resultNum,
|
||||||
bool &madeChange) {
|
PatternRewriter &rewriter) {
|
||||||
auto yieldValues = op.getBody().front().getTerminator();
|
auto yieldShapes = op.getCalculation().front().getTerminator();
|
||||||
auto yieldShapes = op.getShapeCalculation().front().getTerminator();
|
|
||||||
auto shape = yieldShapes->getOperand(resultNum);
|
auto shape = yieldShapes->getOperand(resultNum);
|
||||||
auto result = op->getResult(resultNum);
|
auto result = op->getResult(resultNum);
|
||||||
|
|
||||||
|
@ -298,7 +296,9 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
|
||||||
// much as possible to literals.
|
// much as possible to literals.
|
||||||
auto listConstruct = shape.getDefiningOp<PrimListConstructOp>();
|
auto listConstruct = shape.getDefiningOp<PrimListConstructOp>();
|
||||||
if (!listConstruct)
|
if (!listConstruct)
|
||||||
return;
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected result from ShapeCalculateOp calculation to be a "
|
||||||
|
"`PrimListConstructOp`");
|
||||||
llvm::BitVector clobberedElements(listConstruct->getNumOperands());
|
llvm::BitVector clobberedElements(listConstruct->getNumOperands());
|
||||||
// Analyze the users to determine if we can refine the shape.
|
// Analyze the users to determine if we can refine the shape.
|
||||||
for (Operation *user : listConstruct->getUsers()) {
|
for (Operation *user : listConstruct->getUsers()) {
|
||||||
|
@ -320,7 +320,7 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// An unhandled op! We can't make any assumptions about the shape.
|
// An unhandled op! We can't make any assumptions about the shape.
|
||||||
return;
|
return rewriter.notifyMatchFailure(op, "Unhandled op that mutates lists");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the list of sizes implied by the yielded shape.
|
// Construct the list of sizes implied by the yielded shape.
|
||||||
|
@ -334,55 +334,15 @@ static void refineShapeCalculateResult(ShapeCalculateOp op, int resultNum,
|
||||||
sizes.push_back(kUnknownSize);
|
sizes.push_back(kUnknownSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the updated type incorporating the new shape information.
|
auto originalResultType = result.getType().cast<BaseTensorType>();
|
||||||
Type originalResultType = result.getType();
|
|
||||||
auto impliedTypesFromShape =
|
auto impliedTypesFromShape =
|
||||||
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
originalResultType.cast<BaseTensorType>()
|
||||||
makeArrayRef(sizes), nullptr);
|
.getWithSizesAndDtype(makeArrayRef(sizes),
|
||||||
auto updatedType =
|
originalResultType.getOptionalDtype())
|
||||||
meetTensorTypes(originalResultType.cast<BaseTensorType>(),
|
.cast<BaseTensorType>();
|
||||||
impliedTypesFromShape.cast<BaseTensorType>());
|
|
||||||
// If we didn't get any new information, there is nothing left for us to do.
|
|
||||||
if (!updatedType || updatedType == originalResultType)
|
|
||||||
return;
|
|
||||||
|
|
||||||
// Update all the uses of the result type to the new type, if possible. Insert
|
return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape,
|
||||||
// a TensorStaticInfoCastOp for any users that might require the exact
|
rewriter);
|
||||||
// previous type.
|
|
||||||
Value originalTypedValue;
|
|
||||||
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
|
|
||||||
if (use.getOwner()
|
|
||||||
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!originalTypedValue) {
|
|
||||||
rewriter.setInsertionPointAfter(op);
|
|
||||||
originalTypedValue = rewriter.create<TensorStaticInfoCastOp>(
|
|
||||||
op->getLoc(), originalResultType, result);
|
|
||||||
}
|
|
||||||
use.set(originalTypedValue);
|
|
||||||
}
|
|
||||||
result.setType(updatedType);
|
|
||||||
madeChange = true;
|
|
||||||
|
|
||||||
// Update the value yielded from the body to match the new result type. If we
|
|
||||||
// can refine the def in place, do that, otherwise insert a
|
|
||||||
// TensorStaticInfoCastOp.
|
|
||||||
OpOperand &use = op.getBody().front().getTerminator()->getOpOperand(resultNum);
|
|
||||||
Value def = use.get();
|
|
||||||
Value newYieldedValue;
|
|
||||||
if (def.isa<OpResult>() &&
|
|
||||||
def.cast<OpResult>()
|
|
||||||
.getDefiningOp()
|
|
||||||
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>()) {
|
|
||||||
newYieldedValue = def;
|
|
||||||
} else {
|
|
||||||
rewriter.setInsertionPoint(yieldValues);
|
|
||||||
newYieldedValue =
|
|
||||||
rewriter.create<TensorStaticInfoCastOp>(op->getLoc(), updatedType, def);
|
|
||||||
}
|
|
||||||
use.set(newYieldedValue);
|
|
||||||
newYieldedValue.setType(updatedType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -393,10 +353,11 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(ShapeCalculateOp op,
|
LogicalResult matchAndRewrite(ShapeCalculateOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
bool madeChange = false;
|
LogicalResult result = failure();
|
||||||
for (int i = 0, e = op->getNumResults(); i != e; i++)
|
for (int i = 0, e = op->getNumResults(); i != e; i++)
|
||||||
refineShapeCalculateResult(op, i, rewriter, madeChange);
|
if (succeeded(refineShapeCalculateResult(op, i, rewriter)))
|
||||||
return success(madeChange);
|
result = success();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -100,7 +100,8 @@ Type Torch::getTypeForScalarType(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Type Torch::getTorchTypeForScalarType(MLIRContext *context,
|
FailureOr<Type>
|
||||||
|
Torch::getTorchTypeForScalarType(MLIRContext *context,
|
||||||
torch_upstream::ScalarType dtypeInt) {
|
torch_upstream::ScalarType dtypeInt) {
|
||||||
switch (dtypeInt) {
|
switch (dtypeInt) {
|
||||||
case torch_upstream::ScalarType::Double:
|
case torch_upstream::ScalarType::Double:
|
||||||
|
@ -108,11 +109,37 @@ Type Torch::getTorchTypeForScalarType(MLIRContext *context,
|
||||||
case torch_upstream::ScalarType::Long:
|
case torch_upstream::ScalarType::Long:
|
||||||
return Torch::IntType::get(context);
|
return Torch::IntType::get(context);
|
||||||
default:
|
default:
|
||||||
llvm::report_fatal_error(
|
return failure();
|
||||||
"Unsupported scalar type to Torch type conversion");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
|
||||||
|
MLIRContext *context = type.getContext();
|
||||||
|
if (type.isa<Torch::FloatType>()) {
|
||||||
|
// For now, use float32 which is the initial default dtype returned by
|
||||||
|
// `torch.get_default_dtype`.
|
||||||
|
return Float32Type::get(context);
|
||||||
|
}
|
||||||
|
if (type.isa<Torch::IntType>())
|
||||||
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
|
if (type.isa<Torch::BoolType>())
|
||||||
|
return IntegerType::get(context, 1);
|
||||||
|
llvm_unreachable(
|
||||||
|
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
||||||
|
}
|
||||||
|
|
||||||
|
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
|
||||||
|
MLIRContext *context = type.getContext();
|
||||||
|
if (type.isa<Torch::FloatType>())
|
||||||
|
return Float64Type::get(context);
|
||||||
|
if (type.isa<Torch::IntType>())
|
||||||
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
|
if (type.isa<Torch::BoolType>())
|
||||||
|
return IntegerType::get(context, 1);
|
||||||
|
llvm_unreachable(
|
||||||
|
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
||||||
|
}
|
||||||
|
|
||||||
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
Type dtype) {
|
Type dtype) {
|
||||||
int intType = (int)getScalarTypeForType(dtype);
|
int intType = (int)getScalarTypeForType(dtype);
|
||||||
|
|
|
@ -104,7 +104,8 @@ endif()
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Custom op example
|
# Custom op example
|
||||||
# Required for running the update_torch_ods.sh and update_shape_lib.sh scripts.
|
# Required for running the update_torch_ods.sh and update_abstract_interp_lib.sh
|
||||||
|
# scripts.
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
# add_subdirectory(torch_mlir/_torch_mlir_custom_op_example)
|
||||||
|
|
|
@ -5,12 +5,12 @@ If you're reading this, you're likely looking to create or support a third-party
|
||||||
This isn't much different than [adding a new PyTorch op to torch-mlir](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
|
This isn't much different than [adding a new PyTorch op to torch-mlir](https://github.com/llvm/torch-mlir/wiki/Torch-ops-E2E-implementation).
|
||||||
You'll still go through the exact same process, with just a small change:
|
You'll still go through the exact same process, with just a small change:
|
||||||
|
|
||||||
- Before running `update_torch_ods.sh` or `update_shape_lib.sh`, you'll want to set the `TORCH_MLIR_EXT_PYTHONPATH` to point to wherever your extension lives and the `TORCH_MLIR_EXT_MODULES` to the name of the python module.
|
- Before running `update_torch_ods.sh` or `update_abstract_interp_lib.sh`, you'll want to set the `TORCH_MLIR_EXT_PYTHONPATH` to point to wherever your extension lives and the `TORCH_MLIR_EXT_MODULES` to the name of the python module.
|
||||||
|
|
||||||
For instance, let's say you've written a python package called `my_torch_ops`.
|
For instance, let's say you've written a python package called `my_torch_ops`.
|
||||||
If `my_torch_ops` lives in `/example/subdirectory/`, then you'll want to set `TORCH_MLIR_EXT_PYTHONPATH=/example/subdirectory` and `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
|
If `my_torch_ops` lives in `/example/subdirectory/`, then you'll want to set `TORCH_MLIR_EXT_PYTHONPATH=/example/subdirectory` and `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
|
||||||
If you've installed your package (with `pip`, for instance), then you'll only need to set `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
|
If you've installed your package (with `pip`, for instance), then you'll only need to set `TORCH_MLIR_EXT_MODULES=my_torch_ops`.
|
||||||
Note that the `update_torch_ods.sh` and `update_shape_lib.sh` scripts do not use the `PYTHONPATH` environment variable in your current shell.
|
Note that the `update_torch_ods.sh` and `update_abstract_interp_lib.sh` scripts do not use the `PYTHONPATH` environment variable in your current shell.
|
||||||
This is on purpose, but it means that you either need to set `TORCH_MLIR_EXT_PYTHONPATH` to include your package or to include the paths set in your shell's `PYTHONPATH` variable.
|
This is on purpose, but it means that you either need to set `TORCH_MLIR_EXT_PYTHONPATH` to include your package or to include the paths set in your shell's `PYTHONPATH` variable.
|
||||||
If you have more than one PyTorch extension, you can add them all by including each path in `TORCH_MLIR_EXT_PYTHONPATH` separated by colons (`:`) and each module in `TORCH_MLIR_EXT_MODULES` separated by commas (`,`).
|
If you have more than one PyTorch extension, you can add them all by including each path in `TORCH_MLIR_EXT_PYTHONPATH` separated by colons (`:`) and each module in `TORCH_MLIR_EXT_MODULES` separated by commas (`,`).
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,189 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||||
|
from torch_mlir.passmanager import PassManager
|
||||||
|
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
|
||||||
|
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
||||||
|
# that when `jit.script`ed converts a float scalar to a tensor
|
||||||
|
# with dtype that corresponds to Python's `float`.
|
||||||
|
#
|
||||||
|
# See definition of `NumToTensor`: https://github.com/pytorch/pytorch/blob/c09929659ce8ba2f1b7b2f6e50084ccbf854d44b/torch/csrc/jit/ir/ir.cpp#L1850
|
||||||
|
# Definition of `fromNumberType` used by `NumToTensor` to
|
||||||
|
# calculate dtype: https://github.com/pytorch/pytorch/blob/c09929659ce8ba2f1b7b2f6e50084ccbf854d44b/aten/src/ATen/core/jit_type.h#L1679
|
||||||
|
#
|
||||||
|
# Note that doing something like
|
||||||
|
# `torch.tensor(scalar).to(type(scalar)).dtype` does not work because
|
||||||
|
# `torch.tensor` needs to know the type of the input at compile time
|
||||||
|
# and there is no `jit.script` support for Python's `type`.
|
||||||
|
#
|
||||||
|
# TODO: A better way to handle this would be to add support for
|
||||||
|
# `isinstance` in torch-mlir, which requires adding a new torch dialect
|
||||||
|
# op.
|
||||||
|
return torch.ops.prim.NumToTensor(scalar).dtype
|
||||||
|
|
||||||
|
# When we import into torch-mlir, only the calls to
|
||||||
|
# `__torch_mlir_internal_promote_dtypes` are used to generate the
|
||||||
|
# `torch.promote_dtypes` ops. Therefore, to avoid generating extra
|
||||||
|
# MLIR code in the library, all calls made inside
|
||||||
|
# `__torch_mlir_internal_promote_dtypes` are `jit.ignore`d.
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _get_scalar_with_dtype(dtype: torch.dtype) -> Union[int, float]:
|
||||||
|
if dtype == torch.int64:
|
||||||
|
return 0
|
||||||
|
elif dtype == torch.float64:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unhandled dtype: {dtype}")
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _promote_scalar_tensor(scalar_dtype: torch.dtype, tensor_rank: int,
|
||||||
|
tensor_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
scalar = _get_scalar_with_dtype(scalar_dtype)
|
||||||
|
tensor = torch.rand([1] * tensor_rank).to(tensor_dtype)
|
||||||
|
return torch.result_type(scalar, tensor)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _promote_tensor_tensor(lhs_rank: int, lhs_dtype: torch.dtype,
|
||||||
|
rhs_rank: int, rhs_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
lhs_tensor = torch.rand([1] * lhs_rank).to(lhs_dtype)
|
||||||
|
rhs_tensor = torch.rand([1] * rhs_rank).to(rhs_dtype)
|
||||||
|
return torch.result_type(lhs_tensor, rhs_tensor)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _promote_scalar_scalar(lhs_dtype: torch.dtype,
|
||||||
|
rhs_dtype: torch.dtype) -> torch.dtype:
|
||||||
|
# When `torch.result_type` is used on two scalars, the result
|
||||||
|
# dtype is the dtype one would expect for an op with signature
|
||||||
|
# (Scalar, Scalar) -> (Tensor). However, once a module gets
|
||||||
|
# jit.scripted, all math ops that work on scalars becomes
|
||||||
|
# (Scalar, Scalar) -> (Scalar) ops. So to get the right result
|
||||||
|
# dtype, we use the tensor-tensor promotion rules.
|
||||||
|
return _promote_tensor_tensor(0, lhs_dtype, 0, rhs_dtype)
|
||||||
|
|
||||||
|
def promote_dtypes(ranks: List[Optional[int]],
|
||||||
|
dtypes: List[torch.dtype]) -> torch.dtype:
|
||||||
|
"""Apply PyTorch dtype promotion rules and return the result type.
|
||||||
|
"""
|
||||||
|
return __torch_mlir_internal_promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
def __torch_mlir_internal_promote_dtypes(ranks: List[Optional[int]],
|
||||||
|
dtypes: List[torch.dtype]
|
||||||
|
) -> torch.dtype:
|
||||||
|
"""Apply PyTorch dtype promotion rules and return the result type.
|
||||||
|
|
||||||
|
This function serves two purposes:
|
||||||
|
1. It is handled in a special way during import into Torch-MLIR,
|
||||||
|
generating `torch.promote_dtypes` ops
|
||||||
|
2. Computes the actual promotion logic at the Python level in order
|
||||||
|
to be able to test dtype calculation functions against PyTorch
|
||||||
|
"""
|
||||||
|
lhs_optional_rank = ranks[0]
|
||||||
|
lhs_dtype = dtypes[0]
|
||||||
|
for rhs_optional_rank, rhs_dtype in zip(ranks, dtypes):
|
||||||
|
if lhs_optional_rank is None and rhs_optional_rank is None:
|
||||||
|
lhs_dtype = _promote_scalar_scalar(lhs_dtype, rhs_dtype)
|
||||||
|
elif lhs_optional_rank is None and rhs_optional_rank is not None:
|
||||||
|
lhs_dtype = _promote_scalar_tensor(
|
||||||
|
lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||||
|
lhs_optional_rank = rhs_optional_rank
|
||||||
|
elif lhs_optional_rank is not None and rhs_optional_rank is None:
|
||||||
|
lhs_dtype = _promote_scalar_tensor(
|
||||||
|
rhs_dtype, lhs_optional_rank, lhs_dtype)
|
||||||
|
elif lhs_optional_rank is not None and rhs_optional_rank is not None:
|
||||||
|
lhs_dtype = _promote_tensor_tensor(
|
||||||
|
lhs_optional_rank, lhs_dtype, rhs_optional_rank, rhs_dtype)
|
||||||
|
lhs_optional_rank = max(lhs_optional_rank, rhs_optional_rank)
|
||||||
|
return lhs_dtype
|
||||||
|
|
||||||
|
def not_present_in_registry(f):
|
||||||
|
"""Decorator for abstract interpretation functions not present in the registry.
|
||||||
|
|
||||||
|
This can happen for "valsem" ops that we have in Torch-MLIR, such as
|
||||||
|
torch.valsem.aten.fill.Scalar, which are consistent with PyTorch conventions
|
||||||
|
(e.g. being the value-semantic correspondent of torch.aten.fill_.Scalar),
|
||||||
|
but that for whatever reason are not present in PyTorch. Such ops are useful
|
||||||
|
to keep certain passes within Torch-MLIR as consistent as possible.
|
||||||
|
For such ops, in the shape library generator, we treat them as if they
|
||||||
|
were registered torch ops (so we don't put "valsem" on them), to keep the
|
||||||
|
generator consistent.
|
||||||
|
|
||||||
|
To check if this decorator has been applied, use
|
||||||
|
`hasattr(f, "_not_present_in_registry")`.
|
||||||
|
"""
|
||||||
|
f._not_present_in_registry = None
|
||||||
|
return f
|
||||||
|
|
||||||
|
def _verify_signature_matches_registry(f, registry: Registry):
|
||||||
|
source = inspect.getsource(f)
|
||||||
|
signature = None
|
||||||
|
for line in source.splitlines():
|
||||||
|
if line.startswith("def "):
|
||||||
|
signature = line
|
||||||
|
break
|
||||||
|
assert signature is not None, f"Could not find signature for {f.__name__}"
|
||||||
|
assert "〡" in signature, f"Malformed signature {signature}. Signature missing the character `〡`"
|
||||||
|
function_name, function_kind = f.__name__.split("〡")
|
||||||
|
atoms = function_name.split("〇")
|
||||||
|
if len(atoms) == 2:
|
||||||
|
atoms += [""]
|
||||||
|
operator = registry.get_by_triple(tuple(atoms))
|
||||||
|
if function_kind == "shape":
|
||||||
|
expected_signature = operator.get_shape_function_signature()
|
||||||
|
elif function_kind == "dtype":
|
||||||
|
expected_signature = operator.get_dtype_function_signature()
|
||||||
|
elif function_kind == "decomposition":
|
||||||
|
expected_signature = operator.get_decomposition_function_signature()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid Op signature function kind: '{function_kind}'")
|
||||||
|
if signature != expected_signature:
|
||||||
|
raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}")
|
||||||
|
|
||||||
|
def generate_library(globals_) -> str:
|
||||||
|
"""Convert all op functions in `globals()` into MLIR."""
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
# We use the registry to ensure that the shape functions are consistent
|
||||||
|
# with the ops.
|
||||||
|
registry = Registry.load()
|
||||||
|
for k, v in globals_.items():
|
||||||
|
if "〇" not in k:
|
||||||
|
continue
|
||||||
|
if not hasattr(v, "_not_present_in_registry"):
|
||||||
|
_verify_signature_matches_registry(v, registry)
|
||||||
|
# Add it to the compilation unit.
|
||||||
|
torch.jit.script(v)
|
||||||
|
for function in torch.jit._state._python_cu.get_functions():
|
||||||
|
# Calls to the function `__torch_mlir_internal_promote_dtypes`
|
||||||
|
# will get converted to the torch-dialect op `torch.promote_dtypes`
|
||||||
|
# during import, so there is no need to import the actual
|
||||||
|
# function.
|
||||||
|
if function.name == "__torch_mlir_internal_promote_dtypes":
|
||||||
|
continue
|
||||||
|
mb.import_function(function)
|
||||||
|
# Clean up the IR a bit before writing it out.
|
||||||
|
pm = PassManager.parse("builtin.module(canonicalize)", context=mb.module.context)
|
||||||
|
pm.run(mb.module)
|
||||||
|
# Munge the IR a bit to make it more systematically accessible.
|
||||||
|
asm = mb.module.operation.get_asm()
|
||||||
|
# We'd like a unique function prefix to avoid collisions with user-
|
||||||
|
# defined symbols. Since all of our shape functions conveniently have
|
||||||
|
# a `〇` in them, we replace the torch namespace with our prefix. E.g.:
|
||||||
|
# __torch__.aten〇add〇Scalar -> __torch_mlir_shape_fn.aten〇add〇Scalar
|
||||||
|
asm = re.sub(r"__torch__\.([^.(]+)\\E3\\80\\87([^.(]+)\\E3\\80\\A1([^.(\"]+)",
|
||||||
|
r"__torch_mlir_\3_fn.\1\\E3\\80\\87\2",
|
||||||
|
asm)
|
||||||
|
|
||||||
|
# Put the `〇` back to a regular `.`.
|
||||||
|
asm = asm.replace("\\E3\\80\\87", ".")
|
||||||
|
return asm
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
"""Access to the Torch JIT operator registry."""
|
"""Access to the Torch JIT operator registry."""
|
||||||
|
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union, Callable
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -15,6 +15,48 @@ from .utils import TextEmitter
|
||||||
# Note that this utility exists only in the c-extension.
|
# Note that this utility exists only in the c-extension.
|
||||||
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
from torch_mlir._mlir_libs._jit_ir_importer import get_registered_ops # pytype: disable=import-error
|
||||||
|
|
||||||
|
def _rename_python_keyword_parameter_name(parameter_name: str) -> str:
|
||||||
|
if parameter_name == "from":
|
||||||
|
parameter_name = "from_" # Avoid using a Python keyword.
|
||||||
|
return parameter_name
|
||||||
|
|
||||||
|
def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
default = ""
|
||||||
|
if "default_debug" in arg:
|
||||||
|
if "List" in arg["pytype"]:
|
||||||
|
# TorchScript doesn't allow lists as default parameters due
|
||||||
|
# to the weird Python semantics of mutable default
|
||||||
|
# arguments. So munge them into tuples, which work
|
||||||
|
# fine here. We only need these to simplify the invocation
|
||||||
|
# of the abstract interpretation functions as valid Python for
|
||||||
|
# testing against the real ops, and tuples work fine in all
|
||||||
|
# the places this kicks in (e.g. conv dilations -- we aren't
|
||||||
|
# mutating those lists).
|
||||||
|
default_debug = arg["default_debug"].replace(
|
||||||
|
'[', '(').replace(']', ')')
|
||||||
|
elif arg["pytype"] == "str":
|
||||||
|
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
||||||
|
else:
|
||||||
|
default_debug = arg["default_debug"]
|
||||||
|
default = f" = {default_debug}"
|
||||||
|
return default
|
||||||
|
|
||||||
|
def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
||||||
|
if "number" in pytype:
|
||||||
|
return pytype.replace("number", "Union[int, float]")
|
||||||
|
# `torch.device` is lowercase.
|
||||||
|
if pytype == "Device":
|
||||||
|
return "device"
|
||||||
|
if pytype == "Optional[Device]":
|
||||||
|
return "Optional[device]"
|
||||||
|
# Generators don't contribute to shapes, and are not scriptable currently.
|
||||||
|
# So just hack them to be passed as "Any".
|
||||||
|
if pytype == "Generator":
|
||||||
|
return "Any"
|
||||||
|
if pytype == "Optional[Generator]":
|
||||||
|
return "Any"
|
||||||
|
return pytype
|
||||||
|
|
||||||
def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
||||||
"""Convert a JitOperator pytype to the type relevant in shape functions.
|
"""Convert a JitOperator pytype to the type relevant in shape functions.
|
||||||
|
|
||||||
|
@ -28,31 +70,29 @@ def _pytype_to_shape_fn_pytype(pytype: str) -> str:
|
||||||
# logically real-valued), so it doesn't really matter much, and
|
# logically real-valued), so it doesn't really matter much, and
|
||||||
# `float` helps make it clearer that it's not part of the shape
|
# `float` helps make it clearer that it's not part of the shape
|
||||||
# function.
|
# function.
|
||||||
if pytype == "number":
|
# TODO: This is no longer needed. Scalars can be represented by
|
||||||
return "float"
|
# Union[int, float].
|
||||||
if pytype == "Optional[number]":
|
if "number" in pytype:
|
||||||
return "Optional[float]"
|
return pytype.replace("number", "float")
|
||||||
# `torch.device` is lowercase.
|
if "Tensor" in pytype:
|
||||||
if pytype == "Device":
|
return pytype.replace("Tensor", "List[int]")
|
||||||
return "device"
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
if pytype == "Optional[Device]":
|
|
||||||
return "Optional[device]"
|
def _pytype_to_dtype_fn_pytype(pytype: str) -> str:
|
||||||
# Shape functions only care about the shape of tensors.
|
"""Convert a JitOperator pytype to the type relevant in dtype functions.
|
||||||
if pytype == "Tensor":
|
|
||||||
return "List[int]"
|
In particular, this converts `Tensor` to `int`, along with a few
|
||||||
if pytype == "Optional[Tensor]":
|
other special cases.
|
||||||
return "Optional[List[int]]"
|
"""
|
||||||
if pytype == "List[Tensor]":
|
# Dtype functions only care about the rank and dtype of tensors.
|
||||||
return "List[List[int]]"
|
if "Tensor" in pytype:
|
||||||
if pytype == "List[Optional[Tensor]]":
|
return pytype.replace("Tensor", "int")
|
||||||
return "List[Optional[List[int]]]"
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
# Generators don't contribute to shapes, and are not scriptable currently.
|
|
||||||
# So just hack them to be passed as "Any".
|
def _pytype_to_decomposition_fn_pytype(pytype: str) -> str:
|
||||||
if pytype == "Generator":
|
"""Convert a JitOperator pytype to the type relevant in decomposition functions.
|
||||||
return "Any"
|
"""
|
||||||
if pytype == "Optional[Generator]":
|
return _pytype_to_fn_pytype_common(pytype)
|
||||||
return "Any"
|
|
||||||
return pytype
|
|
||||||
|
|
||||||
class JitOperator:
|
class JitOperator:
|
||||||
"""Information about a single registered `torch::jit::Operator`"""
|
"""Information about a single registered `torch::jit::Operator`"""
|
||||||
|
@ -142,6 +182,23 @@ class JitOperator:
|
||||||
cpp_class_name = cpp_class_name.lstrip("_")
|
cpp_class_name = cpp_class_name.lstrip("_")
|
||||||
return op_name, cpp_class_name
|
return op_name, cpp_class_name
|
||||||
|
|
||||||
|
def _get_function_signature(self, function_kind: str,
|
||||||
|
parameter_decl_builder: Callable["SIG_ATTR_TYPE", str],
|
||||||
|
ret_decl_builder: Callable["SIG_ATTR_TYPE", str]) -> str:
|
||||||
|
mlir_op_name, _ = self.get_mlir_names()
|
||||||
|
# Replace `.` with a valid Python identifier character.
|
||||||
|
# `〇` vaguely looks like `.`.
|
||||||
|
def_name = "〇".join(mlir_op_name.split("."))
|
||||||
|
def_name += f"〡{function_kind}"
|
||||||
|
parameter_decls = list(map(parameter_decl_builder, self.arguments))
|
||||||
|
ret_decls = list(map(ret_decl_builder, self.returns))
|
||||||
|
parameters = ", ".join(parameter_decls)
|
||||||
|
result = ", ".join(ret_decls)
|
||||||
|
if len(ret_decls) >= 2:
|
||||||
|
result = f"Tuple[{result}]"
|
||||||
|
|
||||||
|
return f"def {def_name}({parameters}) -> {result}:"
|
||||||
|
|
||||||
def get_shape_function_signature(self):
|
def get_shape_function_signature(self):
|
||||||
"""Gets the Python function signature for this op's shape function.
|
"""Gets the Python function signature for this op's shape function.
|
||||||
|
|
||||||
|
@ -150,45 +207,66 @@ class JitOperator:
|
||||||
ops have extra default arguments and stuff that are tedious to write out
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
right.
|
right.
|
||||||
"""
|
"""
|
||||||
mlir_op_name, _ = self.get_mlir_names()
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
# Replace `.` with a valid Python identifier character.
|
|
||||||
# `〇` vaguely looks like `.`.
|
|
||||||
def_name = "〇".join(mlir_op_name.split("."))
|
|
||||||
parameter_decls = []
|
|
||||||
for arg in self.arguments:
|
|
||||||
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
pytype = _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||||
default = ""
|
default = _get_default_value(arg)
|
||||||
if "default_debug" in arg:
|
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||||
if "List" in arg["pytype"]:
|
return f"{parameter_name}: {pytype}{default}"
|
||||||
# TorchScript doesn't allow lists as default parameters due
|
|
||||||
# to the weird Python semantics of mutable default
|
|
||||||
# arguments. So munge them into tuples, which work
|
|
||||||
# fine here. We only need these to simplify the invocation
|
|
||||||
# of the shape functions as valid Python for testing against
|
|
||||||
# the real ops, and tuples work fine in all the places this
|
|
||||||
# kicks in (e.g. conv dilations -- we aren't mutating those
|
|
||||||
# lists).
|
|
||||||
default_debug = arg["default_debug"].replace(
|
|
||||||
'[', '(').replace(']', ')')
|
|
||||||
elif arg["pytype"] == "str":
|
|
||||||
default_debug = repr(arg["default_debug"]).replace("'", '"')
|
|
||||||
else:
|
|
||||||
default_debug = arg["default_debug"]
|
|
||||||
default = f" = {default_debug}"
|
|
||||||
parameter_name = arg["name"]
|
|
||||||
if parameter_name == "from":
|
|
||||||
parameter_name = "from_" # Avoid using a Python keyword.
|
|
||||||
parameter_decls.append(f"{parameter_name}: {pytype}{default}")
|
|
||||||
ret_decls = []
|
|
||||||
for ret in self.returns:
|
|
||||||
pytype = _pytype_to_shape_fn_pytype(ret["pytype"])
|
|
||||||
ret_decls.append(f"{pytype}")
|
|
||||||
parameters = ", ".join(parameter_decls)
|
|
||||||
result = ", ".join(ret_decls)
|
|
||||||
if len(ret_decls) >= 2:
|
|
||||||
result = f"Tuple[{result}]"
|
|
||||||
|
|
||||||
return f"def {def_name}({parameters}) -> {result}:"
|
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
return _pytype_to_shape_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
|
return self._get_function_signature(
|
||||||
|
"shape", parameter_decl_builder, ret_decl_builder)
|
||||||
|
|
||||||
|
def get_dtype_function_signature(self):
|
||||||
|
"""Gets the Python function signature for this op's dtype function.
|
||||||
|
|
||||||
|
While this is technically debug-only output, it is useful to copy-paste
|
||||||
|
it from the debug dump into the library definitions, as many
|
||||||
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
|
right.
|
||||||
|
"""
|
||||||
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
pytype = _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||||
|
default = _get_default_value(arg)
|
||||||
|
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||||
|
if "Tensor" in arg["pytype"]:
|
||||||
|
return ", ".join([f"{parameter_name}_rank: {pytype}{default}",
|
||||||
|
f"{parameter_name}_dtype: {pytype}{default}"])
|
||||||
|
return f"{parameter_name}: {pytype}{default}"
|
||||||
|
|
||||||
|
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
# The dtype function is expected to return dtypes for scalar
|
||||||
|
# results of type `number`. Here we handle this case because
|
||||||
|
# `_pytype_to_dtype_fn_pytype` will replace `number` with
|
||||||
|
# `Union[int, float]`.
|
||||||
|
if arg["pytype"] == "number":
|
||||||
|
return "int"
|
||||||
|
return _pytype_to_dtype_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
|
return self._get_function_signature(
|
||||||
|
"dtype", parameter_decl_builder, ret_decl_builder)
|
||||||
|
|
||||||
|
def get_decomposition_function_signature(self):
|
||||||
|
"""Gets the Python function signature for this op's decomposition function.
|
||||||
|
|
||||||
|
While this is technically debug-only output, it is useful to copy-paste
|
||||||
|
it from the debug dump into the shape library definitions, as many
|
||||||
|
ops have extra default arguments and stuff that are tedious to write out
|
||||||
|
right.
|
||||||
|
"""
|
||||||
|
def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
pytype = _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||||
|
default = _get_default_value(arg)
|
||||||
|
parameter_name = _rename_python_keyword_parameter_name(arg["name"])
|
||||||
|
return f"{parameter_name}: {pytype}{default}"
|
||||||
|
|
||||||
|
def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str:
|
||||||
|
return _pytype_to_decomposition_fn_pytype(arg["pytype"])
|
||||||
|
|
||||||
|
return self._get_function_signature(
|
||||||
|
"decomposition", parameter_decl_builder, ret_decl_builder)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
f = io.StringIO()
|
f = io.StringIO()
|
||||||
|
@ -212,6 +290,10 @@ class JitOperator:
|
||||||
p(f"is_mutable = {self.is_mutable}")
|
p(f"is_mutable = {self.is_mutable}")
|
||||||
if any(ret["type"] == "Tensor" for ret in self.returns):
|
if any(ret["type"] == "Tensor" for ret in self.returns):
|
||||||
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
p(f"shape_function_signature = {self.get_shape_function_signature()}")
|
||||||
|
p(f"decomposition_function_signature = {self.get_decomposition_function_signature()}")
|
||||||
|
if any(ret["type"] in ["Tensor", "Scalar"] for ret in self.returns):
|
||||||
|
p(f"dtype_function_signature = {self.get_dtype_function_signature()}")
|
||||||
|
|
||||||
if not self.arguments:
|
if not self.arguments:
|
||||||
p("arguments = []")
|
p("arguments = []")
|
||||||
else:
|
else:
|
||||||
|
@ -300,12 +382,13 @@ class Registry:
|
||||||
return self.by_triple[key]
|
return self.by_triple[key]
|
||||||
|
|
||||||
|
|
||||||
# A List[Dict[str, _]] mapping attribute names to:
|
# A Dict[str, _] mapping attribute names to:
|
||||||
# - str (e.g. {'name': 'dim'} )
|
# - str (e.g. {'name': 'dim'} )
|
||||||
# - int (e.g. {'N': 1} )
|
# - int (e.g. {'N': 1} )
|
||||||
# - Dict[str, List[str]]
|
# - Dict[str, List[str]]
|
||||||
# (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} )
|
# (e.g. {'alias_info': {'before': ['alias::a'], 'after': ['alias::a']}} )
|
||||||
SIGLIST_TYPE = List[Dict[str, Union[str, int, Dict[str, List[str]]]]]
|
SIG_ATTR_TYPE = Dict[str, Union[str, int, Dict[str, List[str]]]]
|
||||||
|
SIGLIST_TYPE = List[SIG_ATTR_TYPE]
|
||||||
# A Dict[str, _] describing a registered op. Each field is either
|
# A Dict[str, _] describing a registered op. Each field is either
|
||||||
# - bool (e.g. {'is_mutable': False} )
|
# - bool (e.g. {'is_mutable': False} )
|
||||||
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,315 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from typing import Any, List, Iterable, Optional, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Shape, dtype, and decomposition function testing infrastructure.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# We expect all functions to be adequately tested. For functions
|
||||||
|
# implemented with upstream helpers, additional testing is usually not needed.
|
||||||
|
# But for functions that are authored/maintained by the Torch-MLIR
|
||||||
|
# project, we expect adequate testing.
|
||||||
|
#
|
||||||
|
# To do this, we provide decorators
|
||||||
|
# - `@check_shape_function`
|
||||||
|
# - `@check_dtype_function`
|
||||||
|
# - `@check_decomposition_function`
|
||||||
|
# which can be used to specify a series of operator invocations (such as "call
|
||||||
|
# this operator with two arguments -- a first tensor of size [2, 3] and a second
|
||||||
|
# tensor of size [3, 4]"). These tests are then run as part of this script, and
|
||||||
|
# any mismatches from the real op's behavior will be reported.
|
||||||
|
#
|
||||||
|
# A typical use of the decorator might look like:
|
||||||
|
# ```
|
||||||
|
# @check_shape_function([
|
||||||
|
# Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||||
|
# Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`.
|
||||||
|
# Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`.
|
||||||
|
# Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`.
|
||||||
|
# Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`.
|
||||||
|
# ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds.
|
||||||
|
# ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds.
|
||||||
|
# ])
|
||||||
|
# ```
|
||||||
|
# Each `Invocation` takes a list of args/kwargs which will be passed to both the
|
||||||
|
# shape function and the real op and the results compared.
|
||||||
|
# We expect both the successful and error cases to be tested.
|
||||||
|
#
|
||||||
|
# The typical iteration flow is to add invocations to the list and then re-run
|
||||||
|
# `build_tools/update_abstract_interp_lib.sh` to re-run the tests.
|
||||||
|
|
||||||
|
class TensorOfShape:
|
||||||
|
"""Symbolic placeholder for a tensor argument to an operation.
|
||||||
|
|
||||||
|
Shape functions take tensor arguments as `List[int]`, whereas the real ops
|
||||||
|
take them as `Tensor`, so we need a symbolic representation of a tensor
|
||||||
|
argument to an op in order to represent an invocation that can drive both
|
||||||
|
the shape function and the real op (see `Invocation`).
|
||||||
|
|
||||||
|
A plain list doesn't work, because plain lists are actually legal arguments
|
||||||
|
to a shape function (e.g. conv dilations), and we don't want them to receive
|
||||||
|
this special treatment.
|
||||||
|
|
||||||
|
This class also tracks a dtype of the tensor, since some ops require a
|
||||||
|
specific dtype.
|
||||||
|
"""
|
||||||
|
def __init__(self, *shape: int, dtype: torch.dtype = torch.float32):
|
||||||
|
self.shape = list(shape)
|
||||||
|
self.dtype = dtype
|
||||||
|
def __repr__(self):
|
||||||
|
args_str = ", ".join(repr(x) for x in self.shape)
|
||||||
|
if self.dtype is torch.float32:
|
||||||
|
return f"TensorOfShape({args_str})"
|
||||||
|
else:
|
||||||
|
return f"TensorOfShape({args_str}, dtype={self.dtype})"
|
||||||
|
|
||||||
|
def LongTensorOfShape(*args, **kwargs):
|
||||||
|
"""Helper for indicating a TensorOfShape with integer type."""
|
||||||
|
return TensorOfShape(*args, **kwargs, dtype=torch.long)
|
||||||
|
|
||||||
|
def NonZeroDTensorWithDtype(dtype):
|
||||||
|
"""Helper for indicating a non-zero dim tensor with custom type."""
|
||||||
|
return TensorOfShape(1, dtype=dtype)
|
||||||
|
|
||||||
|
def ZeroDTensorWithDtype(dtype):
|
||||||
|
"""Helper for indicating a zero dim tensor with custom type."""
|
||||||
|
return TensorOfShape(dtype=dtype)
|
||||||
|
|
||||||
|
def _recursively_transform_tensor_args(
|
||||||
|
o: Any,
|
||||||
|
tensor_transformer: Callable[[TensorOfShape], Any]) -> Any:
|
||||||
|
"""Replace `TensorOfShape` with the result of `tensor_transformer`"""
|
||||||
|
if o is None or isinstance(o, (float, int)):
|
||||||
|
return o
|
||||||
|
if isinstance(o, TensorOfShape):
|
||||||
|
return tensor_transformer(o)
|
||||||
|
if isinstance(o, list):
|
||||||
|
return [_recursively_transform_tensor_args(x, tensor_transformer) for x in o]
|
||||||
|
if isinstance(o, tuple):
|
||||||
|
return tuple(_recursively_transform_tensor_args(x, tensor_transformer) for x in o)
|
||||||
|
raise Exception(f"Unhandled type {type(o)}")
|
||||||
|
|
||||||
|
def _convert_to_dtype_function_args(arguments: Iterable[Any]) -> List[Any]:
|
||||||
|
"""Converts an Invocation argument to a dtype function argument.
|
||||||
|
|
||||||
|
TensorOfShape is replaced with two ints representing the rank
|
||||||
|
and dtype of the tensor, respectively.
|
||||||
|
"""
|
||||||
|
def contains_tensor(o: Any) -> bool:
|
||||||
|
if o is None or isinstance(o, (float, int)):
|
||||||
|
return False
|
||||||
|
if isinstance(o, TensorOfShape):
|
||||||
|
return True
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
for elem in o:
|
||||||
|
if contains_tensor(elem):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
raise Exception(f"Unhandled type {type(o)}")
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for arg in arguments:
|
||||||
|
if contains_tensor(arg):
|
||||||
|
rank_arg = _recursively_transform_tensor_args(
|
||||||
|
arg, lambda x: len(x.shape))
|
||||||
|
dtype_arg = _recursively_transform_tensor_args(
|
||||||
|
arg, lambda x: x.dtype)
|
||||||
|
result += [rank_arg, dtype_arg]
|
||||||
|
else:
|
||||||
|
result.append(arg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
class Invocation:
|
||||||
|
"""Representation of a single op invocation (i.e. list of args to the op).
|
||||||
|
|
||||||
|
This class is used to represent a single invocation of an op in a way that
|
||||||
|
we can use to both invoke the abstract interpretation function and invoke
|
||||||
|
the actual op, which have slightly different signatures.
|
||||||
|
|
||||||
|
Specifically, this class has special knowledge of `TensorOfShape` and
|
||||||
|
translates it appropriately to either a tensor (for the real op), a
|
||||||
|
`List[int]` for the shape function, and two `int`s representing
|
||||||
|
the tensor rank and dtype in the case of a dtype function.
|
||||||
|
|
||||||
|
This class also tracks whether the invocation is expected to raise an
|
||||||
|
exception for greater precision when interpreting errors raised during
|
||||||
|
testing.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
|
self.args = list(args)
|
||||||
|
# We assume kwargs don't contain tensors, so they don't need any
|
||||||
|
# special handling.
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def is_expected_to_raise_exception(self) -> bool:
|
||||||
|
"""Returns true if the invocation is expected to raise an exception.
|
||||||
|
|
||||||
|
The subclass ErrorInvocation overrides this to indicate an Invocation
|
||||||
|
that is expected to raise an exception.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def to_shape_function_args(self):
|
||||||
|
"""Gets positional arguments appropriate for a shape function."""
|
||||||
|
# Make a copy of the size list, since a shape function might
|
||||||
|
# modify it in-place. In the compiler, the lowering always
|
||||||
|
# produces a new list via a fresh invocation of `AtenSizeOp`,
|
||||||
|
# which allocates a new, unaliased list. So in-place mutations
|
||||||
|
# are ok since they make it a bit easier to write some shape
|
||||||
|
# functions.
|
||||||
|
tensor_transformer = lambda o: list(o.shape)
|
||||||
|
return _recursively_transform_tensor_args(
|
||||||
|
self.args, tensor_transformer)
|
||||||
|
|
||||||
|
def to_dtype_function_args(self):
|
||||||
|
"""Gets positional arguments appropriate for a dtype function."""
|
||||||
|
return _convert_to_dtype_function_args(self.args)
|
||||||
|
|
||||||
|
def to_real_op_args(self):
|
||||||
|
"""Gets positional arguments appropriate for the real op."""
|
||||||
|
tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype)
|
||||||
|
return _recursively_transform_tensor_args(self.args, tensor_transformer)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
args_str = ", ".join(repr(x) for x in self.args)
|
||||||
|
kwargs_str = ""
|
||||||
|
if self.kwargs:
|
||||||
|
kwargs_str = ", " + ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
||||||
|
return f"Invocation({args_str}{kwargs_str})"
|
||||||
|
|
||||||
|
class ErrorInvocation(Invocation):
|
||||||
|
"""An Invocation that raises an exception.
|
||||||
|
|
||||||
|
Explicitly knowing that an invocation is expected to raise an exception
|
||||||
|
avoids certain failure modes of the test infrastructure where a bug
|
||||||
|
slips through when both the abstract interpretation function and the real
|
||||||
|
op raise exceptions due to independent bugs (that cancel each other out and
|
||||||
|
spurioiusly make the two appear to "agree" that an exception needs to be
|
||||||
|
raised).
|
||||||
|
"""
|
||||||
|
def is_expected_to_raise_exception(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _normalize_multiple_results_to_list(t: Any):
|
||||||
|
"""Returns a flat list of results.
|
||||||
|
|
||||||
|
This normalizes the fact that Python represents multiple returns with a
|
||||||
|
tuple, but single returns as a single value. We just want a list with
|
||||||
|
N elements for N results.
|
||||||
|
"""
|
||||||
|
if isinstance(t, tuple):
|
||||||
|
return list(t)
|
||||||
|
# Shape functions return List[int] instead of tensors.
|
||||||
|
if isinstance(t, (Tensor, list, torch.dtype, int, float)):
|
||||||
|
return [t]
|
||||||
|
raise ValueError(f"Unexpected type {type(t)}")
|
||||||
|
|
||||||
|
def _report(f, invocation: Invocation, error_message: str):
|
||||||
|
fn_type = f.__name__.split("〡")[-1]
|
||||||
|
raise ValueError(f"For {fn_type} function {f.__name__!r} with invocation {invocation}: {error_message}")
|
||||||
|
|
||||||
|
def _get_fn_and_golden_results(f, invocation: List[Invocation]):
|
||||||
|
"""Run the invocation on the library function and torch op.
|
||||||
|
|
||||||
|
If no unexpected errors are detected, returns a tuple wth the first
|
||||||
|
element being the results from the library function and the second
|
||||||
|
element being the results from the torch op. The results will be `None`
|
||||||
|
if the library function and torch op expectedly result in errors.
|
||||||
|
"""
|
||||||
|
fn_name_without_fn_type, fn_type = f.__name__.split("〡")
|
||||||
|
fn_name_parts = fn_name_without_fn_type.split("〇")
|
||||||
|
ns, unqual = fn_name_parts[:2]
|
||||||
|
overload = "default" if len(fn_name_parts) != 3 else fn_name_parts[-1]
|
||||||
|
op = getattr(getattr(getattr(torch.ops, ns), unqual), overload)
|
||||||
|
fn_error, op_error, fn_results, golden_results = None, None, None, None
|
||||||
|
try:
|
||||||
|
fn_results = _normalize_multiple_results_to_list(f(
|
||||||
|
*(getattr(invocation, f"to_{fn_type}_function_args")()),
|
||||||
|
**invocation.kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
fn_error = f"{e}"
|
||||||
|
try:
|
||||||
|
golden_results = _normalize_multiple_results_to_list(op(
|
||||||
|
*invocation.to_real_op_args(),
|
||||||
|
**invocation.kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
op_error = f"{e}"
|
||||||
|
|
||||||
|
# Check for error behavior.
|
||||||
|
if invocation.is_expected_to_raise_exception():
|
||||||
|
if fn_error is None and op_error is None:
|
||||||
|
_report(f, invocation, f"Expected to raise an exception, but neither {fn_type} function n or op raised an exception")
|
||||||
|
if fn_error is None:
|
||||||
|
_report(f, invocation, f"Op raised error {op_error!r}, but shape function did not.")
|
||||||
|
if op_error is None:
|
||||||
|
_report(f, invocation, f"{fn_type} function raised error {fn_error!r}, but op did not.")
|
||||||
|
else:
|
||||||
|
if fn_error is not None and op_error is not None:
|
||||||
|
_report(f, invocation, f"Both {fn_type} function and op raised errors, but were not expected to. {fn_type} function raised error {fn_error!r} and op raised error {op_error!r}.")
|
||||||
|
if fn_error is not None:
|
||||||
|
_report(f, invocation, f"{fn_type} function raised error {fn_error!r} but op did not raise any error.")
|
||||||
|
if op_error is not None:
|
||||||
|
_report(f, invocation, f"Op raised error {op_error!r} but {fn_type} function did not raise any error.")
|
||||||
|
|
||||||
|
return fn_results, golden_results
|
||||||
|
|
||||||
|
def check_shape_function(invocations: List[Invocation]):
|
||||||
|
"""Decorator that automatically tests a shape function.
|
||||||
|
|
||||||
|
The shape function, which is expected to be named systematically with
|
||||||
|
`〇` instead of `.`, is tested against the corresponding op in
|
||||||
|
`torch.ops.*` function using the given invocations.
|
||||||
|
"""
|
||||||
|
def decorator(f):
|
||||||
|
for invocation in invocations:
|
||||||
|
result_shapes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||||
|
if invocation.is_expected_to_raise_exception():
|
||||||
|
continue
|
||||||
|
# Check for matching results.
|
||||||
|
if len(result_shapes) != len(golden_results):
|
||||||
|
_report(f, invocation, f"Expected {len(golden_results)} result shapes, got {len(result_shapes)}")
|
||||||
|
for result_shape, golden_result in zip(result_shapes, golden_results):
|
||||||
|
result_rank = len(result_shape)
|
||||||
|
golden_rank = len(golden_result.shape)
|
||||||
|
if result_rank != golden_rank:
|
||||||
|
_report(f, invocation, f"Expected result rank {golden_rank}, got {result_rank}")
|
||||||
|
for dimension_size, golden_dimension_size in zip(result_shape, golden_result.shape):
|
||||||
|
if dimension_size != golden_dimension_size:
|
||||||
|
_report(f, invocation, f"Expected result shape {golden_result.shape}, got {result_shape}")
|
||||||
|
return f
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def check_dtype_function(invocations: List[Invocation]):
|
||||||
|
"""Decorator that automatically tests a dtype function.
|
||||||
|
|
||||||
|
The dtype function, which is expected to be named systematically with
|
||||||
|
`〇` instead of `.`, is tested against the corresponding op in
|
||||||
|
`torch.ops.*` function using the given invocations.
|
||||||
|
"""
|
||||||
|
def decorator(f):
|
||||||
|
for invocation in invocations:
|
||||||
|
result_dtypes, golden_results = _get_fn_and_golden_results(f, invocation)
|
||||||
|
if invocation.is_expected_to_raise_exception():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(result_dtypes) != len(golden_results):
|
||||||
|
_report(f, invocation, f"Expected {len(golden_results)} result dtypes, got {len(result_dtypes)}")
|
||||||
|
for result_dtype, golden_result in zip(result_dtypes, golden_results):
|
||||||
|
if isinstance(golden_result, torch.Tensor):
|
||||||
|
golden_dtype = golden_result.dtype
|
||||||
|
elif isinstance(golden_result, (int, float)):
|
||||||
|
# Turn Python type to PyTorch dtype
|
||||||
|
golden_dtype = torch.tensor([]).to(type(golden_result)).dtype
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unhandled return type {type(golden_result)}")
|
||||||
|
if result_dtype != golden_dtype:
|
||||||
|
_report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}")
|
||||||
|
return f
|
||||||
|
return decorator
|
|
@ -327,13 +327,26 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||||
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
|
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
|
||||||
});
|
});
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
std::string functionName = node->input(0)->node()->s(c10::attr::name);
|
||||||
appendToBlock, "func.call_indirect", loc,
|
std::vector<MlirType> resultTypes =
|
||||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
getMlirTypesFromValues(loc, node->outputs(), importOptions);
|
||||||
lookupMappedValue(node->input(0)),
|
std::vector<MlirValue> adjustedFuncArgs = adjustStaticInformationForValues(
|
||||||
adjustStaticInformationForValues(
|
|
||||||
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
|
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
|
||||||
expectedTypes, /*userAllowsRefinement=*/false));
|
expectedTypes, /*userAllowsRefinement=*/false);
|
||||||
|
MlirOperation operation;
|
||||||
|
// `__torch_mlir_internal_promote_dtypes` is a special python function that
|
||||||
|
// users can use in dtype refinement function definitions to get the
|
||||||
|
// promoted result dtype for a PyTorch computation. Here we turn the call to
|
||||||
|
// this function to the torch dialect equivalent op `torch.promote_dtypes`.
|
||||||
|
if (functionName == "__torch_mlir_internal_promote_dtypes") {
|
||||||
|
operation =
|
||||||
|
createMlirOperationAtEnd(appendToBlock, "torch.promote_dtypes", loc,
|
||||||
|
resultTypes, adjustedFuncArgs);
|
||||||
|
} else {
|
||||||
|
operation = createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "func.call_indirect", loc, resultTypes,
|
||||||
|
lookupMappedValue(node->input(0)), adjustedFuncArgs);
|
||||||
|
}
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
// RUN: torch-mlir-opt -torch-drop-shape-calculations -split-input-file %s | FileCheck %s
|
// RUN: torch-mlir-opt -torch-drop-abstract-interp-calculations -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @basic(
|
// CHECK-LABEL: func.func @basic$shape_calculate(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
||||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
|
||||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
|
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<[2,?],unk> to !torch.vtensor
|
||||||
// CHECK: return %[[ERASED]] : !torch.vtensor
|
// CHECK: return %[[ERASED]] : !torch.vtensor
|
||||||
func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
func.func @basic$shape_calculate(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%0 = torch.shape.calculate {
|
%0 = torch.shape.calculate {
|
||||||
|
@ -22,6 +22,26 @@ func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @basic$dtype_calculate(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[INT_6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||||
|
func.func @basic$dtype_calculate(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%2 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
|
torch.dtype.calculate.yield %2 : !torch.vtensor<*,f32>
|
||||||
|
} dtypes {
|
||||||
|
torch.dtype.calculate.yield.dtypes %int6 : !torch.int
|
||||||
|
} : !torch.vtensor<*,f32>
|
||||||
|
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
|
return %1 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @shape_calc_in_loop(
|
// CHECK-LABEL: func.func @shape_calc_in_loop(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
||||||
func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
|
@ -142,6 +142,22 @@ func.func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||||
return %0 : !torch.vtensor
|
return %0 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func.func @dtype_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor
|
||||||
|
} dtypes {
|
||||||
|
%2 = torch.prim.dtype %arg0 : !torch.vtensor -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %2 : !torch.int
|
||||||
|
} : !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func.func @promote_dtypes(%ranks: !torch.list<optional<int>>, %dtypes: !torch.list<int>) -> !torch.int {
|
||||||
|
%0 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
|
func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
|
||||||
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
|
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,109 +0,0 @@
|
||||||
// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_tensor$same_category_different_width(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f64>,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
|
|
||||||
%t1: !torch.vtensor<[1],f64>,
|
|
||||||
%alpha: !torch.float) {
|
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_tensor$different_category(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si32>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f64>,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
|
||||||
%t1: !torch.vtensor<[1],f64>,
|
|
||||||
%alpha: !torch.float) {
|
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_tensor$same_category_zero_rank_wider(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f64>,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_tensor$same_category_zero_rank_wider(
|
|
||||||
%t0: !torch.vtensor<[1],f32>,
|
|
||||||
%t1: !torch.vtensor<[],f64>,
|
|
||||||
%alpha: !torch.int) {
|
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_tensor$zero_rank_higher_category(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
|
|
||||||
%t1: !torch.vtensor<[],f32>,
|
|
||||||
%alpha: !torch.int) {
|
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_tensor$alpha_wider_no_contribution(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1],f32>,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.float) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
|
|
||||||
%t1: !torch.vtensor<[1],f32>,
|
|
||||||
%alpha: !torch.float) {
|
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_scalar$scalar_higher_category(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si64>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.float,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
|
|
||||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @tensor_scalar$scalar_same_category_wider(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1],si32>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int,
|
|
||||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.int) {
|
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
|
||||||
// CHECK: return
|
|
||||||
func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
|
|
||||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -121,14 +121,14 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @f
|
// CHECK-LABEL: func.func @f
|
||||||
// CHECK: %[[ATEN:.*]] = torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||||
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
cf.br ^bb1(%cast: !torch.vtensor)
|
cf.br ^bb1(%cast: !torch.vtensor)
|
||||||
^bb1(%arg1: !torch.vtensor):
|
^bb1(%arg1: !torch.vtensor):
|
||||||
%1 = torch.aten.tanh %arg1 : !torch.vtensor -> !torch.vtensor
|
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
|
||||||
return %1 : !torch.vtensor
|
return %1 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,11 +136,11 @@ func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @f
|
// CHECK-LABEL: func.func @f
|
||||||
// CHECK: func.func private @callee
|
// CHECK: func.func private @callee
|
||||||
// CHECK-NEXT: torch.aten.tanh %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||||
func.func @f() {
|
func.func @f() {
|
||||||
builtin.module {
|
builtin.module {
|
||||||
func.func private @callee(%arg0: !torch.vtensor) {
|
func.func private @callee(%arg0: !torch.vtensor) {
|
||||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
|
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
|
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
|
||||||
|
|
|
@ -9,36 +9,36 @@
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @basic(
|
// CHECK-LABEL: func.func @basic(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[TANH]] : !torch.vtensor<*,f32> to !torch.vtensor
|
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||||
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||||
return %1 : !torch.vtensor
|
return %1 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @keep_existing_shape_information(
|
// CHECK-LABEL: func.func @keep_existing_shape_information(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||||
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
||||||
// CHECK: return %[[TANH]] : !torch.vtensor<[2],f32>
|
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
|
||||||
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
|
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
|
||||||
return %1 : !torch.vtensor<[2],f32>
|
return %1 : !torch.vtensor<[2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
|
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[TANH2:.*]] = torch.aten.tanh %[[TANH1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[TANH3:.*]] = torch.tensor_static_info_cast %[[TANH2]] : !torch.vtensor<*,f32> to !torch.vtensor
|
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
// CHECK: return %[[TANH3]] : !torch.vtensor
|
// CHECK: return %[[COS3]] : !torch.vtensor
|
||||||
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||||
%2 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
|
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||||
%3 = torch.aten.tanh %2 : !torch.vtensor -> !torch.vtensor
|
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
|
||||||
return %3 : !torch.vtensor
|
return %3 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,99 +47,18 @@ func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torc
|
||||||
// refinement.
|
// refinement.
|
||||||
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
|
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
// CHECK: %[[TANH0:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[TANH0]] : !torch.vtensor<*,f32> to !torch.vtensor
|
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
// CHECK: %[[TANH1:.*]] = torch.aten.tanh %[[TANH0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
|
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
|
||||||
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||||
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||||
%3 = torch.aten.tanh %1 : !torch.vtensor -> !torch.vtensor
|
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @type_promotion$same_category_different_width(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>,
|
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
|
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],si64>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],si64> to !torch.vtensor<[?],unk>
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
|
|
||||||
func.func @type_promotion$same_category_different_width(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?],unk> {
|
|
||||||
%int3 = torch.constant.int 3
|
|
||||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?],unk>
|
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @type_promotion$different_category(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.int 3
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
|
|
||||||
func.func @type_promotion$different_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
%int3 = torch.constant.int 3
|
|
||||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int3 : !torch.vtensor<[?],si64>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],unk>
|
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @type_promotion$same_category_zero_rank_wider(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
|
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],f32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
|
|
||||||
func.func @type_promotion$same_category_zero_rank_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f64>) -> !torch.vtensor<[?],unk> {
|
|
||||||
%float2.300000e00 = torch.constant.float 2.300000e+00
|
|
||||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f64>, !torch.float -> !torch.vtensor<[?],unk>
|
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @type_promotion$zero_rank_higher_category(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si64>,
|
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.int 2
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],f32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
|
|
||||||
func.func @type_promotion$zero_rank_higher_category(%arg0: !torch.vtensor<[?],si64>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
%int2 = torch.constant.int 2
|
|
||||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[?],unk>
|
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @type_promotion$alpha_wider(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.300000e+00
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[ARG1]], %[[ALPHA]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],f32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[ADD]] : !torch.vtensor<[?],f32> to !torch.vtensor<[?],unk>
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?],unk>
|
|
||||||
func.func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],unk> {
|
|
||||||
%float2.300000e00 = torch.constant.float 2.300000e+00
|
|
||||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %float2.300000e00 : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[?],unk>
|
|
||||||
return %0 : !torch.vtensor<[?],unk>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @type_promotion_scalar_operation(
|
|
||||||
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
|
|
||||||
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
|
|
||||||
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
|
|
||||||
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
|
|
||||||
// CHECK: return %[[RET]] : !torch.number
|
|
||||||
func.func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
|
|
||||||
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
|
|
||||||
return %ret : !torch.number
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||||
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> {
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-reify-dtype-calculations -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: module {
|
||||||
|
// CHECK: func.func private @__torch_mlir_dtype_fn.aten.tanh(
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @basic(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.dtype.calculate {
|
||||||
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor -> !torch.vtensor
|
||||||
|
// CHECK: torch.dtype.calculate.yield %[[TANH]] : !torch.vtensor
|
||||||
|
// CHECK: } dtypes {
|
||||||
|
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor -> !torch.list<int>
|
||||||
|
// CHECK: %[[RANK:.*]] = torch.aten.len.t %[[SIZE]] : !torch.list<int> -> !torch.int
|
||||||
|
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG]] : !torch.vtensor -> !torch.int
|
||||||
|
// CHECK: %[[RESULT_DTYPE:.*]] = func.call @__torch_mlir_dtype_fn.aten.tanh(%[[RANK]], %[[DTYPE]]) : (!torch.int, !torch.int) -> !torch.int
|
||||||
|
// CHECK: torch.dtype.calculate.yield.dtypes %[[RESULT_DTYPE]] : !torch.int
|
||||||
|
// CHECK: } : !torch.vtensor
|
||||||
|
// CHECK: return %[[RESULT:.*]] : !torch.vtensor
|
||||||
|
func.func @basic(%arg0: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
%0 = torch.aten.tanh %arg0 : !torch.vtensor -> !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func private @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(
|
||||||
|
// CHECK: {{.*}} = torch.promote_dtypes {{.*}} : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
|
||||||
|
// CHECK: {{.*}} = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes({{.*}}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @op_with_dtype_promotion(
|
||||||
|
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide({{.*}}
|
||||||
|
func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.floor_divide(
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @turn_tensors_into_rank_and_dtype_args(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[SIZE0:.*]] = torch.aten.size %[[ARG0]] : !torch.vtensor -> !torch.list<int>
|
||||||
|
// CHECK: %[[RANK0:.*]] = torch.aten.len.t %[[SIZE0]] : !torch.list<int> -> !torch.int
|
||||||
|
// CHECK: %[[DTYPE0:.*]] = torch.prim.dtype %[[ARG0]] : !torch.vtensor -> !torch.int
|
||||||
|
// CHECK: %[[SIZE1:.*]] = torch.aten.size %[[ARG1]] : !torch.vtensor -> !torch.list<int>
|
||||||
|
// CHECK: %[[RANK1:.*]] = torch.aten.len.t %[[SIZE1]] : !torch.list<int> -> !torch.int
|
||||||
|
// CHECK: %[[DTYPE1:.*]] = torch.prim.dtype %[[ARG1]] : !torch.vtensor -> !torch.int
|
||||||
|
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.floor_divide(%[[RANK0]], %[[DTYPE0]], %[[RANK1]], %[[DTYPE1]]) : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.int
|
||||||
|
func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||||
|
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
|
@ -0,0 +1,285 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-simplify-dtype-calculations -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @basic(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[DTYPE_INT:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch.dtype.calculate {
|
||||||
|
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||||
|
// CHECK: torch.dtype.calculate.yield %[[TANH]] : !torch.vtensor<*,f32>
|
||||||
|
// CHECK: } dtypes {
|
||||||
|
// CHECK: torch.dtype.calculate.yield.dtypes %[[DTYPE_INT]] : !torch.int
|
||||||
|
// CHECK: } : !torch.vtensor<*,f32>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RESULT]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||||
|
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.tanh %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor
|
||||||
|
} dtypes {
|
||||||
|
%2 = torch.prim.dtype %arg0 : !torch.vtensor<*,f32> -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %2 : !torch.int
|
||||||
|
} : !torch.vtensor
|
||||||
|
return %0 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_same_category_different_width(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f64>
|
||||||
|
func.func @promote_dtypes$tensor_tensor_same_category_different_width(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[1],f64>, %arg2: !torch.float) {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
|
||||||
|
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
|
||||||
|
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[1],f64> -> !torch.int
|
||||||
|
%dtypes = torch.prim.ListConstruct %f32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %3 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_different_category(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f64>
|
||||||
|
func.func @promote_dtypes$tensor_tensor_different_category(%arg0: !torch.vtensor<[1],si32>, %arg1: !torch.vtensor<[1],f64>, %arg2: !torch.float) {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%si32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si32> -> !torch.int
|
||||||
|
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[1],f64> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
|
||||||
|
%dtypes = torch.prim.ListConstruct %si32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %3 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_same_category_zero_rank_wider(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
|
||||||
|
func.func @promote_dtypes$tensor_tensor_same_category_zero_rank_wider(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.int) {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
|
||||||
|
%f64_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[],f64> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
|
||||||
|
%dtypes = torch.prim.ListConstruct %f32_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %3 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_zero_rank_higher_category(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
|
||||||
|
func.func @promote_dtypes$tensor_tensor_zero_rank_higher_category(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.int) {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%si64_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%f32_dtype = torch.prim.dtype %arg1 : !torch.vtensor<[],f32> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>
|
||||||
|
%dtypes = torch.prim.ListConstruct %si64_dtype, %f32_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %3 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_tensor_alpha_wider_no_contribution(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Tensor {{.*}} -> !torch.vtensor<[1],f32>
|
||||||
|
func.func @promote_dtypes$tensor_tensor_alpha_wider_no_contribution(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.float) {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%f32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],f32> -> !torch.int
|
||||||
|
%alpha_as_tensor = torch.prim.NumToTensor.Scalar %arg2 : !torch.float -> !torch.tensor<[],f64>
|
||||||
|
%f64_dtype = torch.prim.dtype %alpha_as_tensor : !torch.tensor<[],f64> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %int1, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>
|
||||||
|
%dtypes = torch.prim.ListConstruct %f32_dtype, %f32_dtype, %f64_dtype : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %3 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_scalar_scalar_higher_category(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Scalar {{.*}} -> !torch.vtensor<[1],f32>
|
||||||
|
func.func @promote_dtypes$tensor_scalar_scalar_higher_category(%arg0: !torch.vtensor<[1],si64>, %arg1: !torch.float, %arg2: !torch.int) {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Scalar %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%si64_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>
|
||||||
|
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.float -> !torch.tensor<[],f64>
|
||||||
|
%f64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],f64> -> !torch.int
|
||||||
|
%dtypes = torch.prim.ListConstruct %si64_dtype, %f64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %5 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$tensor_scalar_scalar_same_category_wider(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add.Scalar {{.*}} -> !torch.vtensor<[1],si32>
|
||||||
|
func.func @promote_dtypes$tensor_scalar_scalar_same_category_wider(%arg0: !torch.vtensor<[1],si32>, %arg1: !torch.int, %arg2: !torch.int) {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add.Scalar %arg0, %arg1, %arg2 : !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.vtensor<[1],unk>
|
||||||
|
} dtypes {
|
||||||
|
%si32_dtype = torch.prim.dtype %arg0 : !torch.vtensor<[1],si32> -> !torch.int
|
||||||
|
%ranks = torch.prim.ListConstruct %int1, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>
|
||||||
|
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.tensor<[],si64>
|
||||||
|
%si64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],si64> -> !torch.int
|
||||||
|
%dtypes = torch.prim.ListConstruct %si32_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %5 : !torch.int
|
||||||
|
} : !torch.vtensor<[1],unk>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$scalar_scalar_different_category(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.float
|
||||||
|
func.func @promote_dtypes$scalar_scalar_different_category(%arg0: !torch.float, %arg1: !torch.int) -> !torch.number {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.float, !torch.int -> !torch.number
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.number
|
||||||
|
} dtypes {
|
||||||
|
%ranks = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>
|
||||||
|
%arg0_as_tensor = torch.prim.NumToTensor.Scalar %arg0 : !torch.float -> !torch.tensor<[],f64>
|
||||||
|
%f64_dtype = torch.prim.dtype %arg0_as_tensor : !torch.tensor<[],f64> -> !torch.int
|
||||||
|
%arg1_as_tensor = torch.prim.NumToTensor.Scalar %arg1 : !torch.int -> !torch.tensor<[],si64>
|
||||||
|
%si64_dtype = torch.prim.dtype %arg1_as_tensor : !torch.tensor<[],si64> -> !torch.int
|
||||||
|
%dtypes = torch.prim.ListConstruct %f64_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %7 : !torch.int
|
||||||
|
} : !torch.number
|
||||||
|
return %0 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @promote_dtypes$scalar_scalar_same_category(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
|
||||||
|
func.func @promote_dtypes$scalar_scalar_same_category(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.number
|
||||||
|
} dtypes {
|
||||||
|
%ranks = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>
|
||||||
|
%arg0_as_tensor = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.tensor<[],si64>
|
||||||
|
%si64_dtype = torch.prim.dtype %arg0_as_tensor : !torch.tensor<[],si64> -> !torch.int
|
||||||
|
%dtypes = torch.prim.ListConstruct %si64_dtype, %si64_dtype : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.promote_dtypes %ranks, %dtypes : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int
|
||||||
|
torch.dtype.calculate.yield.dtypes %7 : !torch.int
|
||||||
|
} : !torch.number
|
||||||
|
return %0 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @refine_dtype$invalid_dtype_for_scalar(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.number
|
||||||
|
func.func @refine_dtype$invalid_dtype_for_scalar(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.number
|
||||||
|
} dtypes {
|
||||||
|
// dtype int for int32
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
torch.dtype.calculate.yield.dtypes %int3 : !torch.int
|
||||||
|
} : !torch.number
|
||||||
|
return %0 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @refine_dtype$no_simplification
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.number
|
||||||
|
func.func @refine_dtype$no_simplification(%arg0: !torch.int, %arg1: !torch.int, %dtype: !torch.int) -> !torch.number {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.number
|
||||||
|
} dtypes {
|
||||||
|
torch.dtype.calculate.yield.dtypes %dtype : !torch.int
|
||||||
|
} : !torch.number
|
||||||
|
return %0 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// If result type is already refined (even if wrong, as is the case here),
|
||||||
|
// don't make any changes to result type.
|
||||||
|
// TODO: This case should result in an error
|
||||||
|
// CHECK-LABEL: func.func @refine_dtype$result_type_already_refined
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
|
||||||
|
func.func @refine_dtype$result_type_already_refined(%arg0: !torch.float, %arg1: !torch.float) -> !torch.int {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.float, !torch.float -> !torch.int
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.int
|
||||||
|
} dtypes {
|
||||||
|
// dtype int for float64
|
||||||
|
%int7 = torch.constant.int 7
|
||||||
|
torch.dtype.calculate.yield.dtypes %int7 : !torch.int
|
||||||
|
} : !torch.int
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @refine_dtype$derefine_result_type(
|
||||||
|
// CHECK: {{.*}} = torch.aten.add {{.*}} -> !torch.int
|
||||||
|
// CHECK: %[[ERASED:.*]] = torch.derefine {{.*}} : !torch.int to !torch.number
|
||||||
|
// CHECK: return %[[ERASED]] : !torch.number
|
||||||
|
func.func @refine_dtype$derefine_result_type(%arg0: !torch.int, %arg1: !torch.int) -> !torch.number {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.dtype.calculate {
|
||||||
|
%1 = torch.aten.add %arg0, %arg1 : !torch.int, !torch.int -> !torch.number
|
||||||
|
torch.dtype.calculate.yield %1 : !torch.number
|
||||||
|
} dtypes {
|
||||||
|
// dtype int for int64
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
torch.dtype.calculate.yield.dtypes %int4 : !torch.int
|
||||||
|
} : !torch.number
|
||||||
|
return %0 : !torch.number
|
||||||
|
}
|
Loading…
Reference in New Issue