mirror of https://github.com/llvm/torch-mlir
Merge main into dtype-functions-staging (#1935)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com> Signed-off-by: Prateek Gupta <prateek.gupta2@cerebras.net> Co-authored-by: Jiahao Li <liplus17@163.com> Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com> Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com> Co-authored-by: Chi_Liu <chi@nod.ai> Co-authored-by: Victor Guerra <vguerra@gmail.com> Co-authored-by: Victor Guerra <vm.guerramoran@criteo.com> Co-authored-by: powderluv <powderluv@users.noreply.github.com> Co-authored-by: Ashay Rane <ashay@users.noreply.github.com> Co-authored-by: Eric Kunze <eric.kunze@arm.com> Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com> Co-authored-by: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Co-authored-by: Yi Wang <yi.wang.2005@gmail.com> Co-authored-by: Sean Silva <silvasean@google.com> Co-authored-by: Zachary Cetinic <zachattack242@Hotmail.com> Co-authored-by: Tanyo Kwok <tianyou.gty@alibaba-inc.com> Co-authored-by: Zachary Cetinic <zacharycetinic@gmail.com> Co-authored-by: Kunwar Grover <51270680+Groverkss@users.noreply.github.com> Co-authored-by: Ziheng Jiang <ziheng@apache.org> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Maksim Levental <maksim.levental@gmail.com> Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com> Co-authored-by: Prateek Gupta <108802984+prateekgu-cerebras@users.noreply.github.com> Co-authored-by: nvda <nvda@stanford.edu> Co-authored-by: Ahmed S. Taei <asaadaldien@users.noreply.github.com> Co-authored-by: Priya Savithiri <104089347+PriyaBSavithiri@users.noreply.github.com> Co-authored-by: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com> Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com> Co-authored-by: Kan Chen <chenkanhw@163.com> Co-authored-by: gpetters94 <gpetters@protonmail.com>pull/1943/head
parent
ce7abf4911
commit
042d58b699
|
@ -17,7 +17,7 @@ runs:
|
|||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install MLIR Python depends
|
||||
run: |
|
||||
|
@ -26,7 +26,8 @@ runs:
|
|||
|
||||
- name: Install PyTorch nightly depends
|
||||
run: |
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r pytorch-requirements.txt
|
||||
python -m pip install -r build-requirements.txt
|
||||
shell: bash
|
||||
|
||||
- name: Install prerequisites (Linux)
|
||||
|
|
|
@ -8,12 +8,19 @@ on:
|
|||
jobs:
|
||||
build_linux:
|
||||
name: Manylinux Build
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: a100
|
||||
|
||||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'llvm/torch-mlir'
|
||||
|
||||
steps:
|
||||
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Get torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
@ -31,6 +38,7 @@ jobs:
|
|||
|
||||
cd ${GITHUB_WORKSPACE}
|
||||
python -m pip install wheel
|
||||
sudo apt-get install unzip
|
||||
|
||||
# Fetch the most recent nightly torchvision release
|
||||
VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/')
|
||||
|
@ -44,7 +52,8 @@ jobs:
|
|||
# Read the version from the downloaded whl file without extracting it
|
||||
PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/')
|
||||
echo "Found torch release ${PT_RELEASE}"
|
||||
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\ntorchvision==%s\n" "${PT_RELEASE}" "${VISION_RELEASE}" > pytorch-requirements.txt
|
||||
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt
|
||||
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt
|
||||
|
||||
# Read the commit hash from the downloaded whl file without extracting it
|
||||
PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'")
|
||||
|
@ -96,7 +105,7 @@ jobs:
|
|||
git fetch --recurse-submodules=no
|
||||
git checkout main
|
||||
git pull origin main
|
||||
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||
git add pytorch-hash.txt pytorch-requirements.txt torchvision-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
|
||||
|
||||
- name: Update PyTorch Build Cache (if running on main branch)
|
||||
|
|
|
@ -20,6 +20,12 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Checkout torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
|
|
@ -51,6 +51,14 @@ jobs:
|
|||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
|
||||
- name: Prepare workspace
|
||||
if: ${{ matrix.os-arch == 'ubuntu-x86_64' }}
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Checkout torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
@ -113,7 +121,7 @@ jobs:
|
|||
-DLLVM_USE_HOST_TOOLS=ON \
|
||||
-DLLVM_ENABLE_ZSTD=OFF \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DTORCH_MLIR_ENABLE_MHLO=OFF \
|
||||
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
|
||||
-DTORCH_MLIR_ENABLE_LTC=OFF \
|
||||
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
|
||||
-DMACOSX_DEPLOYMENT_TARGET=12.0 \
|
||||
|
|
|
@ -13,8 +13,25 @@ on:
|
|||
jobs:
|
||||
build_linux:
|
||||
name: Manylinux Build
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: a100
|
||||
strategy:
|
||||
matrix:
|
||||
package: [ torch-mlir, torch-mlir-core ]
|
||||
py_version: [ cp38-cp38, cp310-cp310, cp311-cp311 ]
|
||||
exclude:
|
||||
- package: torch-mlir-core
|
||||
py_version: cp38-cp38
|
||||
- package: torch-mlir-core
|
||||
py_version: cp310-cp310
|
||||
|
||||
steps:
|
||||
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Get torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
|
@ -28,7 +45,7 @@ jobs:
|
|||
python -m pip install wheel
|
||||
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
|
||||
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
|
||||
./build_tools/python_deploy/build_linux_packages.sh
|
||||
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh
|
||||
|
||||
# If we were given a release_id, then upload the package we just built
|
||||
# to the github releases page.
|
||||
|
@ -56,7 +73,7 @@ jobs:
|
|||
run: mkdir dist
|
||||
- name: Copy releases to publish to dist directory
|
||||
if: github.event.inputs.release_id != ''
|
||||
run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
|
||||
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/
|
||||
|
||||
# Wheels must be published from a linux environment.
|
||||
#
|
||||
|
@ -70,6 +87,9 @@ jobs:
|
|||
build_macos:
|
||||
name: MacOS Build
|
||||
runs-on: macos-latest
|
||||
strategy:
|
||||
matrix:
|
||||
package: [ torch-mlir, torch-mlir-core ]
|
||||
steps:
|
||||
- name: Get torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
|
@ -85,7 +105,7 @@ jobs:
|
|||
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
|
||||
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
|
||||
sudo ./build_tools/python_deploy/install_macos_deps.sh
|
||||
TORCH_MLIR_PYTHON_VERSIONS="3.10" ./build_tools/python_deploy/build_macos_packages.sh
|
||||
packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh
|
||||
|
||||
# If we were given a release_id, then upload the package we just built
|
||||
# to the github releases page.
|
||||
|
@ -113,7 +133,7 @@ jobs:
|
|||
run: mkdir dist
|
||||
- name: Copy releases to publish to dist directory
|
||||
if: github.event.inputs.release_id != ''
|
||||
run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
|
||||
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/
|
||||
|
||||
# Wheels must be published from a linux environment.
|
||||
#
|
||||
|
@ -127,6 +147,9 @@ jobs:
|
|||
build_windows:
|
||||
name: Windows Build
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
package: [ torch-mlir, torch-mlir-core ]
|
||||
steps:
|
||||
- name: Get torch-mlir
|
||||
uses: actions/checkout@v3
|
||||
|
@ -142,6 +165,14 @@ jobs:
|
|||
- name: Build Python wheels and smoke test.
|
||||
shell: pwsh
|
||||
run: |
|
||||
if ( "${{ matrix.package }}" -eq "torch-mlir-core" )
|
||||
{
|
||||
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='0'
|
||||
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='1'
|
||||
} else {
|
||||
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1'
|
||||
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0'
|
||||
}
|
||||
$env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}'
|
||||
./build_tools/python_deploy/build_windows.ps1
|
||||
|
||||
|
@ -172,7 +203,7 @@ jobs:
|
|||
continue-on-error: true
|
||||
- name: Copy releases to publish to dist directory
|
||||
if: github.event.inputs.release_id != ''
|
||||
run: cp ./wheelhouse/torch_mlir-*.whl dist/
|
||||
run: cp ./wheelhouse/torch_mlir*.whl dist/
|
||||
|
||||
# Wheels must be published from a linux environment.
|
||||
#
|
||||
|
@ -216,4 +247,4 @@ jobs:
|
|||
# if: github.event.inputs.release_id != ''
|
||||
# uses: pypa/gh-action-pypi-publish@v1.5.1
|
||||
# with:
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
|
|
@ -13,8 +13,13 @@ jobs:
|
|||
if: github.repository == 'llvm/torch-mlir'
|
||||
|
||||
steps:
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}
|
||||
- name: Run scrape releases script
|
||||
|
|
|
@ -10,8 +10,14 @@ jobs:
|
|||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'llvm/torch-mlir'
|
||||
steps:
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}
|
||||
|
||||
|
|
|
@ -13,8 +13,15 @@ jobs:
|
|||
# Don't run this in everyone's forks.
|
||||
if: github.repository == 'llvm/torch-mlir'
|
||||
steps:
|
||||
|
||||
- name: Prepare workspace
|
||||
run: |
|
||||
# Clear the workspace directory so that we don't run into errors about
|
||||
# existing lock files.
|
||||
sudo rm -rf $GITHUB_WORKSPACE/*
|
||||
|
||||
- name: Checking out repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}
|
||||
|
||||
|
|
|
@ -32,3 +32,6 @@ bazel-*
|
|||
build_oot/
|
||||
docker_venv/
|
||||
llvm-build/
|
||||
|
||||
# C++ build artifacts
|
||||
compile_commands.json
|
||||
|
|
|
@ -36,12 +36,18 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
|
|||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||
endmacro()
|
||||
|
||||
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
||||
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
endif()
|
||||
|
||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
|
||||
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
||||
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF)
|
||||
set(TORCH_MLIR_ENABLE_LTC OFF)
|
||||
endif()
|
||||
|
||||
if(TORCH_MLIR_ENABLE_LTC)
|
||||
set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
|
||||
|
@ -109,7 +115,6 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OUT_OF_TREE_
|
|||
# Don't try to compile the python extensions at the moment. We need
|
||||
# to import lots of dependencies from AddMLIRPython to make this work.
|
||||
set(MLIR_ENABLE_BINDINGS_PYTHON 1)
|
||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
||||
|
||||
set(TORCH-MLIR_BUILT_STANDALONE 1)
|
||||
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
|
||||
|
@ -119,7 +124,6 @@ else()
|
|||
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
|
||||
|
||||
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
||||
|
||||
# TODO: Fix this upstream so that global include directories are not needed.
|
||||
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
|
||||
|
@ -128,8 +132,8 @@ else()
|
|||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
||||
endif()
|
||||
|
||||
if (TORCH_MLIR_ENABLE_MHLO)
|
||||
set(MHLO_BUILD_EMBEDDED ON)
|
||||
if (TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
set(STABLEHLO_BUILD_EMBEDDED ON)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
||||
EXCLUDE_FROM_ALL)
|
||||
|
|
28
README.md
28
README.md
|
@ -8,13 +8,12 @@ necessarily a reflection of the completeness or stability of the code, it
|
|||
does indicate that the project is not yet endorsed as a component of LLVM.
|
||||
|
||||
[PyTorch](https://pytorch.org)
|
||||
An open source machine learning framework that accelerates the path from research prototyping to production deployment.
|
||||
PyTorch is an open source machine learning framework that facilitates the seamless transition from research and prototyping to production-level deployment.
|
||||
|
||||
[MLIR](https://mlir.llvm.org)
|
||||
The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together.
|
||||
|
||||
The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers.
|
||||
[Torch-MLIR](https://github.com/llvm/torch-mlir)
|
||||
Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks like PyTorch, JAX, and TensorFlow into MLIR and then progressively lowering down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem provides much needed relief to hardware vendors to focus on their unique value rather than implementing yet another PyTorch frontend for MLIR. The goal is to be similar to current hardware vendors adding LLVM target support instead of each one also implementing Clang / a C++ frontend.
|
||||
Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend.
|
||||
|
||||
[![Release Build](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml/badge.svg)](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml)
|
||||
|
||||
|
@ -43,15 +42,26 @@ We have few paths to lower down to the Torch MLIR Dialect.
|
|||
|
||||
## Install torch-mlir snapshot
|
||||
|
||||
This installs a pre-built snapshot of torch-mlir for Python 3.7/3.8/3.9/3.10 on Linux and macOS.
|
||||
At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS.
|
||||
|
||||
If you have Python 3.10, the following commands initialize a virtual environment.
|
||||
```shell
|
||||
python -m venv mlir_venv
|
||||
python3.10 -m venv mlir_venv
|
||||
source mlir_venv/bin/activate
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
```
|
||||
|
||||
Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10.
|
||||
```shell
|
||||
conda create -n torch-mlir python=3.10
|
||||
conda activate torch-mlir
|
||||
python -m pip install --upgrade pip
|
||||
pip install --pre torch-mlir torchvision -f https://llvm.github.io/torch-mlir/package-index/ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
# This will install the corresponding torch and torchvision nightlies
|
||||
```
|
||||
|
||||
Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.
|
||||
```
|
||||
pip install --pre torch-mlir torchvision \
|
||||
-f https://llvm.github.io/torch-mlir/package-index/
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
## Demos
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
-r pytorch-requirements.txt
|
||||
|
||||
numpy
|
||||
pybind11
|
||||
wheel
|
||||
|
|
|
@ -39,16 +39,16 @@ set -eu -o errtrace
|
|||
this_dir="$(cd "$(dirname "$0")" && pwd)"
|
||||
repo_root="$(cd "$this_dir"/../../ && pwd)"
|
||||
# This needs to be a manylinux image so we can ship pip packages
|
||||
TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}"
|
||||
TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:d8994b87b45b7b2e6055fccc32db018ec73aeb05a4e43a9daa61b77cc34f846e}"
|
||||
# This assumes an Ubuntu LTS like image. You can build your own with
|
||||
# ./build_tools/docker/Dockerfile
|
||||
TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}"
|
||||
# Version of Python to use in Release builds. Ignored in CIs.
|
||||
TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp310-cp310}"
|
||||
TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}"
|
||||
# Location to store Release wheels
|
||||
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
|
||||
# What "packages to build"
|
||||
TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
|
||||
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}"
|
||||
# Use pre-built Pytorch
|
||||
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
||||
# Skip running tests if you want quick iteration
|
||||
|
@ -84,6 +84,11 @@ function run_on_host() {
|
|||
export USERID=0
|
||||
export GROUPID=0
|
||||
;;
|
||||
torch-mlir-core)
|
||||
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
|
||||
export USERID=0
|
||||
export GROUPID=0
|
||||
;;
|
||||
out-of-tree)
|
||||
TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE}
|
||||
# CI uses only Python3.10
|
||||
|
@ -159,6 +164,12 @@ function run_in_docker() {
|
|||
|
||||
clean_build torch_mlir "$python_version"
|
||||
;;
|
||||
torch-mlir-core)
|
||||
clean_wheels torch_mlir_core "$python_version"
|
||||
build_torch_mlir_core
|
||||
run_audit_wheel torch_mlir_core "$python_version"
|
||||
clean_build torch_mlir_core "$python_version"
|
||||
;;
|
||||
out-of-tree)
|
||||
setup_venv "$python_version"
|
||||
build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
|
||||
|
@ -267,8 +278,8 @@ function test_in_tree() {
|
|||
echo ":::: Run Linalg e2e integration tests"
|
||||
python -m e2e_testing.main --config=linalg -v
|
||||
|
||||
echo ":::: Run MHLO e2e integration tests"
|
||||
python -m e2e_testing.main --config=mhlo -v
|
||||
echo ":::: Run StableHLO e2e integration tests"
|
||||
python -m e2e_testing.main --config=stablehlo -v
|
||||
|
||||
echo ":::: Run TOSA e2e integration tests"
|
||||
python -m e2e_testing.main --config=tosa -v
|
||||
|
@ -277,7 +288,7 @@ function test_in_tree() {
|
|||
python -m e2e_testing.main --config=lazy_tensor_core -v
|
||||
|
||||
echo ":::: Run TorchDynamo e2e integration tests"
|
||||
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
|
||||
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot
|
||||
}
|
||||
|
||||
function setup_venv() {
|
||||
|
@ -373,6 +384,15 @@ function run_audit_wheel() {
|
|||
rm "$generic_wheel"
|
||||
}
|
||||
|
||||
function build_torch_mlir_core() {
|
||||
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
|
||||
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
|
||||
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
|
||||
}
|
||||
|
||||
function clean_wheels() {
|
||||
local wheel_basename="$1"
|
||||
local python_version="$2"
|
||||
|
|
|
@ -20,7 +20,7 @@ set -eu -o errtrace
|
|||
|
||||
this_dir="$(cd "$(dirname "$0")" && pwd)"
|
||||
repo_root="$(cd "$this_dir"/../../ && pwd)"
|
||||
python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10}"
|
||||
python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10 3.11}"
|
||||
output_dir="${output_dir:-${this_dir}/wheelhouse}"
|
||||
packages="${packages:-torch-mlir}"
|
||||
|
||||
|
@ -61,6 +61,11 @@ function run() {
|
|||
build_torch_mlir torch_mlir "$python_version"
|
||||
run_audit_wheel torch_mlir "$python_version"
|
||||
;;
|
||||
torch-mlir-core)
|
||||
clean_wheels torch_mlir_core "$python_version"
|
||||
build_torch_mlir_core torch_mlir_core "$python_version"
|
||||
run_audit_wheel torch_mlir_core "$python_version"
|
||||
;;
|
||||
*)
|
||||
echo "Unrecognized package '$package'"
|
||||
exit 1
|
||||
|
@ -77,7 +82,8 @@ function build_torch_mlir() {
|
|||
python"${python_version}" -m venv "$output_dir"/build_venv
|
||||
source "$output_dir"/build_venv/bin/activate
|
||||
python"${python_version}" -m pip install -U pip
|
||||
python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
|
||||
|
@ -87,6 +93,25 @@ function build_torch_mlir() {
|
|||
rm -rf "$output_dir"/build_venv
|
||||
}
|
||||
|
||||
function build_torch_mlir_core() {
|
||||
local wheel_basename="$1"
|
||||
local python_version="$2"
|
||||
rm -rf "$output_dir"/build_venv
|
||||
python"${python_version}" -m venv "$output_dir"/build_venv
|
||||
source "$output_dir"/build_venv/bin/activate
|
||||
python"${python_version}" -m pip install -U pip delocate
|
||||
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
|
||||
CMAKE_GENERATOR=Ninja \
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
|
||||
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
|
||||
CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \
|
||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
|
||||
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
|
||||
python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root"
|
||||
deactivate
|
||||
rm -rf "$output_dir"/build_venv
|
||||
}
|
||||
|
||||
function clean_wheels() {
|
||||
local wheel_basename="$1"
|
||||
local python_version="$2"
|
||||
|
@ -107,7 +132,8 @@ function run_audit_wheel() {
|
|||
python"${python_version}" -m venv "$output_dir"/test_venv
|
||||
source "$output_dir"/test_venv/bin/activate
|
||||
python"${python_version}" -m pip install -U pip
|
||||
python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
|
||||
python"${python_version}" -m pip install "$generic_wheel" --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
DYLD_LIBRARY_PATH="$output_dir"/test_venv/lib/python"${python_version}"/site-packages/torch/lib delocate-wheel -v "$generic_wheel"
|
||||
deactivate
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
Write-Host "Installing Build Dependencies"
|
||||
python -m venv .\mlir_venv\
|
||||
.\mlir_venv\Scripts\Activate.PS1
|
||||
pip install -r .\requirements.txt
|
||||
pip install -r .\pytorch-requirements.txt
|
||||
pip install -r .\build-requirements.txt
|
||||
pip install delvewheel
|
||||
Write-Host "Build Deps installation completed successfully"
|
||||
|
||||
Write-Host "Building torch-mlir"
|
||||
|
@ -22,3 +24,7 @@ $env:TORCH_MLIR_ENABLE_LTC='0'
|
|||
python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt
|
||||
|
||||
Write-Host "Build completed successfully"
|
||||
|
||||
Write-Host "Fixing up wheel dependencies"
|
||||
delvewheel repair --add-path .\build\cmake_build\tools\torch-mlir\python_packages\torch_mlir\torch_mlir\_mlir_libs --add-dll TorchMLIRAggregateCAPI.dll --no-dll 'c10.dll;torch_python.dll;torch_cpu.dll' -v (get-item .\wheelhouse\torch_mlir*.whl).FullName
|
||||
Write-Host "All Done."
|
||||
|
|
|
@ -19,11 +19,13 @@ if [[ "$(whoami)" != "root" ]]; then
|
|||
fi
|
||||
|
||||
PYTHON_INSTALLER_URLS=(
|
||||
"https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg"
|
||||
"https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg"
|
||||
"https://www.python.org/ftp/python/3.10.10/python-3.10.10-macos11.pkg"
|
||||
"https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg"
|
||||
)
|
||||
|
||||
PYTHON_SPECS=(
|
||||
3.11@https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg
|
||||
3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg
|
||||
3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg
|
||||
)
|
||||
|
|
|
@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
|
|||
|
||||
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
|
||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
||||
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||
|
||||
The terms "frontend" and "backend" are highly overloaded in any compiler
|
||||
project, but frequently in Torch-MLIR this is the meaning that they have.
|
||||
Sometimes "frontend" can mean something even further up the stack, such as
|
||||
something in PyTorch itself. When there is ambiguity we will refer to this as
|
||||
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
|
||||
sitting below Linalg-on-Tensors, TOSA, or MHLO.
|
||||
sitting below Linalg-on-Tensors, TOSA, or StableHLO.
|
||||
|
||||
## The `torch` dialect
|
||||
|
||||
|
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
|
|||
|
||||
The backend contract is a normalized form of the `torch` dialect with a set of
|
||||
properties that make it easy to lower into various forms such as
|
||||
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
|
||||
box. The primary guarantees that we provide Torch-MLIR's backends are:
|
||||
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
|
||||
the box. The primary guarantees that we provide Torch-MLIR's backends are:
|
||||
|
||||
- All tensors have been converted to value semantics.
|
||||
- All tensors have at least a known number of dimensions (i.e. rank), and
|
||||
|
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
|
|||
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
|
||||
`tensor`, etc.)
|
||||
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
|
||||
- [MHLO](https://github.com/tensorflow/mlir-hlo)
|
||||
- [StableHLO](https://github.com/openxla/stablehlo)
|
||||
|
||||
### The Linalg Backend (Linalg-on-Tensors)
|
||||
|
||||
|
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
|
|||
- It is extremely solid with static shapes (and many of its users only care
|
||||
about static shapes, so that's fine).
|
||||
|
||||
### The MHLO Backend
|
||||
### The StableHLO Backend
|
||||
|
||||
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
|
||||
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
|
||||
|
||||
The MHLO backend was the third backend that we added, and it offers a reasonable
|
||||
blend of the benefits of the other two.
|
||||
The StableHLO backend was the third backend that we added, and it offers a
|
||||
reasonable blend of the benefits of the other two.
|
||||
- It is a coarse-grained named-op approach.
|
||||
- It has a pretty clear spec for most of the ops (with a bit of mental
|
||||
translation and hoping that MHLO is the same as HLO):
|
||||
translation and hoping that StableHLO is the same as HLO):
|
||||
https://www.tensorflow.org/xla/operation_semantics
|
||||
- It functionally supports dynamic shapes (though not as coherent and consistent
|
||||
as Linalg-on-Tensors, and the dynamic shape support falls outside the
|
||||
|
@ -317,7 +317,7 @@ blend of the benefits of the other two.
|
|||
example, TOSA limits (for highly considered reasons) the number of dimensions
|
||||
that certain operators can handle to 1D-4D, when from a purely algebraic
|
||||
perspective there isn't a good reason to not be more general. Similarly, more
|
||||
general forms of reduction and scatter also fall into MHLO nicely while
|
||||
general forms of reduction and scatter also fall into StableHLO nicely while
|
||||
TOSA's principles tend to bias it away from that.
|
||||
|
||||
### Backend Implementation
|
||||
|
@ -433,8 +433,9 @@ filling in some corners missing upstream and
|
|||
to pull together upstream functionality into a working system.
|
||||
|
||||
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
|
||||
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
|
||||
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
|
||||
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
|
||||
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
|
||||
RefBackend.
|
||||
|
||||
The RefBackend is absolutely not suitable for any production use case. It leaks
|
||||
memory, doesn't support any error handling, performs no optimizations, and
|
||||
|
|
|
@ -34,7 +34,7 @@ and Clang's
|
|||
- Eric Kunze (@eric-k256)
|
||||
- Suraj Sudhir (@sjarus)
|
||||
|
||||
### TorchToMHLO
|
||||
### TorchToStablehlo
|
||||
|
||||
- Tianyo Kwok (@tanyokwok)
|
||||
- Ziheng Jiang (@ZihengJiang)
|
||||
|
|
|
@ -139,7 +139,7 @@ Ex:
|
|||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
|
||||
```
|
||||
|
||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
|
||||
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
|
||||
|
||||
## Jupyter
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ the ecosystem are:
|
|||
|
||||
- The frontend work required to lower TorchScript to the backend contract.
|
||||
- The irregular support surface area of the large number of PyTorch ops across
|
||||
the Linalg, TOSA, and MHLO backends.
|
||||
the Linalg, TOSA, and StableHLO backends.
|
||||
|
||||
Most of this document describes long-term ecosystem changes that will address
|
||||
these, drastically improving Torch-MLIR's ability to meet its goals.
|
||||
|
@ -108,7 +108,7 @@ more advanced).
|
|||
### Refactoring the backend
|
||||
|
||||
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
|
||||
TOSA, and MHLO. These backends take IR in the backend contract form (see
|
||||
TOSA, and StableHLO. These backends take IR in the backend contract form (see
|
||||
[architecture.md](architecture.md)) and lowers them to the respective dialects.
|
||||
Today, each backend is implemented completely independently. This leads to
|
||||
duplication and irregularity across the backends.
|
||||
|
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
|
|||
forward-looking efforts that intersect with this effort:
|
||||
|
||||
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
|
||||
initially forked from MHLO which intends to create a stable support surface
|
||||
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
|
||||
complete op set, so it is very attractive to have "almost all" models
|
||||
bottleneck through a stable interface like StableHLO. StableHLO is currently
|
||||
under relatively early development, but already delivers on many of the goals
|
||||
of stability.
|
||||
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
|
||||
attractive to have "almost all" models bottleneck through a stable interface
|
||||
like StableHLO. StableHLO is currently under relatively early development,
|
||||
but already delivers on many of the goals of stability.
|
||||
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
|
||||
which could serve a role very similar to MHLO, while providing community
|
||||
ownership. TCP is still in early planning phases, but there is strong
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
|
|||
from torch_mlir_e2e_test.configs import (
|
||||
LazyTensorCoreTestConfig,
|
||||
LinalgOnTensorsBackendTestConfig,
|
||||
MhloBackendTestConfig,
|
||||
StablehloBackendTestConfig,
|
||||
NativeTorchTestConfig,
|
||||
TorchScriptTestConfig,
|
||||
TosaBackendTestConfig,
|
||||
|
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
|
|||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
|
||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||
|
||||
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
|
||||
|
||||
# Import tests to register them in the global registry.
|
||||
from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||
register_all_tests()
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
|
||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||
parser.add_argument("-c", "--config",
|
||||
choices=config_choices,
|
||||
|
@ -42,7 +42,7 @@ def _get_argparse():
|
|||
help=f"""
|
||||
Meaning of options:
|
||||
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
|
||||
"mhlo": run through torch-mlir"s default MHLO backend.
|
||||
"stablehlo": run through torch-mlir"s default StableHLO backend.
|
||||
"tosa": run through torch-mlir"s default TOSA backend.
|
||||
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
|
||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||
|
@ -80,9 +80,9 @@ def main():
|
|||
if args.config == "tosa":
|
||||
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
|
||||
xfail_set = all_test_unique_names - TOSA_PASS_SET
|
||||
if args.config == "mhlo":
|
||||
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
|
||||
xfail_set = all_test_unique_names - MHLO_PASS_SET
|
||||
if args.config == "stablehlo":
|
||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||
elif args.config == "native_torch":
|
||||
config = NativeTorchTestConfig()
|
||||
xfail_set = {}
|
||||
|
|
|
@ -26,6 +26,11 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# https://github.com/pytorch/pytorch/issues/89629
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
|
||||
# error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
|
||||
# https://github.com/llvm/torch-mlir/issues/1859
|
||||
"ConvolutionModule2DGroups_basic",
|
||||
|
||||
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
|
||||
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
|
||||
# https://github.com/pytorch/pytorch/issues/89630
|
||||
|
@ -39,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# RuntimeError: Failed running call_function aten.uniform(...
|
||||
# https://github.com/pytorch/torchdynamo/issues/1954
|
||||
"UniformNoCorrelationModule_basic",
|
||||
# TypeError: expected np.ndarray (got float)
|
||||
# TODO: This is due to returning a scalar float as output from the test.
|
||||
# We should probably just standardize all tests to return tensors.
|
||||
"DivIntModule_basic",
|
||||
|
||||
#### Torch-MLIR internal compiler errors
|
||||
|
||||
|
@ -66,14 +67,13 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
|
||||
"Matmul_vecmat",
|
||||
|
||||
# https://github.com/llvm/torch-mlir/issues/1611
|
||||
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
|
@ -83,8 +83,16 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# error: unsupported by backend contract: tensor with unknown rank
|
||||
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
||||
"ElementwisePreluModule_basic",
|
||||
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
|
||||
"StdCorrectionKeepDimModule_basic",
|
||||
|
||||
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
#ERROR: value (-56) is not equal to golden value (200)
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
|
||||
# Dtype function transition failures
|
||||
"MobilenetV3Module_basic",
|
||||
|
@ -92,8 +100,12 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ResNet18StaticModule_basic",
|
||||
}
|
||||
|
||||
MHLO_PASS_SET = {
|
||||
STABLEHLO_PASS_SET = {
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillScalarFloatValueStaticModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddSizeIntModule_basic",
|
||||
"AddSizeIntNegDimModule_basic",
|
||||
"ArangeDtypeFloatModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
|
@ -108,10 +120,15 @@ MHLO_PASS_SET = {
|
|||
"ArangeStartStepFloatModule_basic",
|
||||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"BatchMlpLayerModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
|
@ -126,19 +143,29 @@ MHLO_PASS_SET = {
|
|||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"ElementwiseLeakyReluModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSqrtModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
"ElementwiseCosModule_basic",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||
"ElementwiseEqFloatScalarModule_basic",
|
||||
|
@ -201,6 +228,8 @@ MHLO_PASS_SET = {
|
|||
"Gather2DInputModdule_basic",
|
||||
"GatherRandomIndexModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardsigmoidModule_basic",
|
||||
|
@ -223,6 +252,8 @@ MHLO_PASS_SET = {
|
|||
"MeanDynamicSizesModule_basic",
|
||||
"MeanLargeInputModule_basic",
|
||||
"MeanModule_basic",
|
||||
"Mlp1LayerModule_basic",
|
||||
"Mlp2LayerModule_basic",
|
||||
"MmTanhModule_basic",
|
||||
"Mv_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
|
@ -239,6 +270,7 @@ MHLO_PASS_SET = {
|
|||
"ReduceSumDtypeFloatModule_basic",
|
||||
"ReduceSumDtypeIntModule_basic",
|
||||
"SelectIntModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SliceSingleIdxModule_basic",
|
||||
"SqueezeDimModule_dynamic",
|
||||
"SqueezeDimModule_negDim",
|
||||
|
@ -250,9 +282,15 @@ MHLO_PASS_SET = {
|
|||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorsStackModule_basic",
|
||||
"TensorsStackNegativeDimModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"Mlp2LayerModuleNoBias_basic",
|
||||
"NumelModule_basic",
|
||||
"SiluModule_basic",
|
||||
"SquareModule_basic",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"ViewCollapseOnesMiddleModule_basic",
|
||||
|
@ -272,6 +310,7 @@ MHLO_PASS_SET = {
|
|||
"Convolution2DStaticModule_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseBinaryStaticShapeModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
|
@ -288,6 +327,7 @@ MHLO_PASS_SET = {
|
|||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
|
@ -358,6 +398,7 @@ MHLO_PASS_SET = {
|
|||
"ViewExpandCollapseModule_basic",
|
||||
"ViewExpandCollapseWithOnesModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNegativeStaticModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
|
@ -420,12 +461,14 @@ MHLO_PASS_SET = {
|
|||
"UnsafeViewDynamicExpandModule_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
"TestF16Return_basic",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseBinaryModule_basic",
|
||||
|
@ -449,6 +492,7 @@ TOSA_PASS_SET = {
|
|||
"ViewExpandOnesMiddleOppModule_basic",
|
||||
"ViewOffsetBackwardTestStaticModule_basic",
|
||||
"TanhBackward_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"AddCMulModule_basic",
|
||||
|
@ -459,6 +503,7 @@ TOSA_PASS_SET = {
|
|||
"BoolTensorReturnMixedModule_basic",
|
||||
"BoolTensorHandleSignless_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SqueezeModule_static",
|
||||
"SqueezeModule_noUnitDim",
|
||||
"SqueezeModule_allUnitDim",
|
||||
|
@ -480,6 +525,7 @@ TOSA_PASS_SET = {
|
|||
"Matmul_3d",
|
||||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ElementwiseBitwiseAndModule_basic",
|
||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseNotInt32Module_basic",
|
||||
|
@ -509,6 +555,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
|
@ -572,6 +619,7 @@ TOSA_PASS_SET = {
|
|||
"ViewExpandCollapseWithOnesModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNegativeStaticModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"UnsafeViewExpandModule_basic",
|
||||
"ReshapeCollapseModule_basic",
|
||||
|
@ -604,6 +652,7 @@ TOSA_PASS_SET = {
|
|||
"_LogSoftmaxModuleStable_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseUnsqueezeBroadcastModule_basic",
|
||||
"MaskedFillScalarIntValueModule_basic",
|
||||
"MaskedFillScalarIntValueStaticModule_basic",
|
||||
"MaskedFillTensorIntValueStaticModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
|
@ -611,8 +660,11 @@ TOSA_PASS_SET = {
|
|||
"TensorOpaqueLiteralModule_basic",
|
||||
"TypePromotionDifferentCategoryModule_basic",
|
||||
"TypePromotionSameCategoryDifferentWidthModule_basic",
|
||||
"TypePromotionSameCategoryZeroRankWider_basic",
|
||||
"TypePromotionZeroRankHigherCategoryModule_basic",
|
||||
"GatherStaticModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
|
@ -651,6 +703,10 @@ TOSA_PASS_SET = {
|
|||
"HardsigmoidRandomModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"FullLikeModuleInt2DStatic_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"FullModuleFloat2D_basic",
|
||||
"RepeatModule_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
@ -666,7 +722,7 @@ LTC_XFAIL_SET = {
|
|||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
"AddIntModule_basic",
|
||||
"BernoulliFloatModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
"BincountModule_basic",
|
||||
|
@ -686,6 +742,7 @@ LTC_XFAIL_SET = {
|
|||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"HBC_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
"IndexPut1DIntAccumulateModule_basic",
|
||||
|
@ -720,6 +777,8 @@ LTC_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IndexTensorModule3dInput_basic",
|
||||
"IndexTensorModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputNonContiguous_basic",
|
||||
"IndexTensorMultiInputOneDim_basic",
|
||||
|
@ -752,6 +811,8 @@ LTC_XFAIL_SET = {
|
|||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
"TensorsConcatNegativeDimModule_basic",
|
||||
"TensorsConcatPromoteDTypeModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
"TensorToBoolZeroRank_basic",
|
||||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
|
@ -788,4 +849,34 @@ LTC_XFAIL_SET = {
|
|||
"ElementwisePreluModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
"VarMeanUnbiasedModule_basic",
|
||||
"RandnLikeModule_basic",
|
||||
"RandnLikeDtypeModule_basic",
|
||||
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||
"BernoulliFloatModule_basic",
|
||||
"BernoulliModule_basic",
|
||||
"BernoulliPModule_basic",
|
||||
"DropoutTrainModule_basic",
|
||||
"StdCorrectionKeepDimModule_basic",
|
||||
"StdCorrectionNoneModule_basic",
|
||||
"SliceCopy_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarCorrectionAllDimReduceModule_basic",
|
||||
"VarCorrectionEmptyDimModule_basic",
|
||||
"VarCorrectionKeepDimModule_basic",
|
||||
"VarCorrectionLargeInputModule_basic",
|
||||
"VarCorrectionModule_basic",
|
||||
"VarCorrectionNoneModule_basic",
|
||||
"VarCorrectionSingleDimReduceModule_basic",
|
||||
"VarDimAllDimReduceModule_basic",
|
||||
"VarDimBiasedModule_basic",
|
||||
"VarDimEmptyDimModule_basic",
|
||||
"VarDimModule_basic",
|
||||
"VarDimMultiDimModule_basic",
|
||||
"VarDimNegativeModule_basic",
|
||||
"VarDimNoneDimModule_basic",
|
||||
"VarDimSingleDimModule_basic",
|
||||
"VarDimUnbiasedModule_basic",
|
||||
"VarUnbiasedModule_basic",
|
||||
"AtenFloatScalarModule_basic"
|
||||
}
|
||||
|
|
|
@ -91,4 +91,4 @@ resnet18 = models.resnet18(pretrained=True)
|
|||
resnet18.train(False)
|
||||
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
|
||||
|
||||
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels)
|
||||
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels)
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.eval()
|
||||
data = torch.randn(2,3,200,200)
|
||||
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")
|
|
@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
|
|||
model = BertTinyWrapper()
|
||||
model.eval()
|
||||
data = torch.randint(30522, (2, 128))
|
||||
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
|
||||
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
|
||||
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
|
||||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
|
||||
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
|
||||
outf.write(str(module))
|
||||
|
||||
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
|
||||
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")
|
|
@ -10,7 +10,7 @@
|
|||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
|
|
@ -457,7 +457,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
|
||||
"ValueRange":$operands),
|
||||
[{
|
||||
BlockAndValueMapping bvm;
|
||||
IRMapping bvm;
|
||||
OperationState state(
|
||||
loc, ConcreteOp::getOperationName(), operands, resultTypes,
|
||||
$_op->getAttrs());
|
||||
|
|
|
@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
auto scfIf = b.create<scf::IfOp>(
|
||||
loc, TypeRange{}, cond,
|
||||
loc, cond,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
if (isInclusive) {
|
||||
auto value = b.create<memref::LoadOp>(loc, input(), indices);
|
||||
|
@ -232,7 +232,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|||
|
||||
auto &srcBlock = getRegion().front();
|
||||
Region &thisRegion = scfIf.getElseRegion();
|
||||
BlockAndValueMapping bvm;
|
||||
IRMapping bvm;
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
auto &block = thisRegion.front();
|
||||
|
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
return success(folded);
|
||||
}
|
||||
|
||||
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
||||
LogicalResult ScanOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
@ -461,7 +461,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|||
|
||||
Value init = b.create<memref::LoadOp>(loc, original(), starts);
|
||||
|
||||
BlockAndValueMapping bvm;
|
||||
IRMapping bvm;
|
||||
Block &block = getRegion().front();
|
||||
bvm.map(block.getArgument(0), update);
|
||||
bvm.map(block.getArgument(1), init);
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173
|
||||
Subproject commit 21f4b84c456b471cc52016cf360e14d45f7f2960
|
|
@ -1 +1 @@
|
|||
Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4
|
||||
Subproject commit b1ac0403ee2a40fc648ada6b9f11096f3d50fd19
|
|
@ -1,6 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
else()
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
endif()
|
||||
|
|
|
@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
|
|||
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to MHLO ops";
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to Stablehlo ops";
|
||||
let description = [{
|
||||
Convert Torch ops to mhlo ops.
|
||||
Convert Torch ops to Stablehlo ops.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
||||
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
|
||||
|
||||
// Specify any options.
|
||||
let options = [
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -16,10 +16,11 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
|
||||
createConvertTorchToStablehloPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
|
|
@ -2014,6 +2014,55 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorType:$min,
|
||||
AnyTorchOptionalTensorType:$max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenClampTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorType:$min,
|
||||
AnyTorchOptionalTensorType:$max
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenClamp_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -3340,6 +3389,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||
|
@ -3638,6 +3688,31 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$p,
|
||||
AnyTorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenBernoulliPOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenBernoulliPOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -3771,6 +3846,34 @@ def Torch_AtenRandnGeneratorOp : Torch_Op<"aten.randn.generator", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory,
|
||||
AnyTorchOptionalIntType:$memory_format
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRandnLikeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenRandnLikeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -5146,11 +5249,11 @@ def Torch_AtenStdCorrectionOp : Torch_Op<"aten.std.correction", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalIntType:$correction,
|
||||
AnyTorchOptionalScalarType:$correction,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -5222,11 +5325,11 @@ def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalIntType:$correction,
|
||||
AnyTorchOptionalScalarType:$correction,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -5248,11 +5351,11 @@ def Torch_AtenVarMeanCorrectionOp : Torch_Op<"aten.var_mean.correction", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)`";
|
||||
let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalIntType:$correction,
|
||||
AnyTorchOptionalScalarType:$correction,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -6482,6 +6585,35 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNewEmptyStridedOp : Torch_Op<"aten.new_empty_strided", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenNewEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenNewEmptyStridedOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -6975,30 +7107,6 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenStackOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSumOp : Torch_Op<"aten.sum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -7556,6 +7664,86 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_add_ : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterAdd_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterReduceTwoOp : Torch_Op<"aten.scatter_reduce.two", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src,
|
||||
Torch_StringType:$reduce,
|
||||
Torch_BoolType:$include_self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterReduceTwoOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenScatterReduceTwoOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterReduce_TwoOp : Torch_Op<"aten.scatter_reduce_.two", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_reduce_.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src,
|
||||
Torch_StringType:$reduce,
|
||||
Torch_BoolType:$include_self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterReduce_TwoOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenScatterReduce_TwoOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -8670,6 +8858,31 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenStackOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -9085,6 +9298,7 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
|
||||
|
@ -9111,6 +9325,30 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
|
||||
let arguments = (ins
|
||||
Torch_BoolType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -9580,6 +9818,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
|
||||
|
@ -9850,6 +10089,31 @@ def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenPowIntFloatOp : Torch_Op<"aten.pow.int_float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::pow.int_float : (int, float) -> (float)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenPowIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenPowIntFloatOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -10307,6 +10571,7 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||
|
@ -10359,6 +10624,32 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenHardtanhBackwardOp : Torch_Op<"aten.hardtanh_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$min_val,
|
||||
AnyTorchScalarType:$max_val
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenHardtanhBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenHardtanhBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -10733,6 +11024,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
||||
|
@ -10933,11 +11225,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`";
|
||||
let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$inp,
|
||||
AnyTorchOptionalListOfTorchIntType:$dims,
|
||||
Torch_IntType:$correction,
|
||||
Torch_FloatType:$correction,
|
||||
AnyTorchOptionalIntType:$output_dtype
|
||||
);
|
||||
let results = (outs
|
||||
|
|
|
@ -376,9 +376,6 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
|
|||
|
||||
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||
Pure,
|
||||
TypesMatchWith<"contained types correspond to operand types",
|
||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
|
||||
"isValidSubtype">,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "TorchScript prim::TupleConstruct op";
|
||||
|
@ -397,6 +394,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
|||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result))
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||
|
|
|
@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||
|
@ -121,8 +123,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
|
|||
ArrayRef<std::string> backendLegalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyBackendContractPass(bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
createVerifyBackendContractNoDecompositionsPass();
|
||||
|
||||
StringRef getAbstractInterpLibrary();
|
||||
|
||||
|
|
|
@ -343,24 +343,17 @@ def LowerToBackendContract
|
|||
let dependentDialects = ["func::FuncDialect"];
|
||||
}
|
||||
|
||||
def VerifyBackendContract
|
||||
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
|
||||
def VerifyBackendContractNoDecompositions
|
||||
: Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> {
|
||||
let summary = "Check that program satisfies backend contract.";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createVerifyBackendContractPass(
|
||||
/*decompose=*/true, /*backendLegalOps=*/{})
|
||||
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass()
|
||||
}];
|
||||
let description = [{
|
||||
This pass performs a set of inspections to check that program satisfies backend
|
||||
contract. In case of check failure it prints out the error message and returns
|
||||
`signalPassFailure()` status.
|
||||
contract assuming that no decompositions were applied. In case of check failure
|
||||
it prints out the error message and returns `signalPassFailure()` status.
|
||||
}];
|
||||
let options = [
|
||||
Option<"decompose", "decompose", "bool", /*default=*/"true",
|
||||
"Decompose ops.">,
|
||||
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
|
||||
"List of ops to be considered legal for the backend.">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#define TORCHMLIR_DIALECT_TORCH_UPSTREAM_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
// For layering reasons, the parts of the core MLIR compiler code written in C++
|
||||
// never take a C++ dependency on Torch itself (any code depending on Torch C++
|
||||
|
@ -160,6 +161,15 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
|
|||
//===-----------------------------------------------------------------------===//
|
||||
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Possible value for `reduce` argument for Scatter reduce ops.
|
||||
// Source:
|
||||
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
|
||||
//===-----------------------------------------------------------------------===//
|
||||
enum ReductionType { MAX, MEAN, MIN, SUM, PROD };
|
||||
|
||||
ReductionType get_reduction_enum(const llvm::StringRef &reduce);
|
||||
|
||||
} // namespace torch_upstream
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
|||
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
||||
int64_t length);
|
||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||
Type getTypeForScalarType(
|
||||
FailureOr<Type> getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
|
||||
else()
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
endif()
|
||||
|
|
|
@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
|
|||
/// TOSA backend contract.
|
||||
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
|
||||
|
||||
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
struct MhloBackendPipelineOptions
|
||||
: public PassPipelineOptions<MhloBackendPipelineOptions> {
|
||||
// Do not register the stablehlo options if the stablehlo target is disabled
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
struct StablehloBackendPipelineOptions
|
||||
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
|
||||
Option<bool> enableStaticShape{
|
||||
*this, "enable-static-shape",
|
||||
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
||||
|
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
|
|||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
void createTorchBackendToMhloBackendPipeline(
|
||||
OpPassManager &pm, const MhloBackendPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
|
||||
void createTorchBackendToStablehloBackendPipeline(
|
||||
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyStablehloBackendContractPass();
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||
|
|
|
@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
|
|||
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the mhlo backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the stablehlo backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||
|
|
|
@ -61,7 +61,7 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
|||
return wrap(Torch::TupleType::get(
|
||||
unwrap(context),
|
||||
llvm::to_vector<6>(
|
||||
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
|
||||
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
|
||||
[](MlirType t) { return unwrap(t); }))));
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
|||
return wrap(Torch::UnionType::get(
|
||||
unwrap(context),
|
||||
llvm::to_vector<6>(
|
||||
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
|
||||
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
|
||||
[](MlirType t) { return unwrap(t); }))));
|
||||
}
|
||||
|
||||
|
@ -230,7 +230,7 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
|
|||
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
|
||||
// if numSizes == -1, then it is unranked.
|
||||
if (numSizes > -1)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::NonValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
@ -293,7 +293,7 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
|
|||
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
|
||||
// if numSizes == -1, then it is unranked.
|
||||
if (numSizes > -1)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::ValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
|
|
@ -3,13 +3,7 @@ add_subdirectory(Conversion)
|
|||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
|
||||
add_mlir_library(TorchMLIRInitAll
|
||||
InitAll.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
set(LinkedLibs
|
||||
MLIRFuncDialect
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
|
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
|
|||
TorchMLIRRefBackend
|
||||
)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND LinkedLibs
|
||||
MhloPasses
|
||||
MhloToLinalg
|
||||
StablehloToMhlo
|
||||
)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRInitAll
|
||||
InitAll.cpp
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
${LinkedLibs}
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRInitAll)
|
||||
|
|
|
@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
|
|||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToArith)
|
||||
add_subdirectory(TorchToTosa)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_subdirectory(TorchToMhlo)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
add_subdirectory(TorchToStablehlo)
|
||||
endif()
|
||||
add_subdirectory(TorchToTMTensor)
|
||||
add_subdirectory(TorchConversionToMLProgram)
|
||||
|
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
|
|||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchConversionToMLProgram
|
||||
TorchMLIRConversionUtils)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
list(APPEND linked_libs
|
||||
MhloPasses
|
||||
TorchMLIRTorchToMhlo)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRConversionPasses
|
||||
|
|
|
@ -9,15 +9,15 @@
|
|||
|
||||
#include "torch-mlir/Conversion/Passes.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#include "transforms/passes.h"
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
|
||||
|
@ -32,12 +32,4 @@ namespace {
|
|||
|
||||
void mlir::torch::registerConversionPasses() {
|
||||
::registerPasses();
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
||||
});
|
||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||
return mlir::mhlo::createSymbolicShapeOptimizationPass();
|
||||
});
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ public:
|
|||
// temp = multiplier * currentSeed + incrementStep
|
||||
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
|
||||
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
|
||||
rewriter.create<ml_program::GlobalStoreOp>(
|
||||
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
||||
globalVar);
|
||||
|
|
|
@ -232,6 +232,67 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTorchConstantIntOp
|
||||
: public OpConversionPattern<Torch::ConstantIntOp> {
|
||||
public:
|
||||
using OpConversionPattern<Torch::ConstantIntOp>::OpConversionPattern;
|
||||
using OpAdaptor = Torch::ConstantIntOp::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// note: arith.constant only accept singless integer, so convert singed to
|
||||
// singless
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, rewriter.getIntegerAttr(rewriter.getI64Type(),
|
||||
op.getValueAttr().getValue()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenFloatScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
Value result =
|
||||
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenAddOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
Value operandA =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
|
||||
Value operandB =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
|
||||
if (resultType.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
|
||||
} else if (resultType.isa<mlir::IntegerType>()) {
|
||||
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only support integer or float result type");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -381,8 +442,14 @@ public:
|
|||
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||
context);
|
||||
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenFloatScalarOp>();
|
||||
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAddOp>();
|
||||
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
|
||||
typeConverter, context);
|
||||
|
|
|
@ -463,8 +463,8 @@ public:
|
|||
}
|
||||
|
||||
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
|
||||
ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt);
|
||||
ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize);
|
||||
ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt);
|
||||
ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize);
|
||||
|
||||
// Association indices for expand/collapse ops. These two vectors
|
||||
// are populated such that two entries at the same index corresponds
|
||||
|
@ -1117,6 +1117,18 @@ public:
|
|||
|
||||
RankedTensorType newResultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto outElemType = newResultType.getElementType();
|
||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
|
||||
builder.create<linalg::YieldOp>(loc, elem);
|
||||
};
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
|
||||
}
|
||||
|
||||
int rank = newResultType.getRank();
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
sizes.reserve(rank);
|
||||
|
@ -1136,7 +1148,7 @@ public:
|
|||
|
||||
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, rewriter.getIndexAttr(dim));
|
||||
for (auto tensor : makeArrayRef(tensors).drop_front()) {
|
||||
for (auto tensor : ArrayRef(tensors).drop_front()) {
|
||||
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
resultDimSize =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
|
||||
|
@ -1270,7 +1282,7 @@ public:
|
|||
/*resultType=*/selfType,
|
||||
/*inputs=*/broadcastedSrc,
|
||||
/*outputs=*/self,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = args[0];
|
||||
|
|
|
@ -81,9 +81,21 @@ public:
|
|||
|
||||
Type inElementType = inputType.getElementType();
|
||||
if (!inElementType.isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp,
|
||||
"aten.max_dim to linalg.* requires Float input element type");
|
||||
if (inElementType.isa<mlir::IntegerType>()) {
|
||||
auto integerTy = maxDimOp.getSelf()
|
||||
.getType()
|
||||
.cast<BaseTensorType>()
|
||||
.getDtype()
|
||||
.dyn_cast<mlir::IntegerType>();
|
||||
if (integerTy.isUnsigned())
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp, "aten.max_dim to linalg.* requires input element type "
|
||||
"to be signed in case of integer");
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
|
||||
"input element type");
|
||||
}
|
||||
}
|
||||
|
||||
// Constant op to account for the reduction along dim.
|
||||
|
@ -104,13 +116,23 @@ public:
|
|||
Value initTensorMax = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(resultShape), inElementType);
|
||||
|
||||
FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
|
||||
inElementType,
|
||||
APFloat::getLargest(
|
||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
|
||||
Value fillValueMax;
|
||||
if (inElementType.isa<mlir::FloatType>()) {
|
||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
rewriter.getFloatAttr(
|
||||
inElementType,
|
||||
APFloat::getLargest(
|
||||
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
true)));
|
||||
} else {
|
||||
fillValueMax = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(
|
||||
inElementType,
|
||||
APSInt::getSignedMinValue(
|
||||
inElementType.cast<mlir::IntegerType>().getWidth())));
|
||||
}
|
||||
|
||||
Value fillValueMax =
|
||||
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
|
||||
Value filledTensorMax =
|
||||
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
|
||||
.result();
|
||||
|
@ -152,10 +174,18 @@ public:
|
|||
nestedLoc, oldIndex.getType(),
|
||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||
|
||||
auto resultMax = rewriter.create<arith::MaxFOp>(
|
||||
nestedLoc, newValue, oldValue);
|
||||
Value predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
Value resultMax, predicate;
|
||||
if (inElementType.isa<mlir::FloatType>()) {
|
||||
resultMax =
|
||||
rewriter.create<arith::MaxFOp>(nestedLoc, newValue, oldValue);
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
} else {
|
||||
resultMax =
|
||||
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
|
||||
predicate = rewriter.create<arith::CmpIOp>(
|
||||
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
|
||||
}
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
|
|
|
@ -127,9 +127,14 @@ public:
|
|||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
if (failed(maybeResultElementType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to convert `dtypeInt` to builtin type");
|
||||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||
|
@ -227,9 +232,14 @@ public:
|
|||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
if (failed(maybeResultElementType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to convert `dtypeInt` to builtin type");
|
||||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape.
|
||||
|
|
|
@ -59,6 +59,15 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
|
|||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
|
||||
Type elementalType, Value lhs,
|
||||
Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UGE,
|
||||
arith::CmpIPredicate::uge,
|
||||
arith::CmpIPredicate::sge>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||
|
@ -67,6 +76,14 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
|||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
|
||||
Type elementalType, Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::ULE,
|
||||
arith::CmpIPredicate::ule,
|
||||
arith::CmpIPredicate::sle>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
|
||||
|
@ -117,6 +134,46 @@ static Value createCalculationForMathOpWithDtypeConversion(
|
|||
return b.create<MathOpTy>(loc, arg);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
||||
Value lhs, Value rhs) {
|
||||
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenLeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenEqTensorOp>(),
|
||||
"unimplemented: op type not supported");
|
||||
|
||||
Type lhsDtype = lhs.getType();
|
||||
Type rhsDtype = rhs.getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
op.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
|
||||
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
|
||||
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
|
||||
return createGreaterThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
|
||||
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
|
||||
return createEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("unimplemented: op type not supported");
|
||||
}
|
||||
|
||||
static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||
OpBuilder &b, Location loc, TypeConverter *converter,
|
||||
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
|
||||
|
@ -177,8 +234,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
|
||||
(!matchPattern(clone.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||
clone.emitError("unimplemented: only default memory format is supported");
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
|
||||
clone.emitError("unimplemented: only contiguous and channels last memory "
|
||||
"format is supported");
|
||||
return nullptr;
|
||||
}
|
||||
return payloadArgs[0];
|
||||
|
@ -293,7 +352,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
round.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
return b.create<math::RoundOp>(loc, payloadArgs[0]);
|
||||
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
|
||||
}
|
||||
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
|
||||
if (!prelu.getType()
|
||||
|
@ -370,6 +429,29 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
|
||||
}
|
||||
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
|
||||
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
|
||||
if (!hardtanhBackward.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value gradOutput = payloadArgs[0];
|
||||
Type elementType = gradOutput.getType();
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType);
|
||||
Value constantZero =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType);
|
||||
Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType);
|
||||
Value lesser =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, self, min);
|
||||
Value greater =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, self, max);
|
||||
Value cmp = b.create<arith::OrIOp>(loc, lesser, greater);
|
||||
return b.create<arith::SelectOp>(loc, cmp, constantZero, gradOutput);
|
||||
}
|
||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(add.getType())
|
||||
|
@ -463,64 +545,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||
AtenGtTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
||||
AtenEqTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
|
||||
if (elementalType.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (elementalType.isa<mlir::IntegerType>()) {
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
eqTensor.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||
AtenLtTensorOp::Adaptor adaptor(operands);
|
||||
Type lhsDtype = payloadArgs[0].getType();
|
||||
Type rhsDtype = payloadArgs[1].getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType =
|
||||
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
return createLessThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
|
@ -964,18 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
.getElementType();
|
||||
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||
}
|
||||
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
|
||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(maskedFillScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
Value input = payloadArgs[0];
|
||||
Value mask = payloadArgs[1];
|
||||
Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||
|
||||
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
|
||||
}
|
||||
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(maskedFillTensor.getType())
|
||||
|
@ -1034,7 +1065,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value allOnesVal = b.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(
|
||||
elementType,
|
||||
APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth())));
|
||||
APSInt::getAllOnes(elementType.getIntOrFloatBitWidth())));
|
||||
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
|
||||
}
|
||||
|
||||
|
@ -1082,10 +1113,10 @@ public:
|
|||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||
|
@ -1561,12 +1592,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
|
||||
AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
|
||||
AtenFillScalarOp, AtenFillTensorOp>();
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||
TorchToMhlo.cpp
|
||||
MhloLegalizeUtils.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Pooling.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
||||
|
||||
DEPENDS
|
||||
MhloDialect
|
||||
MhloToLinalg
|
||||
MLIRMhloPassIncGen
|
||||
LMHLOTransformsPassIncGen
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MhloDialect
|
||||
MhloToLinalg
|
||||
MLIRBufferTransforms
|
||||
StablehloOps
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRConversionUtils
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToMhlo)
|
|
@ -1,74 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace torch_to_mhlo {
|
||||
|
||||
struct TorchToMhloOptions {
|
||||
bool enableStaticShape = false;
|
||||
size_t dimSizeIndexBits = 64;
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||
const TorchToMhloOptions &options)
|
||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||
this->options = options;
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||
}
|
||||
const TorchToMhloOptions &getOptions() const { return options; }
|
||||
|
||||
private:
|
||||
TorchToMhloOptions options;
|
||||
};
|
||||
|
||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
} // namespace torch_to_mhlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -29,7 +30,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||
mlir::Value &self, mlir::Value &other,
|
||||
|
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
|||
if (selfRank > otherRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
other = *unsqueezeInfo;
|
||||
} else if (otherRank > selfRank) {
|
||||
auto unsqueezeDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
||||
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
|
||||
unsqueezeDims, dimSizeIndexBits);
|
||||
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
|
||||
dimSizeIndexBits);
|
||||
if (failed(unsqueezeInfo))
|
||||
return failure();
|
||||
self = *unsqueezeInfo;
|
||||
|
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
|||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
|
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
|||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
|
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
constType,
|
||||
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true));
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
if (elementType.isa<mlir::IntegerType>()) {
|
||||
|
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
constAttr = SplatElementsAttr::get(
|
||||
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
||||
}
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
return rewriter
|
||||
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
||||
.getResult();
|
||||
}
|
||||
return failure();
|
||||
|
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|||
|
||||
// These legalizations are for unary ops.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename MhloOpT>
|
||||
template <typename AtenOpT, typename StablehloOpT>
|
||||
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
|
@ -137,13 +142,13 @@ public:
|
|||
Value self = adaptor.getSelf();
|
||||
auto selfType = self.getType().cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
self = mhlo::promoteType(rewriter, self, outType);
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
|
||||
self = hlo::promoteType(rewriter, self, outType);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -152,7 +157,7 @@ public:
|
|||
// These legalizations are for unary ops with only for floating point datatypes.
|
||||
// There is no supported quantized integer mode for these.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename MhloOpT>
|
||||
template <typename AtenOpT, typename StablehloOpT>
|
||||
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
|
@ -164,10 +169,10 @@ public:
|
|||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -198,7 +203,7 @@ public:
|
|||
.template dyn_cast<TensorType>();
|
||||
|
||||
if (!outType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
|
@ -216,9 +221,9 @@ public:
|
|||
|
||||
SmallVector<int32_t> values(size, fillVal);
|
||||
auto constOp =
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -247,8 +252,8 @@ public:
|
|||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outTy);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
@ -274,7 +279,7 @@ public:
|
|||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -287,18 +292,19 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsType) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
if (isa<AtenRsubScalarOp>(op)) {
|
||||
std::swap(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
|
||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||
Value alpha =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
|
||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||
adaptor.getAlpha(), outElemTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||
bcastDimensions);
|
||||
|
@ -328,7 +334,7 @@ public:
|
|||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -343,11 +349,12 @@ public:
|
|||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhs = lhs;
|
||||
} else if (!rhsType) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
@ -368,15 +375,15 @@ public:
|
|||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
|
||||
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
|
||||
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
|
||||
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
|
||||
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
||||
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
|
||||
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
||||
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
|
@ -401,7 +408,7 @@ public:
|
|||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -414,11 +421,12 @@ public:
|
|||
}
|
||||
|
||||
if (!rhsTy) {
|
||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
lhsElemTy);
|
||||
}
|
||||
|
||||
// TODO: what is the PyTorch default type promotion?
|
||||
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
|
||||
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
|
||||
|
||||
chlo::ComparisonTypeAttr compareTypeAttr;
|
||||
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||
|
@ -485,8 +493,8 @@ public:
|
|||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
|
@ -537,8 +545,8 @@ public:
|
|||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||
rewriter.getI64Type()),
|
||||
permValues);
|
||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
Value self = adaptor.getSelf();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
} else {
|
||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
|
||||
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
|
||||
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
|
||||
inputRank);
|
||||
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
|
||||
rewriter.getIndexType(), dim);
|
||||
}
|
||||
|
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||
AtenWhereSelfOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
AtenWhereSelfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
Value cond = adaptor.getCondition();
|
||||
Value other = adaptor.getOther();
|
||||
|
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
return op.emitError("failed broadcast other and condition ranks");
|
||||
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
||||
op,
|
||||
getTypeConverter()->convertType(op.getType()),
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
ArrayRef<Value>{cond, self, other});
|
||||
return success();
|
||||
}
|
||||
|
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
.cast<RankedTensorType>();
|
||||
|
||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
return success();
|
||||
}
|
||||
|
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
op->getLoc(), ValueRange{bcastShapeVec});
|
||||
auto dimensionNumbers =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op, outType, self, bcastShapeTensor,
|
||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||
}
|
||||
|
@ -708,28 +715,11 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||
rewriter.getI64Type()),
|
||||
permValues);
|
||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenTanhOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||
AtenTanhOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
// ValueTensorLiteralOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||
|
@ -751,16 +741,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
||||
return APInt(bitWidth, v.getSExtValue());
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||
valueAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
|
||||
adaptor.getValue());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
||||
adaptor.getValue());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenReciprocalOp
|
||||
// Reciprocal(x) = Div(1, x)
|
||||
template <>
|
||||
|
@ -777,7 +767,45 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenPowTensorScalarOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.getExponent();
|
||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
if (!rhsType) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
|
||||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -790,9 +818,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto outputElemType = outputType.getElementType();
|
||||
Value mhloTensor =
|
||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
|
||||
rewriter.replaceOp(op, mhloTensor);
|
||||
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
||||
rewriter, op, adaptor.getA(), outputElemType);
|
||||
rewriter.replaceOp(op, stablehloTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -815,7 +843,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenReluOp
|
||||
// Relu(x) = Max(0, x)
|
||||
template <>
|
||||
|
@ -836,11 +863,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
false),
|
||||
lhs);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// Convert a Aten::GELU to HLO
|
||||
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
||||
template <>
|
||||
|
@ -857,12 +883,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
|
||||
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
|
||||
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
|
||||
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
|
||||
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
|
||||
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
|
||||
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
||||
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
|
||||
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
|
||||
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
||||
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
|
||||
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -881,7 +907,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenBatchNormOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||
|
@ -919,28 +944,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{channelDim});
|
||||
if (failed(checkNotNone(rewriter, op, weight))) {
|
||||
weight = mhlo::getConstantOfShape(
|
||||
weight = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, bias))) {
|
||||
bias = mhlo::getConstantOfShape(
|
||||
bias = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
||||
runningVar = mhlo::getConstantOfShape(
|
||||
runningVar = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
||||
runningMean = mhlo::getConstantOfShape(
|
||||
runningMean = hlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
|
@ -983,10 +1008,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
Type batchMeanOrVarTy =
|
||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
|
||||
return success();
|
||||
} else {
|
||||
|
@ -995,10 +1021,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
inputTy.getShape().end()};
|
||||
castShape[1] = weightTy.getShape()[0];
|
||||
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
||||
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
|
||||
// Feature counts must match among operands of
|
||||
// stablehlo::BatchNormInferenceOp.
|
||||
Value inputCasted =
|
||||
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
||||
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
|
||||
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
runningMean, runningVar,
|
||||
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||
|
@ -1008,7 +1035,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// AtenNativeLayerNormOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||
|
@ -1076,21 +1102,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
}
|
||||
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
|
||||
numEmbeddingDimSize};
|
||||
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
|
||||
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
|
||||
|
||||
auto mhloBatchNormOutTy =
|
||||
auto stablehloBatchNormOutTy =
|
||||
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
|
||||
auto mhloBathNormOutMeanOrVarTy =
|
||||
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
|
||||
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
|
||||
meanOrVarStablehloOutShape, inputTy.getElementType());
|
||||
|
||||
// Reshape input
|
||||
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, input,
|
||||
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), stablehloBatchNormOutTy, input,
|
||||
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
.value());
|
||||
|
||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
||||
SmallVector<APFloat> zeroConstVec(
|
||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
|
@ -1103,16 +1129,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto oneOrZeroConstType =
|
||||
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
||||
|
||||
Value scale = rewriter.create<mhlo::ConstantOp>(
|
||||
Value scale = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), oneOrZeroConstType,
|
||||
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
|
||||
Value offset = rewriter.create<mhlo::ConstantOp>(
|
||||
Value offset = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), oneOrZeroConstType,
|
||||
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
|
||||
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
|
||||
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
|
||||
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op->getLoc(), stablehloBatchNormOutTy,
|
||||
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
|
||||
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// Reshape back
|
||||
auto outputTy =
|
||||
|
@ -1120,36 +1148,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto outputMeanOrVarTy =
|
||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
||||
|
||||
auto output = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
||||
{static_cast<int64_t>(outputTy.getShape().size())})
|
||||
.value());
|
||||
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
||||
mhlo::getConstTensor(
|
||||
hlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.value());
|
||||
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
||||
mhlo::getConstTensor(
|
||||
hlo::getConstTensor(
|
||||
rewriter, op, outputMeanOrVarTy.getShape(),
|
||||
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
||||
.value());
|
||||
|
||||
// Apply affine transform: output x weight + bias [element-wise]
|
||||
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
||||
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
||||
auto outputMulWeight =
|
||||
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||
auto finalOuput =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
|
||||
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
||||
auto finalOuput = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), outputMulWeight, bcastedBias);
|
||||
rewriter.replaceOp(op, {finalOuput, mean, var});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
// AtenCatOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||
|
@ -1173,11 +1200,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
|
||||
// Promote type
|
||||
for (auto &v : builtinTensors) {
|
||||
v = mhlo::promoteType(rewriter, v, outType);
|
||||
v = hlo::promoteType(rewriter, v, outType);
|
||||
}
|
||||
|
||||
size_t posDim = toPositiveDim(dim, outType.getRank());
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||
op, outType, ValueRange(builtinTensors), posDim);
|
||||
return success();
|
||||
}
|
||||
|
@ -1225,7 +1252,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "this op should be folded as its `min` and `max` both are none");
|
||||
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
maxValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(minInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1233,7 +1261,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
}
|
||||
minValue = *minInfo;
|
||||
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
minValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
||||
if (failed(maxInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -1241,10 +1270,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
}
|
||||
maxValue = *maxInfo;
|
||||
} else {
|
||||
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
|
||||
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
|
||||
minValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
||||
maxValue =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
||||
maxValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1266,24 +1298,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
op, "unimplemented: only int or float dtype supported");
|
||||
}
|
||||
|
||||
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
|
||||
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
||||
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
|
||||
Value start =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
|
||||
Value end =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
||||
Value step =
|
||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
|
||||
|
||||
// Get length of the 1-d output tensor
|
||||
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
|
||||
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
|
||||
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
|
||||
|
||||
Value resultLength = rewriter.create<mhlo::ReshapeOp>(
|
||||
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get({1}, dtype), divOut);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
|
||||
resultLength = rewriter.create<mhlo::ConvertOp>(
|
||||
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
|
||||
resultLength = rewriter.create<stablehlo::ConvertOp>(
|
||||
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
||||
}
|
||||
|
||||
Value window =
|
||||
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
||||
DenseIntElementsAttr broadcastDimensions;
|
||||
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
||||
broadcastDimensions);
|
||||
|
@ -1298,9 +1333,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto outType = this->getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.cast<TensorType>();
|
||||
auto outType =
|
||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
if (!outType) {
|
||||
return op.emitError("only tensor type is supported");
|
||||
}
|
||||
|
@ -1320,26 +1354,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
|
||||
|
||||
// Compute
|
||||
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
|
||||
Value erfArg =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
|
||||
Value kBeta0 =
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
||||
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
|
||||
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
|
||||
adaptor.getSelf());
|
||||
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
||||
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
|
||||
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
|
||||
Value inputSquared = rewriter.create<mhlo::MulOp>(
|
||||
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
|
||||
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
|
||||
Value inputSquared = rewriter.create<stablehlo::MulOp>(
|
||||
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
||||
Value negHalfInputSquared =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
|
||||
Value expRes =
|
||||
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
|
||||
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
|
||||
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
|
||||
Value pdfTimesInput =
|
||||
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
||||
Value pdfTimesInputAddCdf =
|
||||
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
|
||||
pdfTimesInputAddCdf);
|
||||
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
|
||||
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1366,9 +1401,9 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
|
@ -1376,23 +1411,29 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
target.addIllegalOp<RuntimeAssertOp>();
|
||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
|
||||
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
|
||||
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
|
||||
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
||||
context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
|
@ -1459,9 +1500,9 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
|
||||
|
@ -1482,10 +1523,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
|
||||
context)
|
||||
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
|
||||
typeConverter, context)
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
||||
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
|
@ -0,0 +1,29 @@
|
|||
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
|
||||
TorchToStablehlo.cpp
|
||||
StablehloLegalizeUtils.cpp
|
||||
Basic.cpp
|
||||
Gather.cpp
|
||||
Linear.cpp
|
||||
ViewLike.cpp
|
||||
Reduction.cpp
|
||||
Pooling.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
|
||||
|
||||
DEPENDS
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRBufferTransforms
|
||||
StablehloOps
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRConversionUtils
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -24,7 +25,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
||||
// indexVecDim
|
||||
int64_t indexVecDim = indicesRank;
|
||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/offsetDims,
|
||||
/*collapsedSliceDims=*/collapsedSliceDims,
|
||||
|
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
auto outputTy =
|
||||
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
||||
return rewriter
|
||||
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||||
sliceSizesTensor, dimsAttr)
|
||||
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||||
sliceSizesTensor, dimsAttr)
|
||||
.getResult();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||
// Ref:
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||
// padding_idx (int, optional)
|
||||
// – If specified, the entries at padding_idx do not contribute to the gradient;
|
||||
// therefore, the embedding vector at padding_idx is not updated during training,
|
||||
// i.e. it remains as a fixed “pad”.
|
||||
// – If specified, the entries at padding_idx do not contribute to the
|
||||
// gradient; therefore, the embedding vector at padding_idx is not updated
|
||||
// during training, i.e. it remains as a fixed “pad”.
|
||||
// scale_grad_by_freq (boolean, optional)
|
||||
// – If given, this will scale gradients by the inverse of frequency of the
|
||||
// words in the mini-batch. Default False.
|
||||
|
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
|
||||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
return success();
|
||||
|
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
return success();
|
||||
|
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
|
||||
auto options = getOptions();
|
||||
auto indexShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||||
if (failed(indexShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dim sizes of `index` param");
|
||||
|
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
SmallVector<Value> toConcat;
|
||||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||||
if (i == dim) {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
loc, toConcatIndexType, index, toConcatIndexShape));
|
||||
} else {
|
||||
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
|
||||
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
loc, toConcatIndexType, toConcatIndexShape,
|
||||
rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
}
|
||||
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
|
||||
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
||||
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
||||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
||||
|
||||
|
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
startIndexMap.push_back(i);
|
||||
}
|
||||
|
||||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/{},
|
||||
/*collapsedSliceDims=*/collapsedDims,
|
||||
/*startIndexMap=*/startIndexMap,
|
||||
/*indexVecDim=*/indexVecDim);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||
op, input, gatherIndicies, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -25,7 +26,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||
|
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
ArrayRef<int64_t> broadcastDims) {
|
||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto loc = op->getLoc();
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
|
||||
RankedTensorType outTy =
|
||||
RankedTensorType::get(shape, tensorTy.getElementType());
|
||||
|
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
|||
rewriter.getIntegerType(64));
|
||||
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
||||
|
||||
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, mhloShape, broadcastAttr);
|
||||
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, stablehloShape, broadcastAttr);
|
||||
return broadcast;
|
||||
}
|
||||
|
||||
|
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
ArrayRef<int64_t> inpTransDims) {
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto rank = inputTy.getRank();
|
||||
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
|
||||
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
||||
auto inpShape = inputTy.getShape();
|
||||
std::vector<int64_t> newShape;
|
||||
newShape.reserve(rank);
|
||||
|
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
|
||||
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
|
||||
permuteAttr);
|
||||
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
|
||||
input, permuteAttr);
|
||||
return result.getResult();
|
||||
}
|
||||
|
||||
|
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
// set result dimensions
|
||||
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
|
||||
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
|
||||
lhsResultDim >= 0) {
|
||||
outShape.push_back(lhsShape[lhsResultDim]);
|
||||
}
|
||||
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
|
||||
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
|
||||
rhsResultDim >= 0) {
|
||||
outShape.push_back(rhsShape[rhsResultDim]);
|
||||
}
|
||||
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
||||
|
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||
rhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto lhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||
lhsDimSizes.end());
|
||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||
|
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||
lhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
|
||||
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
|
||||
dimSizeIndexBits);
|
||||
auto rhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||
rhsDimSizes.end());
|
||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||
|
@ -218,8 +221,8 @@ public:
|
|||
if (lhsRank <= 2 && rhsRank <= 2) {
|
||||
auto tensorType =
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
|
||||
nullptr);
|
||||
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
|
||||
rhs, nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -253,8 +256,8 @@ public:
|
|||
lhsContractingDim = nBatchDims;
|
||||
}
|
||||
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
stablehlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
|
@ -264,8 +267,8 @@ public:
|
|||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||
lhsContractingDim, rhsContractingDim);
|
||||
output = rewriter
|
||||
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
@ -312,7 +315,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -335,7 +338,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -371,7 +374,7 @@ public:
|
|||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
"only ranked tensor types are supported in StableHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
@ -398,10 +401,10 @@ public:
|
|||
auto bias = adaptor.getBias();
|
||||
auto biasTy = bias.getType();
|
||||
|
||||
// MHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<RankedTensorType>())
|
||||
return op.emitError("only ranked tensor types are supported in MHLO "
|
||||
return op.emitError("only ranked tensor types are supported in StableHLO "
|
||||
"matmul for bias tensor");
|
||||
|
||||
// weight.T
|
||||
|
@ -427,14 +430,14 @@ public:
|
|||
auto outTy =
|
||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||
lhsContractingDim, rhsContractingDim);
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
stablehlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
|
@ -464,7 +467,7 @@ public:
|
|||
auto weightElemTy = weightTy.getElementType();
|
||||
auto rank = weightTy.getRank();
|
||||
const auto &options = getOptions();
|
||||
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
|
||||
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
|
||||
rewriter, op, weight, options.dimSizeIndexBits);
|
||||
auto weightShape = weightTy.getShape();
|
||||
SmallVector<int64_t> weightShapeInt(rank);
|
||||
|
@ -488,7 +491,7 @@ public:
|
|||
}
|
||||
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), weightShapeVec);
|
||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||
weight, weightShapeTensor);
|
||||
|
||||
|
@ -497,7 +500,7 @@ public:
|
|||
for (int64_t i = 0; i <= rank; i++)
|
||||
transposeDims[i] = i;
|
||||
std::swap(transposeDims[1], transposeDims[0]);
|
||||
weight = rewriter.create<mhlo::TransposeOp>(
|
||||
weight = rewriter.create<stablehlo::TransposeOp>(
|
||||
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
|
||||
|
||||
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
|
||||
|
@ -509,7 +512,7 @@ public:
|
|||
weightShapeVec[1] = OCMulGValue;
|
||||
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), weightShapeVec);
|
||||
weight = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
|
||||
weight, weightShapeTensor);
|
||||
return weight;
|
||||
|
@ -544,25 +547,27 @@ public:
|
|||
}
|
||||
|
||||
// Prepare for transposed convolution
|
||||
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
|
||||
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
|
||||
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
|
||||
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
|
||||
DenseIntElementsAttr stablehloStride =
|
||||
rewriter.getI64TensorAttr(stablehloStrideVec);
|
||||
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
|
||||
for (int i = 0; i < nSpatialDims; ++i) {
|
||||
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
|
||||
mhloPaddingVec[i * 2] = padInt;
|
||||
mhloPaddingVec[i * 2 + 1] = padInt;
|
||||
stablehloPaddingVec[i * 2] = padInt;
|
||||
stablehloPaddingVec[i * 2 + 1] = padInt;
|
||||
}
|
||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
||||
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
|
||||
mhloPaddingVec);
|
||||
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
|
||||
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
|
||||
DenseIntElementsAttr mhloLhsDilation =
|
||||
rewriter.getI64TensorAttr(mhloLhsDilationVec);
|
||||
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
|
||||
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
|
||||
DenseIntElementsAttr mhloRhsDilation =
|
||||
rewriter.getI64TensorAttr(mhloRhsDilationVec);
|
||||
stablehloPaddingVec);
|
||||
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
|
||||
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloLhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
|
||||
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
stablehloRhsDilationVec.begin());
|
||||
DenseIntElementsAttr stablehloRhsDilation =
|
||||
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
|
||||
|
||||
DenseElementsAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
@ -571,8 +576,8 @@ public:
|
|||
for (int i = 0; i < nSpatialDims; ++i) {
|
||||
spatialDims.push_back(i + 2);
|
||||
}
|
||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
mhlo::ConvDimensionNumbersAttr::get(
|
||||
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
stablehlo::ConvDimensionNumbersAttr::get(
|
||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||
/*inputFeatureDimension=*/1,
|
||||
/*inputSpatialDimensions=*/spatialDims,
|
||||
|
@ -583,17 +588,18 @@ public:
|
|||
/*outputSpatialDimensions=*/spatialDims);
|
||||
|
||||
// Reverse and transpose weight
|
||||
weight = rewriter.create<mhlo::ReverseOp>(
|
||||
weight = rewriter.create<stablehlo::ReverseOp>(
|
||||
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
|
||||
if (groups != 1) {
|
||||
weight = reshapeConvWeight(rewriter, op, weight, groups);
|
||||
}
|
||||
|
||||
// Create transposed convolution
|
||||
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
||||
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
|
||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
||||
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||
op->getLoc(), convOutTy, input, weight, stablehloStride,
|
||||
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||
precisionConfig);
|
||||
|
||||
// Handle output padding
|
||||
if (!needHandleOutputPadding) {
|
||||
|
@ -605,8 +611,8 @@ public:
|
|||
std::copy(outputPadding.begin(), outputPadding.end(),
|
||||
edgePaddingHighVec.begin() + 2);
|
||||
Value paddingValue =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
||||
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
|
||||
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
|
||||
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
|
||||
mlir::DenseIntElementsAttr edgePaddingLow =
|
||||
rewriter.getI64VectorAttr(edgePaddingLowVec);
|
||||
mlir::DenseIntElementsAttr edgePaddingHigh =
|
||||
|
@ -614,7 +620,7 @@ public:
|
|||
mlir::DenseIntElementsAttr interiorPadding =
|
||||
rewriter.getI64VectorAttr(interiorPaddingVec);
|
||||
|
||||
auto paddedOutput = rewriter.create<mhlo::PadOp>(
|
||||
auto paddedOutput = rewriter.create<stablehlo::PadOp>(
|
||||
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
|
||||
edgePaddingHigh, interiorPadding);
|
||||
|
||||
|
@ -628,22 +634,22 @@ public:
|
|||
ArrayRef<int64_t> dilation, int64_t groups) const {
|
||||
int64_t nDims = outType.getRank();
|
||||
|
||||
// Get mhlo::ConvolutionOp attributes
|
||||
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
|
||||
// Get stablehlo::ConvolutionOp attributes
|
||||
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(stride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stride);
|
||||
std::vector<int64_t> mhloPaddingVec;
|
||||
std::vector<int64_t> stablehloPaddingVec;
|
||||
for (size_t i = 0; i < padding.size(); i++) {
|
||||
mhloPaddingVec.emplace_back(padding[i]);
|
||||
mhloPaddingVec.emplace_back(padding[i]);
|
||||
stablehloPaddingVec.emplace_back(padding[i]);
|
||||
stablehloPaddingVec.emplace_back(padding[i]);
|
||||
}
|
||||
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
||||
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPaddingVec);
|
||||
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
|
||||
stablehloPaddingVec);
|
||||
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
dilation);
|
||||
|
@ -651,8 +657,8 @@ public:
|
|||
for (int64_t i = 2; i < nDims; i++) {
|
||||
spatialDimensions.emplace_back(i);
|
||||
}
|
||||
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
mhlo::ConvDimensionNumbersAttr::get(
|
||||
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||
stablehlo::ConvDimensionNumbersAttr::get(
|
||||
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||
/*inputFeatureDimension=*/1,
|
||||
/*inputSpatialDimensions=*/spatialDimensions,
|
||||
|
@ -662,17 +668,18 @@ public:
|
|||
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
|
||||
/*outputSpatialDimensions=*/spatialDimensions);
|
||||
|
||||
// mhlo::ConvolutionOp's optional attributes, leave them as default
|
||||
DenseIntElementsAttr mhloLhsDilation;
|
||||
// stablehlo::ConvolutionOp's optional attributes, leave them as default
|
||||
DenseIntElementsAttr stablehloLhsDilation;
|
||||
DenseElementsAttr windowReversal;
|
||||
ArrayAttr precisionConfig;
|
||||
|
||||
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
||||
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
|
||||
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
||||
static_cast<uint64_t>(groups), 1, precisionConfig);
|
||||
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
|
||||
op->getLoc(), outType, input, weight, stablehloWindowStride,
|
||||
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
|
||||
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
|
||||
precisionConfig);
|
||||
|
||||
return mhloConvOp.getResult();
|
||||
return stablehloConvOp.getResult();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -754,21 +761,22 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
Value mhloConvResult;
|
||||
Value stablehloConvResult;
|
||||
if (transposed) {
|
||||
mhloConvResult = convertTransposedConv(
|
||||
stablehloConvResult = convertTransposedConv(
|
||||
op, rewriter, outTy, input, weight, stride, padding, dilation,
|
||||
outputPadding, groups, needHandleOutputPadding);
|
||||
} else {
|
||||
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
|
||||
stride, padding, dilation, groups);
|
||||
stablehloConvResult =
|
||||
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
|
||||
dilation, groups);
|
||||
}
|
||||
|
||||
auto bias = adaptor.getBias();
|
||||
|
||||
// No bias provided
|
||||
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
|
||||
rewriter.replaceOp(op, mhloConvResult);
|
||||
rewriter.replaceOp(op, stablehloConvResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -790,21 +798,21 @@ public:
|
|||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||
|
||||
const auto &options = getOptions();
|
||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = hlo::promoteType(rewriter, bias, outTy);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
|
||||
bias, bcastDimensions);
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, outTy, stablehloConvResult, bias, bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
|
@ -7,15 +7,16 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -28,26 +29,26 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
auto constType = RankedTensorType::get({}, elementTy);
|
||||
// Avg pooling
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
|
||||
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
||||
if (elementTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APFloat::getZero(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getLargest(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
op->emitError("unimplemented lowering in AtenPoolingOp");
|
||||
|
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
stablehloPadding);
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
|
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value result =
|
||||
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
|
||||
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
|
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
std::copy(dilation.begin(), dilation.end(),
|
||||
mhloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
stablehloDilation.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
stablehloPadding);
|
||||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
auto initIndexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicIotaOp>(
|
||||
.create<stablehlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(initIndexShapeForType,
|
||||
rewriter.getI64Type()),
|
||||
|
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
|
||||
auto indexTensor =
|
||||
rewriter
|
||||
.create<mhlo::DynamicReshapeOp>(
|
||||
.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
||||
initIndexTensor, inputShapeTensor)
|
||||
.getResult();
|
||||
|
||||
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
||||
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
||||
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
||||
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
||||
|
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::EQ);
|
||||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// Get smaller index if compared values are equal.
|
||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx =
|
||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<mhlo::ReturnOp>(
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
|
||||
|
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
|
||||
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
||||
// input
|
||||
SmallVector<int64_t> mhloStride(inputRank, 1);
|
||||
SmallVector<int64_t> mhloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
|
||||
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
|
||||
std::copy(stride.begin(), stride.end(),
|
||||
stablehloStride.begin() + inputRank - 2);
|
||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||
mhloKernelSize.begin() + inputRank - 2);
|
||||
mhloPadding[mhloPadding.size() - 4] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 3] = padding[0];
|
||||
mhloPadding[mhloPadding.size() - 2] = padding[1];
|
||||
mhloPadding[mhloPadding.size() - 1] = padding[1];
|
||||
stablehloKernelSize.begin() + inputRank - 2);
|
||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloKernelSize);
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloStride);
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
mhloDilation);
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
|
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
// Use kernel size as the divisor
|
||||
if (countIncludePad) {
|
||||
Value divisor = mhlo::getConstTensor<int64_t>(
|
||||
Value divisor = hlo::getConstTensor<int64_t>(
|
||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||
.value();
|
||||
divisor = mhlo::promoteType(rewriter, divisor, outTy);
|
||||
divisor = hlo::promoteType(rewriter, divisor, outTy);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
// Use another stablehlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeVec =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
||||
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
|
||||
|
||||
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
|
||||
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
||||
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
||||
windowDilations, pad);
|
||||
|
@ -522,18 +526,99 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||
// AtenCumsumOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||
AtenCumsumOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dim must be a constant int");
|
||||
}
|
||||
dim = toPositiveDim(dim, inputRank);
|
||||
if (!isValidDim(dim, inputRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is out of range");
|
||||
}
|
||||
if (inputTy.isDynamicDim(dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: cumsum dim must be static");
|
||||
}
|
||||
|
||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||
|
||||
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
||||
stablehloKernelSize[dim] = inputShape[dim];
|
||||
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||
stablehloPadding[dim * 2] = inputShape[dim] - 1;
|
||||
|
||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloKernelSize);
|
||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloStride);
|
||||
DenseIntElementsAttr baseDilations;
|
||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||
rewriter.getI64Type()),
|
||||
stablehloDilation);
|
||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(
|
||||
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
||||
rewriter.getI64Type()),
|
||||
stablehloPadding);
|
||||
|
||||
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
||||
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
||||
baseDilations, windowDilations, pad);
|
||||
|
||||
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
||||
|
||||
// Add bb argument
|
||||
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
||||
auto *firstArg = sumBlock.args_begin();
|
||||
auto *secondArg = std::next(firstArg);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&sumBlock);
|
||||
|
||||
Value sumResult =
|
||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||
|
@ -542,4 +627,6 @@ void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
|||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context, options);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace torch_to_stablehlo {
|
||||
|
||||
struct TorchToStablehloOptions {
|
||||
bool enableStaticShape = false;
|
||||
size_t dimSizeIndexBits = 64;
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||
const TorchToStablehloOptions &options)
|
||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||
this->options = options;
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||
}
|
||||
const TorchToStablehloOptions &getOptions() const { return options; }
|
||||
|
||||
private:
|
||||
TorchToStablehloOptions options;
|
||||
};
|
||||
|
||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
const TorchToStablehloOptions &options);
|
||||
void populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
void populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options);
|
||||
|
||||
} // namespace torch_to_stablehlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -25,7 +26,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getZero(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
constType, {APFloat::getLargest(
|
||||
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/true)});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
} else if (elementTy.isa<mlir::IntegerType>() &&
|
||||
elementTy.getIntOrFloatBitWidth() != 8) {
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
constType,
|
||||
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
||||
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
||||
constAttr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
return std::nullopt;
|
||||
Value initIndex;
|
||||
if (dimSizeIndexBits == 32) {
|
||||
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||
} else {
|
||||
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
}
|
||||
|
||||
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
|
||||
|
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
||||
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape,
|
||||
rewriter.getIntegerType(dimSizeIndexBits)),
|
||||
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), ValueRange{input, indexTensor},
|
||||
ValueRange{
|
||||
initValue,
|
||||
|
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
},
|
||||
dimensions);
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
|
||||
// Add block arguments
|
||||
auto blockValArgumentType =
|
||||
|
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
auto *secondValArg = std::next(firstIdxArg);
|
||||
auto *secondIdxArg = std::next(secondValArg);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
||||
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
||||
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
||||
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
|
||||
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
}
|
||||
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
mhlo::ComparisonDirection::EQ);
|
||||
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
||||
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
||||
stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareGeDirectionAttr, compareTypeAttr);
|
||||
Value retValResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
||||
|
||||
// get smaller index value if compared nums are equal.
|
||||
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
|
||||
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
||||
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
||||
compareEqDirectionAttr, compareTypeAttr);
|
||||
Value minIdx =
|
||||
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
|
||||
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
||||
*secondIdxArg);
|
||||
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
||||
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
|
||||
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
||||
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
||||
|
||||
rewriter.create<mhlo::ReturnOp>(
|
||||
rewriter.create<stablehlo::ReturnOp>(
|
||||
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
||||
}
|
||||
return mhloReduceOp.getResults();
|
||||
return stablehloReduceOp.getResults();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenArgmaxOp to MHLO");
|
||||
"AtenArgmaxOp to StableHLO");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
|
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, mhloReduceResults[1]);
|
||||
rewriter.replaceOp(op, stablehloReduceResults[1]);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxDimOp to MHLO");
|
||||
"AtenMaxDimOp to StableHLO");
|
||||
}
|
||||
|
||||
RankedTensorType valResultType = getTypeConverter()
|
||||
|
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
|
||||
dim, options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
|
||||
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
|
||||
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
|
||||
auto stablehloReduceValueResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), valResultType, stablehloReduceResults[0],
|
||||
outShapeTensor);
|
||||
auto stablehloReduceIndexResult =
|
||||
rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||
op->getLoc(), idxResultType, stablehloReduceResults[1],
|
||||
outShapeTensor);
|
||||
rewriter.replaceOp(
|
||||
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
|
||||
rewriter.replaceOp(op,
|
||||
{stablehloReduceResults[0], stablehloReduceResults[1]});
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenSumOp to MHLO");
|
||||
"AtenSumOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
|
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
}
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
Value input = adaptor.getSelf();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
if (!inputElemTy.isIntOrFloat()) {
|
||||
|
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenMaxOp to MHLO");
|
||||
"AtenMaxOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
|
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Block &block = mhloReduceOp.getBody().emplaceBlock();
|
||||
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value maxResult = rewriter.create<mhlo::MaxOp>(
|
||||
Value maxResult = rewriter.create<stablehlo::MaxOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
->convertType(op.getType())
|
||||
.template dyn_cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only Tensor types supported in StableHLO");
|
||||
}
|
||||
if (inputTy.getElementType() != outTy.getElementType()) {
|
||||
// Use output element type as computation type.
|
||||
auto dstElemTy = outTy.getElementType();
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
input =
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
|
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
inputElemTy.getIntOrFloatBitWidth() == 8) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "IntegerType with bitwidth 8 unsupported in convertion from "
|
||||
"AtenSumDimIntListOp to MHLO");
|
||||
"AtenSumDimIntListOp to StableHLO");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> inputDims;
|
||||
|
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
Value initValue =
|
||||
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
||||
if (!initValue) return failure();
|
||||
if (!initValue)
|
||||
return failure();
|
||||
|
||||
llvm::sort(dims.begin(), dims.end());
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Region ®ion = mhloReduceOp.getBody();
|
||||
Region ®ion = stablehloReduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
||||
|
||||
|
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
Value addResult = rewriter.create<mhlo::AddOp>(
|
||||
Value addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
|
||||
}
|
||||
|
||||
if (keepDim) {
|
||||
const auto &options = getOptions();
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()),
|
||||
mhloReduceOp.getResult(0), outShapeTensor);
|
||||
stablehloReduceOp.getResult(0), outShapeTensor);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
|
||||
mhloReduceOp.getResults());
|
||||
stablehloReduceOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenFrobeniusNormDimOp
|
||||
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
|
||||
// + mhlo.sqrt
|
||||
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
|
||||
// dims)
|
||||
// + stablehlo.sqrt
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
const TorchToMhloOptions &options = getOptions();
|
||||
const TorchToStablehloOptions &options = getOptions();
|
||||
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
// Sort the dims in ascending order, making the conversion
|
||||
// Sort the dims in ascending order, making the conversion
|
||||
// stable with unordered dims.
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
|
@ -624,58 +642,57 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
op, "non-const bool `keepdim` is not supported");
|
||||
}
|
||||
|
||||
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
|
||||
|
||||
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
|
||||
if (!initValue) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
|
||||
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
|
||||
op->getLoc(), squareOp.getResult(), initValue,
|
||||
rewriter.getI64TensorAttr(dims));
|
||||
|
||||
Region ®ion = squareSumReduceOp.getBody();
|
||||
Region ®ion = reduceOp.getBody();
|
||||
Block &block = region.emplaceBlock();
|
||||
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
|
||||
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
block.addArgument(blockArgumentTy, op->getLoc());
|
||||
|
||||
auto *firstArgument = block.args_begin();
|
||||
auto secondArgument = block.args_rbegin();
|
||||
auto firstArgument = *block.args_begin();
|
||||
auto secondArgument = *block.args_rbegin();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
|
||||
auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
|
||||
op->getLoc(), blockArgumentTy,
|
||||
DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
|
||||
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
|
||||
auto squareResult = rewriter.create<mhlo::PowOp>(
|
||||
op->getLoc(), abs, constantOrd2);
|
||||
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
|
||||
*firstArgument);
|
||||
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
auto addResult = rewriter.create<stablehlo::AddOp>(
|
||||
op->getLoc(), firstArgument, secondArgument);
|
||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
|
||||
}
|
||||
|
||||
auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
|
||||
squareSumReduceOp.getResult(0));
|
||||
auto output =
|
||||
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto outShapeInfo =
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output,
|
||||
outShapeTensor);
|
||||
return success();
|
||||
|
@ -685,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
|
@ -7,11 +7,12 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include <numeric>
|
||||
|
@ -21,27 +22,27 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace hlo {
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val) {
|
||||
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Create a 64-bit float constant operator from a double
|
||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val) {
|
||||
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val) {
|
||||
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, val);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|||
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
|||
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
|||
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
||||
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
||||
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
|||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
||||
auto const_type = RankedTensorType::get(dshape, dtype);
|
||||
auto const_attr = SplatElementsAttr::get(const_type, val);
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
||||
op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype) {
|
||||
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value scalarValue, Type dtype) {
|
||||
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ArrayRef<Value>{scalarValue});
|
||||
auto dtype_tensor =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||
return rewriter.create<mhlo::ReshapeOp>(
|
||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
||||
return rewriter.create<stablehlo::ReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
||||
dtype_tensor);
|
||||
}
|
||||
|
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
|||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
|
||||
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
|
||||
input);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promoted_type =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
input =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
||||
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
|
||||
input);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> inShape = in_type.getShape();
|
||||
|
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|||
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
||||
rewriter.getI64Type()),
|
||||
bcastDims);
|
||||
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
|
||||
input, bcast_attr);
|
||||
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
op->getLoc(), outType, input, bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
|||
const APFloat &constant, Value shape,
|
||||
TensorType outType) {
|
||||
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
||||
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
|
||||
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
|
||||
return rewriter
|
||||
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
|
||||
rewriter.getI64TensorAttr({}))
|
||||
.create<stablehlo::DynamicBroadcastInDimOp>(
|
||||
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
|
||||
.getResult();
|
||||
}
|
||||
} // namespace mhlo
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
@ -18,22 +18,22 @@
|
|||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace hlo {
|
||||
|
||||
using mlir::ConversionPatternRewriter;
|
||||
|
||||
// Create a 32-bit float constant operator from a float
|
||||
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
||||
// Create a 64-bit float constant operator from a double
|
||||
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val);
|
||||
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
||||
double val);
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
|
||||
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
|
||||
template <typename T>
|
||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
|
@ -42,8 +42,8 @@ template <typename T>
|
|||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
|
||||
|
||||
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value scalarValue, Type dtype);
|
||||
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value scalarValue, Type dtype);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||
|
||||
|
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType);
|
||||
} // namespace mhlo
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
|
|
@ -7,17 +7,18 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
namespace {
|
||||
|
||||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||
class ConvertTorchToStablehlo
|
||||
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
|
||||
public:
|
||||
ConvertTorchToMhlo() = default;
|
||||
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
|
||||
ConvertTorchToStablehlo() = default;
|
||||
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
|
||||
this->enableStaticShape = enableStaticShape;
|
||||
this->enableI32Index = enableI32Index;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
registry.insert<stablehlo::StablehloDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
|
@ -48,7 +50,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
||||
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
@ -57,20 +59,20 @@ public:
|
|||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
|
||||
enableI32Index ? 32u : 64u};
|
||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
torch_to_stablehlo::TorchToStablehloOptions options{
|
||||
enableStaticShape, enableI32Index ? 32u : 64u};
|
||||
torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
||||
target, options);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
@ -82,13 +84,13 @@ public:
|
|||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass() {
|
||||
return std::make_unique<ConvertTorchToMhlo>(false, false);
|
||||
mlir::torch::createConvertTorchToStablehloPass() {
|
||||
return std::make_unique<ConvertTorchToStablehlo>(false, false);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
|
||||
bool enableI32Index) {
|
||||
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
|
||||
enableI32Index);
|
||||
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
|
||||
bool enableI32Index) {
|
||||
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
|
||||
enableI32Index);
|
||||
}
|
|
@ -7,14 +7,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -28,7 +29,7 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
using namespace mlir::torch::torch_to_stablehlo;
|
||||
|
||||
namespace {
|
||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||
|
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
|||
auto stridesTensor =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
||||
|
||||
return rewriter.create<mhlo::RealDynamicSliceOp>(
|
||||
return rewriter.create<stablehlo::RealDynamicSliceOp>(
|
||||
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
||||
}
|
||||
|
||||
|
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
||||
}
|
||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -179,7 +180,7 @@ public:
|
|||
auto loc = op.getLoc();
|
||||
auto newRank = dimSizes.size();
|
||||
if (newRank == 0 || rankType.getRank() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -214,17 +215,18 @@ public:
|
|||
numel);
|
||||
|
||||
if (dimSizes.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
return success();
|
||||
}
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
|
||||
loc, mhloShape.getType(), numel, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
Value stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
|
||||
loc, stablehloShape.getType(), numel, stablehloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
|
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
dims.push_back(r);
|
||||
}
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
std::iota(dims.begin(), dims.end(), 0);
|
||||
dims.erase(dims.begin() + dim);
|
||||
if (dims.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
|
||||
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
if (failed(unsqzTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to create unsqueezed tensor");
|
||||
|
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -26,6 +27,9 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -52,6 +56,147 @@ using namespace mlir::torch::TMTensor;
|
|||
// that these patterns become mostly mechanical associations of
|
||||
// "aten.foo -> linalg.foo".
|
||||
|
||||
static Attribute getNumericLimit(PatternRewriter &rewriter, Type elementType,
|
||||
bool getMin = true) {
|
||||
auto bitWidth = elementType.getIntOrFloatBitWidth();
|
||||
if (llvm::isa<mlir::IntegerType>(elementType)) {
|
||||
if (getMin) {
|
||||
return rewriter.getIntegerAttr(elementType,
|
||||
APInt::getSignedMinValue(bitWidth));
|
||||
} else {
|
||||
return rewriter.getIntegerAttr(elementType,
|
||||
APInt::getSignedMaxValue(bitWidth));
|
||||
}
|
||||
} else if (mlir::FloatType floatType =
|
||||
llvm::dyn_cast<mlir::FloatType>(elementType)) {
|
||||
return rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getLargest(floatType.getFloatSemantics(), getMin));
|
||||
} else {
|
||||
llvm_unreachable("Only float/integer types are supported!");
|
||||
}
|
||||
}
|
||||
|
||||
// This function will reformat the `index` and `src` from torch operations
|
||||
// like `torch.scatter` or `torch.scatter_reduce` to match the expected
|
||||
// input for the TMScatterOp. It will return the reformated `index` and `src`
|
||||
// as a pair of mlir::Value that can be used as inputs for the TMScatterOp.
|
||||
static std::pair<Value, Value>
|
||||
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
|
||||
Value indices, Value src,
|
||||
int64_t dim) {
|
||||
// Get information on types for inputs
|
||||
RankedTensorType indexType = indices.getType().cast<RankedTensorType>();
|
||||
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>();
|
||||
|
||||
// Store location for insertions
|
||||
Location loc = src.getLoc();
|
||||
|
||||
Value indexSize = getTensorSize(rewriter, loc, indices);
|
||||
indexSize = castIntToIndex(rewriter, loc, indexSize);
|
||||
SmallVector<Value> indexShape = getTensorSizes(rewriter, loc, indices);
|
||||
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
// We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...)
|
||||
SmallVector<Value> indSliceShape({indexSize, cstOne});
|
||||
Value indSlice =
|
||||
createZeroInitTensor(rewriter, loc, indSliceShape, rewriter.getI32Type());
|
||||
|
||||
// New output shape will be equal to the product of the dimensions of the
|
||||
// updates
|
||||
SmallVector<Value> outputs(indexType.getRank(), indSlice);
|
||||
outputs.push_back(createZeroInitTensor(rewriter, loc, {indexSize},
|
||||
srcSelf.getElementType()));
|
||||
SmallVector<Type> outputsType(indexType.getRank(), indSlice.getType());
|
||||
outputsType.push_back(outputs[indexType.getRank()].getType());
|
||||
|
||||
// Create mapping over flattened iteration space
|
||||
SmallVector<AffineExpr> indSliceExpr = {rewriter.getAffineDimExpr(0),
|
||||
rewriter.getAffineConstantExpr(0)};
|
||||
SmallVector<AffineMap> mapping(
|
||||
indexType.getRank(), AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
|
||||
indSliceExpr, src.getContext()));
|
||||
// Mapping for updates
|
||||
mapping.push_back(rewriter.getDimIdentityMap());
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
{utils::IteratorType::parallel});
|
||||
|
||||
// This function goes over the flattened iteration space of the `indices`
|
||||
// and `src`. It will reconstruct the original induction variables based
|
||||
// on the current flattened index. The flattened iteration space is required
|
||||
// because TMTensorScatterOp expects a list of single element updates.
|
||||
auto flattenedUpdates =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outputsType, ValueRange(), outputs, mapping, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> indexValues(indexType.getRank());
|
||||
Value ind = b.create<linalg::IndexOp>(loc, 0);
|
||||
for (int i = indexType.getRank() - 1; i >= 0; i--) {
|
||||
indexValues[i] =
|
||||
b.create<arith::RemSIOp>(loc, ind, indexShape[i]);
|
||||
ind = b.create<arith::DivSIOp>(loc, ind, indexShape[i]);
|
||||
}
|
||||
// Extract the scatter index and update value
|
||||
Value extractIndexValue =
|
||||
b.create<tensor::ExtractOp>(loc, indices, indexValues);
|
||||
Value extractSrcValue =
|
||||
b.create<tensor::ExtractOp>(loc, src, indexValues);
|
||||
SmallVector<Value> yieldVals;
|
||||
for (Value v : indexValues) {
|
||||
Value scalar = castIndexToInt64(b, loc, v);
|
||||
yieldVals.push_back(b.create<arith::TruncIOp>(
|
||||
loc, rewriter.getI32Type(), scalar));
|
||||
}
|
||||
// Replace the original index with the index specified
|
||||
// by the scatter.
|
||||
yieldVals[dim] = b.create<arith::TruncIOp>(
|
||||
loc, rewriter.getI32Type(), extractIndexValue);
|
||||
yieldVals.push_back(extractSrcValue);
|
||||
b.create<linalg::YieldOp>(loc, yieldVals);
|
||||
})
|
||||
.getResultTensors();
|
||||
|
||||
auto toOpFoldResult = [](Value v) -> OpFoldResult {
|
||||
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!op)
|
||||
return v;
|
||||
return op.getValue();
|
||||
};
|
||||
|
||||
// The result of the linalg::Generic operation gives us (rank(`src`) + 1)
|
||||
// 1D-tensors where each contains a number of elements equal to the total
|
||||
// number of elements in the `src` tensor. The indices must now be
|
||||
// constructed by concatanating the first rank(`src`) tensors together. The
|
||||
// new `src` tensor is the last tensor returned from the linalg::Generic
|
||||
// operation.
|
||||
SmallVector<Value> offsets = {
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 0),
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 0)};
|
||||
SmallVector<Value> strides = {
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 1),
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 1)};
|
||||
Value indicesRank =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, indexType.getRank());
|
||||
Value flattenedIndices = createZeroInitTensor(
|
||||
rewriter, loc, SmallVector<Value>({indexSize, indicesRank}),
|
||||
rewriter.getI32Type());
|
||||
SmallVector<Value> scatterInputsVector(flattenedUpdates);
|
||||
for (auto const slice : ArrayRef(scatterInputsVector).drop_back()) {
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, slice);
|
||||
flattenedIndices = rewriter.createOrFold<tensor::InsertSliceOp>(
|
||||
loc, slice, flattenedIndices,
|
||||
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
|
||||
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
|
||||
// Increment offset to insert into next column
|
||||
offsets[1] = rewriter.createOrFold<arith::AddIOp>(loc, offsets[1], cstOne);
|
||||
}
|
||||
|
||||
return std::make_pair(flattenedIndices,
|
||||
scatterInputsVector[indexType.getRank()]);
|
||||
}
|
||||
|
||||
static Value createTMTensorScatterOp(
|
||||
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
|
||||
bool uniqueIndices,
|
||||
|
@ -142,7 +287,7 @@ public:
|
|||
// Finding the maximum value in the input tensor.
|
||||
SmallVector<int64_t> maxTensorSizes;
|
||||
ValueTensorType maxTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(maxTensorSizes),
|
||||
context, llvm::ArrayRef(maxTensorSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value maxTensor =
|
||||
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
||||
|
@ -165,7 +310,7 @@ public:
|
|||
SmallVector<int64_t> expandedInputSizes{
|
||||
makeShapeTorchCompatible(inputType.getShape())[0], 1};
|
||||
ValueTensorType expandInputType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedInputSizes),
|
||||
context, llvm::ArrayRef(expandedInputSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
|
@ -286,9 +431,9 @@ public:
|
|||
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
|
||||
int64_t indexTensorSize = indexTensorType.getSizes()[0];
|
||||
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
|
||||
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIndexTensorSizes),
|
||||
indexTensorType.getDtype());
|
||||
ValueTensorType expandedIndexTensorType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes),
|
||||
indexTensorType.getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
|
@ -552,6 +697,229 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenScatterReduceTwoOp
|
||||
: public OpConversionPattern<AtenScatterReduceTwoOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenScatterReduceTwoOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
RankedTensorType selfType =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>();
|
||||
RankedTensorType indexType =
|
||||
adaptor.getIndex().getType().cast<RankedTensorType>();
|
||||
RankedTensorType srcType =
|
||||
adaptor.getSrc().getType().cast<RankedTensorType>();
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
if (selfType.getRank() != indexType.getRank() ||
|
||||
indexType.getRank() != srcType.getRank())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"'self', 'index' and 'src' should all "
|
||||
"have the same number of dimensions.");
|
||||
|
||||
std::string reduceType;
|
||||
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceType)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"'reduce' must be a costant string");
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(op, "'dim' is not constant");
|
||||
|
||||
bool includeSelf;
|
||||
if (!matchPattern(op.getIncludeSelf(), m_TorchConstantBool(&includeSelf)))
|
||||
return rewriter.notifyMatchFailure(op, "'include_self' is not constant");
|
||||
|
||||
// Get reduce string as the equivalent enum
|
||||
auto reduceEnum = torch_upstream::get_reduction_enum(reduceType);
|
||||
|
||||
// Get the inputs reformatted for the TMScatterOp
|
||||
auto [indices, updates] =
|
||||
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(
|
||||
rewriter, adaptor.getIndex(), adaptor.getSrc(), dim);
|
||||
|
||||
// Value 'counts' will be used to tally the number of reductions into
|
||||
// each unique index. The tally is used to calculate the average of the
|
||||
// values scattered per index.
|
||||
Value counts = nullptr;
|
||||
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
SmallVector<Value> selfShape =
|
||||
getTensorSizes(rewriter, loc, adaptor.getSelf());
|
||||
Attribute initAttr;
|
||||
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
|
||||
initAttr = rewriter.getFloatAttr(srcType.getElementType(), 1);
|
||||
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
|
||||
initAttr = rewriter.getIntegerAttr(srcType.getElementType(), 1);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
Value initElement = rewriter.create<arith::ConstantOp>(loc, initAttr);
|
||||
counts = createInitTensor(rewriter, loc, selfShape,
|
||||
selfType.getElementType(), initElement);
|
||||
}
|
||||
|
||||
// If the original values shouldn't be included, normalize the
|
||||
// input tensor where the scatters take place.
|
||||
if (!includeSelf) {
|
||||
Value normalizationValue;
|
||||
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
||||
reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
// Set the values in the input tensor to '0' so they are not included
|
||||
normalizationValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(srcType.getElementType()));
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
|
||||
// Set the values in the input tensor to '1' (multiplication identity)
|
||||
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
|
||||
normalizationValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getFloatAttr(srcType.getElementType(), 1.0));
|
||||
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
|
||||
normalizationValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(srcType.getElementType(), 1));
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
|
||||
// Set the values in the input tensor to the smallest element of that
|
||||
// type
|
||||
auto minAttr = getNumericLimit(rewriter, srcType.getElementType(),
|
||||
/*getMin=*/true);
|
||||
normalizationValue = rewriter.create<arith::ConstantOp>(loc, minAttr);
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
|
||||
// Set the values in the input tensor to the largest element of that
|
||||
// type
|
||||
auto maxAttr = getNumericLimit(rewriter, srcType.getElementType(),
|
||||
/*getMin=*/false);
|
||||
normalizationValue = rewriter.create<arith::ConstantOp>(loc, maxAttr);
|
||||
}
|
||||
|
||||
// Scatter the normalizations into the input tensor
|
||||
Value indexSize = getTensorSize(rewriter, loc, adaptor.getIndex());
|
||||
indexSize = castIntToIndex(rewriter, loc, indexSize);
|
||||
Value normalizations = createInitTensor(
|
||||
rewriter, loc, SmallVector<Value>({indexSize}),
|
||||
srcType.getElementType(), /*init_element=*/normalizationValue);
|
||||
self = createTMTensorScatterOp(
|
||||
rewriter, loc, normalizations, indices, self,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
b.create<TMTensor::YieldOp>(loc, update);
|
||||
});
|
||||
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
counts = createTMTensorScatterOp(
|
||||
rewriter, loc, normalizations, indices, counts,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
b.create<TMTensor::YieldOp>(loc, update);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Create final operation
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, self,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
Value result;
|
||||
if (reduceEnum == torch_upstream::ReductionType::SUM ||
|
||||
reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::AddIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::AddFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::MulIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::MulFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::MaxSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::MaxFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
|
||||
if (update.getType().isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::MinSIOp>(loc, update, current);
|
||||
} else if (update.getType().isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::MinFOp>(loc, update, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
}
|
||||
b.create<TMTensor::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
// Special case for the mean
|
||||
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
|
||||
counts = createTMTensorScatterOp(
|
||||
rewriter, loc, updates, indices, counts,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value update, Value current) {
|
||||
Value result;
|
||||
if (mlir::IntegerType intType =
|
||||
llvm::dyn_cast<mlir::IntegerType>(current.getType())) {
|
||||
Value constantUpdate = b.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(intType, 1));
|
||||
result = b.create<arith::AddIOp>(loc, constantUpdate, current);
|
||||
} else if (mlir::FloatType floatType =
|
||||
llvm::dyn_cast<mlir::FloatType>(current.getType())) {
|
||||
Value constantUpdate = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(floatType, 1.0));
|
||||
result = b.create<arith::AddFOp>(loc, constantUpdate, current);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
b.create<TMTensor::YieldOp>(loc, result);
|
||||
});
|
||||
|
||||
Value output = rewriter.create<tensor::EmptyOp>(
|
||||
loc, tensor::getMixedSizes(rewriter, loc, self),
|
||||
selfType.getElementType());
|
||||
|
||||
// Finally divide the result
|
||||
scatterOp =
|
||||
rewriter
|
||||
.create<linalg::MapOp>(
|
||||
loc, ValueRange{scatterOp, counts}, output,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result;
|
||||
if (llvm::isa<mlir::IntegerType>(args[0].getType())) {
|
||||
result = b.create<arith::DivSIOp>(loc, args[0], args[1]);
|
||||
} else if (llvm::isa<mlir::FloatType>(args[0].getType())) {
|
||||
result = b.create<arith::DivFOp>(loc, args[0], args[1]);
|
||||
} else {
|
||||
llvm_unreachable("Only integer/float types supported!");
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult()[0];
|
||||
}
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
||||
public:
|
||||
|
@ -644,6 +1012,8 @@ public:
|
|||
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<AtenScatterReduceTwoOp>();
|
||||
patterns.add<ConvertAtenScatterReduceTwoOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCumsumOp>();
|
||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
@ -717,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp
|
|||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
|
||||
|
@ -747,8 +748,8 @@ class ConvertAtenOneDimReductionOp
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef({reduceDim}));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
|
||||
|
@ -781,8 +782,8 @@ public:
|
|||
reduceDims.push_back(i);
|
||||
int64_t N = selfTy.getRank();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
|
||||
keepDims = false;
|
||||
|
||||
return success();
|
||||
|
@ -2645,6 +2646,36 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"size must consist of Scalar constants");
|
||||
|
||||
// the shape -1 is inferred from other dimensions
|
||||
size_t countNegativeShape{0};
|
||||
// Check at most one -1 shape
|
||||
for (size_t i = 0; i < outShape.size(); i++) {
|
||||
if (outShape[i] < 0) {
|
||||
countNegativeShape++;
|
||||
if (countNegativeShape > 1)
|
||||
return rewriter.notifyMatchFailure(op, "At most one -1 shape");
|
||||
}
|
||||
}
|
||||
|
||||
auto inputShape = selfType.getShape();
|
||||
size_t totalSize = 1;
|
||||
for (size_t i = 0; i < inputShape.size(); i++) {
|
||||
totalSize *= inputShape[i];
|
||||
}
|
||||
|
||||
size_t otherSize = 1;
|
||||
for (size_t i = 0; i < outShape.size(); i++) {
|
||||
if (outShape[i] > 0) {
|
||||
otherSize *= outShape[i];
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < outShape.size(); i++) {
|
||||
if (outShape[i] < 0) {
|
||||
outShape[i] = totalSize / otherSize;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||
rewriter.getDenseI64ArrayAttr(outShape));
|
||||
|
@ -2816,6 +2847,79 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
||||
AtenHardtanhBackwardOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
}
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isIntOrFloat()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
// Integer types with width > 32 are not supported
|
||||
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
|
||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Integer types with width greater than 32 are not supported");
|
||||
}
|
||||
|
||||
Value gradOutput = adaptor.getGradOutput();
|
||||
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
||||
|
||||
Type gradOutputElemType = gradOutputType.getElementType();
|
||||
|
||||
if (selfElemTy != gradOutputElemType) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Input element type should be same as the grad_output element type.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
||||
Value maxVal, minVal;
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMinVal(), minVal,
|
||||
selfElemTy, constTypeShape))) {
|
||||
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
|
||||
}
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMaxVal(), maxVal,
|
||||
selfElemTy, constTypeShape))) {
|
||||
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
|
||||
}
|
||||
|
||||
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
|
||||
Type outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Value lesser = rewriter.create<tosa::GreaterOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||
minVal, adaptor.getSelf());
|
||||
|
||||
Value greater = rewriter.create<tosa::GreaterOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||
adaptor.getSelf(), maxVal);
|
||||
|
||||
Value cmp = rewriter.create<tosa::LogicalOrOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||
lesser, greater);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmp, replace,
|
||||
gradOutput);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||
AtenEmbeddingOp op, OpAdaptor adaptor,
|
||||
|
@ -3113,31 +3217,70 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape)))
|
||||
SmallVector<int64_t> resultShape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"size must consist of Scalar constants");
|
||||
// Get the result type
|
||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
SmallVector<int64_t> inputShape(
|
||||
makeShapeTorchCompatible(selfType.getShape()));
|
||||
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
|
||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||
// true then we can replace the op result with the input operand
|
||||
// irrespective of the users of the op result.
|
||||
if (!llvm::equal(inputShape, outShape)) {
|
||||
for (auto user : op->getResult(0).getUsers()) {
|
||||
// This case is only supported if the result of the `broadcast_to` op is
|
||||
// not used by an op which is a view like.
|
||||
if (isViewLikeOp(user)) {
|
||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||
// true then we can replace the op result with the input operand directly.
|
||||
if (llvm::equal(inputShape, resultShape)) {
|
||||
// If we reach here, then it means that the broadcasting is not required
|
||||
// since the input and result are of same shape.
|
||||
op.replaceAllUsesWith(op.getSelf());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
} else if (selfType.hasRank() &&
|
||||
(selfType.getRank() == (int64_t)resultShape.size() ||
|
||||
selfType.getRank() == 0)) {
|
||||
// Right now to support limited cases where input and result shape are not
|
||||
// equal, we can put a constraint that either the input should be of rank
|
||||
// 0 or the rank of input tensor and result should be equal. And then we
|
||||
// can check for broadcasting compatibility for the latter case. For
|
||||
// broadcasting compatibility, either the shape of input and result should
|
||||
// be equal at each dimenion or one of them should be 1.
|
||||
if (selfType.getRank() != 0) {
|
||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
|
||||
resultShape[i] != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: broadcast not supported for this case");
|
||||
op, "unimplemented: either the shape of input and result should "
|
||||
"be equal at each dimenion or one of them should be 1.");
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we reach here, then it means the given case is handled by implicit
|
||||
// broadcasting done by tosa.
|
||||
op.replaceAllUsesWith(op.getSelf());
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
// If the above condition hold true then we can directly create a const
|
||||
// zero tensor of shape same as the result shape.
|
||||
SmallVector<int64_t> zeroTensorShape{resultShape};
|
||||
|
||||
// create the 0 constant tensor
|
||||
int64_t totalNumElements = 1;
|
||||
for (auto dimSize : zeroTensorShape) {
|
||||
totalNumElements = dimSize * totalNumElements;
|
||||
}
|
||||
// There is some danger here. For edge cases in floating point, x + 0 != x.
|
||||
// The cases are denormalized values, which may get flushed, and -0 + 0 =
|
||||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
||||
// but we should put a comment acknowledging the danger, as there isn't an
|
||||
// op that avoids the denorm flushing.
|
||||
SmallVector<int64_t> intValues(totalNumElements, 0);
|
||||
SmallVector<float> floatValues(totalNumElements, 0.0);
|
||||
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
|
||||
? tosa::getConstTensor<float>(
|
||||
rewriter, op, floatValues, zeroTensorShape)
|
||||
.value()
|
||||
: tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, intValues, zeroTensorShape)
|
||||
.value();
|
||||
|
||||
// Use add broadcast
|
||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
||||
zeroTensor);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -3232,6 +3375,171 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||||
AtenIndexTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10],
|
||||
// [11,12,13,14,15],[16,17,18,19,20]]) # 4*5
|
||||
// i = tf.constant([[1,2,3], [3,2,1]]) # 2*3
|
||||
// i_expand = tf.expand_dims(i,axis=2) # 2*3*1
|
||||
// IndexTensorOutput = tf.gather_nd(t,tf.i_expand)
|
||||
// = torch.ops.aten.index(t, (i, )) = t[i] # 2*3*5
|
||||
// [[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
|
||||
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
|
||||
auto input = adaptor.getSelf();
|
||||
auto inputTensorType =
|
||||
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
||||
// Check input is a tensor type.
|
||||
if (!inputTensorType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
|
||||
// Deal with torch.prim.ListConstruct of non const value to get the index
|
||||
auto tensorList = op.getIndices();
|
||||
SmallVector<Value> tensorsTorchType;
|
||||
if (!getListConstructElements(tensorList, tensorsTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
auto indexTensors = getTypeConvertedValues(
|
||||
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
// Support for multiple indexes
|
||||
if (indexTensors.size() > 1) {
|
||||
// t[i, i]
|
||||
// = torch.ops.aten.index(t,(i,i))
|
||||
// = tensor([[ t[1,1], t[2,2], t[3,3]],
|
||||
// [ t[3,3], t[2,2], t[1,1]]])
|
||||
// = tensor([[ 7, 13, 19], [19, 13, 7]])
|
||||
// = tf.gather_nd(t,tf.ii_expand)
|
||||
// ii_expand
|
||||
// = tf.concat((i_expand,i_expand), dim=2)
|
||||
// = tf.constant([[[1,1],[2,2],[3,3]],
|
||||
// [[3,3],[2,2],[1,1]]]) # 2*3*2
|
||||
SmallVector<Value> indicesTfConcatTensors;
|
||||
SmallVector<int64_t> indexesRank;
|
||||
SmallVector<SmallVector<int64_t>> indexesShape;
|
||||
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto index = indexTensors[i];
|
||||
auto indexTorch = tensorsTorchType[i];
|
||||
// TODO add support for none index input like torch.ops.aten.index(x,
|
||||
// (None, index1, index2, None))
|
||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexShape = indexType.getShape();
|
||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||
indexesRank.push_back(indexType.getRank());
|
||||
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
|
||||
index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indiceShapeOneDim;
|
||||
for (auto shape : indexShape) {
|
||||
indiceShapeOneDim.push_back(shape);
|
||||
}
|
||||
indiceShapeOneDim.push_back(1);
|
||||
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
|
||||
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
|
||||
|
||||
// create concat tensor for indicesTf
|
||||
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
|
||||
}
|
||||
|
||||
// Right now only support multiple indexes with same shape
|
||||
// TODO for different shape multiple indexes, add broadcast_to for small
|
||||
// shape
|
||||
for (auto indexShapeOneDim : indexesShape) {
|
||||
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: Only support multi indexes with same shape");
|
||||
}
|
||||
}
|
||||
|
||||
// concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2]
|
||||
auto indicesShapeConcat = indexesShape[0];
|
||||
uint64_t lastDim = indexesRank[0];
|
||||
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||
rewriter, op->getLoc(),
|
||||
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
|
||||
indicesTfConcatTensors, lastDim);
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf gathernp algorithm with tf style indices as input.
|
||||
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert GatherNdOp fail for index tensor.");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Support for multiple index
|
||||
auto index = indexTensors[0];
|
||||
auto indexTorch = tensorsTorchType[0];
|
||||
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
|
||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
||||
auto indexShape = indexType.getShape();
|
||||
// index i64 to i32 for tosa compatible
|
||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||
index = rewriter.create<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
|
||||
}
|
||||
|
||||
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
|
||||
SmallVector<int64_t> indicesShape;
|
||||
for (auto shape : indexShape) {
|
||||
indicesShape.push_back(shape);
|
||||
}
|
||||
indicesShape.push_back(1);
|
||||
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
|
||||
rewriter.getDenseI64ArrayAttr(indicesShape));
|
||||
|
||||
if (!indicesTf) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Convert TorchIndex To TfIndices fail.");
|
||||
}
|
||||
// do the tf gathernp algorithm with tf style indices as input.
|
||||
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
|
||||
indicesTf.getResult());
|
||||
|
||||
if (!result) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert GatherNdOp fail for index tensor.");
|
||||
}
|
||||
rewriter.replaceOp(op, {result.value()});
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||
AtenWhereSelfOp op, OpAdaptor adaptor,
|
||||
|
@ -3968,9 +4276,11 @@ public:
|
|||
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
|
||||
return op.emitError(
|
||||
"unimplemented: only default memory format is supported");
|
||||
"unimplemented: only contiguous and channels last memory "
|
||||
"format is supported");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -4169,6 +4479,7 @@ public:
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
|
||||
|
@ -4196,6 +4507,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
|
|
|
@ -230,6 +230,10 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
|||
(src.isInteger(32) && dest.isInteger(1)) ||
|
||||
(src.isInteger(32) && dest.isF32()) ||
|
||||
(src.isInteger(8) && dest.isInteger(1)) ||
|
||||
(src.isInteger(1) && dest.isInteger(64)) ||
|
||||
(src.isInteger(1) && dest.isF32()) ||
|
||||
(src.isF32() && dest.isF64()) ||
|
||||
(src.isF64() && dest.isF32()) ||
|
||||
(src.isF32() && dest.isInteger(8)) ||
|
||||
(src.isF32() && dest.isInteger(1))) {
|
||||
return success();
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
@ -31,11 +32,11 @@ namespace {
|
|||
struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
IRMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
IRMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|||
return FloatAttr::get(Float64Type::get(context), value);
|
||||
}
|
||||
|
||||
static Value getScalarValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
static Value getScalarIntValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputType = input.getType();
|
||||
if (inputType.isa<Torch::IntType>()) {
|
||||
return input;
|
||||
}
|
||||
Value scalar = nullptr;
|
||||
|
||||
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
||||
if (!inputTensorType)
|
||||
return nullptr;
|
||||
|
||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||
if (!inputDtype || !inputDtype.isInteger(64))
|
||||
return nullptr;
|
||||
|
||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||
if (!inputRank || *inputRank != 0)
|
||||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
std::optional<unsigned> tensorRank =
|
||||
getTensorRank(valueTensorLiteralOp.getResult());
|
||||
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
|
||||
auto tensorType =
|
||||
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
|
||||
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
}
|
||||
}
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
scalar = primNumToTensorScalarOp.getA();
|
||||
return primNumToTensorScalarOp.getA();
|
||||
}
|
||||
return scalar;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -386,7 +390,7 @@ void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
|
|||
// If the condition is constant, we can give a more precise answer.
|
||||
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||
Region *executedRegion =
|
||||
condAttr.getValue().isOneValue() ? &getThenRegion() : &getElseRegion();
|
||||
condAttr.getValue().isOne() ? &getThenRegion() : &getElseRegion();
|
||||
regions.push_back(RegionSuccessor(executedRegion));
|
||||
return;
|
||||
}
|
||||
|
@ -507,7 +511,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
return isValidSubtype(inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
|
||||
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
|
||||
if (!uncheckedCast)
|
||||
return nullptr;
|
||||
|
@ -570,10 +574,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
|||
// Aten__RangeLengthOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto lo = operands[0];
|
||||
auto hi = operands[1];
|
||||
auto step = operands[2];
|
||||
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
||||
auto lo = adaptor.getLo();
|
||||
auto hi = adaptor.getHi();
|
||||
auto step = adaptor.getStep();
|
||||
if (!lo || !hi || !step)
|
||||
return nullptr;
|
||||
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
|
@ -595,10 +599,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__DeriveIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto index = operands[0];
|
||||
auto start = operands[1];
|
||||
auto step = operands[2];
|
||||
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
|
||||
auto index = adaptor.getIndex();
|
||||
auto start = adaptor.getStart();
|
||||
auto step = adaptor.getStep();
|
||||
if (!index || !start || !step)
|
||||
return nullptr;
|
||||
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
|
@ -612,7 +616,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Is__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
|
||||
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
|
||||
}
|
||||
|
||||
|
@ -620,7 +624,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Isnot__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
|
||||
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
|
||||
}
|
||||
|
||||
|
@ -628,7 +632,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Not__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
||||
bool value;
|
||||
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
|
||||
return nullptr;
|
||||
|
@ -639,7 +643,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenNeBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
||||
|
||||
|
@ -655,7 +659,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqueezeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand();
|
||||
|
@ -667,7 +671,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqueezeDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
|
@ -679,7 +683,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenRoundOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
||||
return getSelf();
|
||||
|
@ -691,7 +695,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenTypeAsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
|
||||
Type inType = getSelf().getType();
|
||||
Type newType = getOther().getType();
|
||||
|
||||
|
@ -705,7 +709,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenToDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
||||
bool nonBlocking, copyArg;
|
||||
// The non_blocking arg must be `False`.
|
||||
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
|
||||
|
@ -736,7 +740,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenToDtypeLayoutOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||
// The pin_memory arg should be either constant `False` or `none`.
|
||||
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
bool pinMemory;
|
||||
|
@ -797,7 +801,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
|
@ -812,7 +816,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes())
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
|
@ -825,7 +829,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLenTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
|
||||
// `len([1,1,1])` -> `3`, if it is not mutated.
|
||||
if (auto listConstruct =
|
||||
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||
|
@ -853,7 +857,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenLenStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
|
||||
return getI64IntegerAttr(getContext(),
|
||||
stringConstruct.getValueAttr().getValue().size());
|
||||
|
@ -869,22 +873,25 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
if (op->getNumOperands() < 2) {
|
||||
return failure();
|
||||
}
|
||||
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
|
||||
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
|
||||
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
|
||||
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
|
||||
auto outType = op->getResult(0).getType();
|
||||
|
||||
if (!lhs || !rhs) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only int scalar lhs or rhs is supported");
|
||||
}
|
||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
||||
op)) {
|
||||
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
|
||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
|
||||
AtenAddScalarOp>(op)) {
|
||||
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
||||
if (!alpha) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only int scalar alpha is supported");
|
||||
}
|
||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||
if (isa<AtenRsubScalarOp>(op))
|
||||
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
|
||||
else
|
||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||
}
|
||||
|
||||
if (isa<AtenDivTensorModeOp>(op)) {
|
||||
|
@ -937,6 +944,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
||||
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
||||
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
||||
} else if (isa<AtenRsubScalarOp>(op)) {
|
||||
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
|
||||
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
||||
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
@ -984,6 +993,16 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRSubScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
|
||||
return rewrite0DBinaryTensorOp(op, rewriter);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1014,6 +1033,23 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenScalarImplicitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
Value a = op.getA();
|
||||
auto outType = op.getResult().getType();
|
||||
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
||||
if (!scalarValue)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1092,7 +1128,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenSizeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
|
||||
return nullptr;
|
||||
|
@ -1132,7 +1168,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
|||
// AtenLtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a < b; });
|
||||
}
|
||||
|
@ -1141,7 +1177,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a > b; });
|
||||
}
|
||||
|
@ -1150,7 +1186,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGeFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a >= b; });
|
||||
}
|
||||
|
@ -1159,7 +1195,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a == b; });
|
||||
}
|
||||
|
@ -1225,7 +1261,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|||
// AtenNeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a != b; });
|
||||
}
|
||||
|
@ -1234,7 +1270,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a == b; });
|
||||
}
|
||||
|
@ -1243,7 +1279,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
|
||||
|
@ -1259,7 +1295,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a < b; });
|
||||
}
|
||||
|
@ -1268,7 +1304,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a <= b; });
|
||||
}
|
||||
|
@ -1277,7 +1313,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a > b; });
|
||||
}
|
||||
|
@ -1286,7 +1322,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a >= b; });
|
||||
}
|
||||
|
@ -1295,7 +1331,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenBoolFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
||||
return getI1IntegerAttr(getContext(), c != 0.0);
|
||||
|
@ -1306,7 +1342,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenBoolIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getI1IntegerAttr(getContext(), c != 0);
|
||||
|
@ -1317,9 +1353,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenFloatScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold int -> float conversion.
|
||||
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
|
||||
return FloatAttr::get(
|
||||
mlir::Float64Type::get(getContext()),
|
||||
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
||||
|
@ -1330,13 +1366,27 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
|
||||
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
|
||||
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
||||
|
@ -1347,6 +1397,18 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
|
||||
bool b;
|
||||
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
|
||||
return getI64IntegerAttr(getContext(), static_cast<long>(b));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSortIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1440,7 +1502,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1545,7 +1607,7 @@ void CopyToValueTensorOp::getEffects(
|
|||
// ConstantNoneOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
|
||||
return TypeAttr::get(Torch::NoneType::get(getContext()));
|
||||
}
|
||||
|
||||
|
@ -1558,9 +1620,7 @@ void ConstantNoneOp::getAsmResultNames(
|
|||
// ConstantStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
return getValueAttr();
|
||||
}
|
||||
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
||||
|
||||
void ConstantStrOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
|
@ -1598,7 +1658,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
|
|||
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
||||
}
|
||||
|
||||
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1614,7 +1674,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
|
|||
// ConstantFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1644,7 +1704,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
|
|||
// ConstantNumberOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1672,7 +1732,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
|||
// ConstantBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1690,7 +1750,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
return isValidSubtype(outputs[0], inputs[0]);
|
||||
}
|
||||
|
||||
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
|
||||
if (derefineOp.getOperand().getType() == getType())
|
||||
return derefineOp.getOperand();
|
||||
|
@ -1824,7 +1884,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenEqIntListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
|
||||
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!lhsLiteral)
|
||||
return nullptr;
|
||||
|
@ -1849,6 +1909,20 @@ OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleConstructOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PrimTupleConstructOp::verify() {
|
||||
if (!(isValidSubtype(
|
||||
Torch::TupleType::get(getContext(),
|
||||
llvm::to_vector<6>(getElements().getType())),
|
||||
getResult().getType())))
|
||||
return emitOpError(
|
||||
"failed to verify that contained types correspond to operand types");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1950,7 +2024,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
|||
// Aten__Getitem__DictStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
|
||||
auto dictConstruct = getDictConstructIfNotModified(getSelf());
|
||||
if (!dictConstruct)
|
||||
return nullptr;
|
||||
|
@ -1968,7 +2042,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Contains__StrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
|
||||
auto dictConstruct = getDictConstructIfNotModified(getDict());
|
||||
if (!dictConstruct)
|
||||
return nullptr;
|
||||
|
@ -1991,7 +2065,7 @@ static bool isListConstructNotModified(Value torchList) {
|
|||
});
|
||||
}
|
||||
|
||||
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
|
||||
auto itemConstruct = getItem();
|
||||
if (!isListConstructNotModified(getL()))
|
||||
return nullptr;
|
||||
|
@ -2052,43 +2126,55 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|||
// AtenFloordivIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRemainderIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a % b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a + b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a - b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||
return nullptr;
|
||||
return list.getElements()[0];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenStackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||
return nullptr;
|
||||
|
@ -2099,7 +2185,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// AtenSliceTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||
|
@ -2118,7 +2204,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// AtenMulIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
|
@ -2129,46 +2215,70 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(), [](double a, double b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1]) {
|
||||
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
|
||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
operands, [](double a, double b) -> double { return a - b; });
|
||||
adaptor.getOperands(),
|
||||
[](double a, double b) -> double { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1]) {
|
||||
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Since AtenDivOp always returns float value, we don't need to deal with the
|
||||
// case where the operands are both integers separately.
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
operands, [](double a, double b) -> double { return a / b; });
|
||||
adaptor.getOperands(),
|
||||
[](double a, double b) -> double { return a / b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenPowIntFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCeilScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0]) {
|
||||
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
|
||||
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
||||
if (!floatValue) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -2181,7 +2291,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenNegIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getI64IntegerAttr(getContext(), -c);
|
||||
|
@ -2192,7 +2302,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqrtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getF64FloatAttr(getContext(), std::sqrt(c));
|
||||
|
@ -2203,7 +2313,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
||||
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
|
||||
if (tensorType.hasDtype()) {
|
||||
torch_upstream::ScalarType scalarType =
|
||||
|
@ -2217,7 +2327,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
|
@ -2229,7 +2339,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenFloatTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Float.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
|
@ -2241,7 +2351,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDivFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
||||
|
@ -2258,7 +2368,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDivIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
|
@ -2271,7 +2381,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenCeilFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
||||
return getI64IntegerAttr(getContext(), std::ceil(c));
|
||||
|
@ -2282,13 +2392,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimMaxIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
||||
// If both operands are the same, then the operation is an identity.
|
||||
if (getA() == getB())
|
||||
return getA();
|
||||
|
||||
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
// Torch semantics are that !torch.int is 64-bit signed.
|
||||
|
@ -2301,7 +2411,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimMinSelfIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
||||
if (!list)
|
||||
return nullptr;
|
||||
|
@ -2320,6 +2430,25 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
*std::min_element(values.begin(), values.end()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimMinIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
|
||||
// If both operands are the same, then the operation is an identity.
|
||||
if (getA() == getB())
|
||||
return getA();
|
||||
|
||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
// Torch semantics are that !torch.int is 64-bit signed.
|
||||
return IntegerAttr::get(
|
||||
lhs.getType(),
|
||||
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeCalculateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -68,16 +68,32 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
|
|||
return true;
|
||||
}
|
||||
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
|
||||
auto typeTensorType = type.dyn_cast<BaseTensorType>();
|
||||
if (subtypeTensorType && typeTensorType) {
|
||||
// Check that both tensors have the same `BaseTensorType` subtype.
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtypeTensorType.isa<ValueTensorType>() !=
|
||||
typeTensorType.isa<ValueTensorType>())
|
||||
return false;
|
||||
|
||||
// `type` must not have more static information than `subtype`, and `type`
|
||||
// must not disagree with `subtype`.
|
||||
if (typeTensorType.hasDtype() &&
|
||||
(!subtypeTensorType.hasDtype() ||
|
||||
typeTensorType.getDtype() != subtypeTensorType.getDtype())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typeTensorType.hasSizes() &&
|
||||
(!subtypeTensorType.hasSizes() ||
|
||||
typeTensorType.getSizes() != subtypeTensorType.getSizes())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
|
||||
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -463,7 +479,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
|||
}
|
||||
}
|
||||
|
||||
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
|
||||
return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype);
|
||||
}
|
||||
|
||||
////===----------------------------------------------------------------------===//
|
||||
|
@ -505,4 +521,4 @@ DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
|||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4088,6 +4088,259 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" return %none : !torch.none\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch.jit._shape_functions.stack(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %str = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: \"\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
|
||||
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" torch.prim.Loop %1, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %18 = torch.aten.add.int %17, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %19 = torch.aten.le.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %20 = torch.prim.If %19 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %18 : !torch.int\n"
|
||||
" }\n"
|
||||
" %21 = torch.aten.neg.int %20 : !torch.int -> !torch.int\n"
|
||||
" %22 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %23 = torch.aten.lt.int %arg1, %21 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %24 = torch.prim.If %23 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %31 = torch.aten.gt.int %arg1, %22 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %31 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %25 = torch.aten.__not__ %24 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %25 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %26 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
|
||||
" %31 = torch.aten.add.int %arg1, %20 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %31 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" }\n"
|
||||
" %28 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %29, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %31 = torch.aten.__getitem__.t %16, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %32 = torch.aten.append.t %28, %31 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" torch.aten.insert.t %28, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
|
||||
" %30 = torch.aten.append.t %0, %28 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" torch.prim.Loop %3, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %18 = torch.aten.gt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %18 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %4 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" %5 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
|
||||
" %6 = torch.prim.Loop %4, %true, init(%5) {\n"
|
||||
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):\n"
|
||||
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %19 = torch.prim.If %18 -> (!torch.bool) {\n"
|
||||
" %22 = torch.aten.__getitem__.t %16, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %23 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n"
|
||||
" %21 = torch.prim.If %20 -> (!torch.optional<int>) {\n"
|
||||
" %22 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %23 = torch.prim.If %22 -> (!torch.int) {\n"
|
||||
" %25 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %26 = torch.aten.le.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %25 : !torch.int\n"
|
||||
" }\n"
|
||||
" %28 = torch.aten.neg.int %27 : !torch.int -> !torch.int\n"
|
||||
" %29 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %30 = torch.aten.lt.int %arg1, %28 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %31 = torch.prim.If %30 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %35 = torch.aten.gt.int %arg1, %29 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %35 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %32 = torch.aten.__not__ %31 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %32 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %33 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %34 = torch.prim.If %33 -> (!torch.int) {\n"
|
||||
" %35 = torch.aten.add.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %35 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %34 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %25 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %25 : !torch.int\n"
|
||||
" }\n"
|
||||
" %24 = torch.derefine %23 : !torch.int to !torch.optional<int>\n"
|
||||
" torch.prim.If.yield %24 : !torch.optional<int>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg3 : !torch.optional<int>\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter(%21 : !torch.optional<int>)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>\n"
|
||||
" %7 = torch.aten.__is__ %6, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %16 = torch.prim.unchecked_cast %6 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %16 : !torch.int\n"
|
||||
" }\n"
|
||||
" %9 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" %10 = torch.aten.gt.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %10 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %11 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" %12 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
|
||||
" %13 = torch.prim.Loop %11, %true, init(%12) {\n"
|
||||
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):\n"
|
||||
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %18 = torch.prim.Loop %17, %true, init(%int1) {\n"
|
||||
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
|
||||
" %23 = torch.aten.__getitem__.t %16, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %24 = torch.aten.mul.int %arg5, %23 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop.condition %true, iter(%24 : !torch.int)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
|
||||
" %19 = torch.aten.eq.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
|
||||
" %23 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %24 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n"
|
||||
" %22 = torch.prim.If %21 -> (!torch.optional<list<int>>) {\n"
|
||||
" %23 = torch.derefine %16 : !torch.list<int> to !torch.optional<list<int>>\n"
|
||||
" torch.prim.If.yield %23 : !torch.optional<list<int>>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg3 : !torch.optional<list<int>>\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter(%22 : !torch.optional<list<int>>)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>\n"
|
||||
" %14 = torch.aten.__is__ %13, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" %15 = torch.prim.If %14 -> (!torch.list<int>) {\n"
|
||||
" torch.prim.If.yield %2 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %16 = torch.prim.unchecked_cast %13 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %17 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
|
||||
" %18 = torch.prim.Loop %17, %true, init(%int0) {\n"
|
||||
" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n"
|
||||
" %22 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %23 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
|
||||
" %24 = torch.prim.Loop %23, %true, init(%int1) {\n"
|
||||
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
|
||||
" %29 = torch.aten.__getitem__.t %22, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %30 = torch.aten.mul.int %arg5, %29 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop.condition %true, iter(%30 : !torch.int)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
|
||||
" %25 = torch.aten.eq.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %26 = torch.prim.If %25 -> (!torch.bool) {\n"
|
||||
" %29 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
|
||||
" %30 = torch.aten.eq.int %29, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %30 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool\n"
|
||||
" %28 = torch.prim.If %27 -> (!torch.int) {\n"
|
||||
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" %30 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
|
||||
" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %31 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %32 = torch.aten.__range_length %int0, %29, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.Loop %32, %true, init() {\n"
|
||||
" ^bb0(%arg4: !torch.int):\n"
|
||||
" %35 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %36 = torch.aten.ne.int %35, %8 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %36 -> () {\n"
|
||||
" %37 = torch.aten.__getitem__.t %16, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %38 = torch.aten.__getitem__.t %22, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %39 = torch.aten.eq.int %37, %38 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %39 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %33 = torch.aten.__getitem__.t %22, %8 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %34 = torch.aten.add.int %arg3, %33 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %34 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter(%28 : !torch.int)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
|
||||
" %19 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %20 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %20, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %22 = torch.aten.__getitem__.t %16, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %23 = torch.aten.append.t %19, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %21 = torch.aten._set_item.t %19, %8, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %19 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %15 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
|
@ -5790,6 +6043,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -5877,6 +6134,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -5993,7 +6254,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.float, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
|
@ -6006,13 +6267,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
|
@ -6035,7 +6296,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
|
@ -6549,6 +6810,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -6583,6 +6847,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.p\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._index_put_impl\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -6596,6 +6863,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -6881,6 +7151,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -7310,6 +7583,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -7340,6 +7617,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
|
||||
" %1 = torch.derefine %int0 : !torch.int to !torch.any\n"
|
||||
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
|
@ -7855,6 +8139,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
@ -7925,6 +8233,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
|
@ -8644,7 +8976,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
|
@ -8659,7 +8990,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
|
@ -8681,7 +9012,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
|
||||
" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
|
|
@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
LowerToBackendContract.cpp
|
||||
MaximizeValueSemantics.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
RecomposeComplexOps.cpp
|
||||
ReduceOpVariants.cpp
|
||||
RefinePublicReturn.cpp
|
||||
RefineTypes.cpp
|
||||
|
|
|
@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|||
int64_t dtypeInt;
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return false;
|
||||
Type resDtype =
|
||||
FailureOr<Type> resDtype =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
return resDtype.isa<mlir::FloatType>();
|
||||
if (failed(resDtype))
|
||||
return false;
|
||||
return resDtype->isa<mlir::FloatType>();
|
||||
}
|
||||
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
|
@ -70,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
: llvm::ArrayRef(sizes),
|
||||
tensorType.getOptionalDtype());
|
||||
return resultType;
|
||||
}
|
||||
|
@ -106,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
valueType
|
||||
.getWithSizesAndDtype(
|
||||
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(valueType.getSizes()),
|
||||
: llvm::ArrayRef(valueType.getSizes()),
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
||||
.cast<BaseTensorType>();
|
||||
return rewriter
|
||||
|
@ -140,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|||
BaseTensorType inputType, Value scalar) {
|
||||
SmallVector<int64_t> sizes;
|
||||
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
||||
makeArrayRef(sizes), inputType.getOptionalDtype());
|
||||
ArrayRef(sizes), inputType.getOptionalDtype());
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||
ValueRange{});
|
||||
|
@ -169,6 +171,37 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|||
return sub;
|
||||
}
|
||||
|
||||
// Helper function to unsqueeze the input tensor at given dim.
|
||||
// Return the unsqueezed tensor or failure.
|
||||
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value input, Value dim) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> unsqueezedShape;
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
// `input` has a reduced rank. Hence add 1.
|
||||
int64_t unsqueezedRank = inputShape.size() + 1;
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, unsqueezedRank);
|
||||
if (!isValidDim(dimInt, unsqueezedRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
}
|
||||
unsqueezedShape.append(inputShape.begin(), inputShape.end());
|
||||
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
|
||||
} else {
|
||||
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
|
||||
}
|
||||
Type unsqueezedType = inputType.getWithSizesAndDtype(
|
||||
unsqueezedShape, inputType.getOptionalDtype());
|
||||
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
|
||||
op->getLoc(), unsqueezedType, input, dim);
|
||||
return unsqueezed;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||
/// number of dimensions across which the max needs to be computed.
|
||||
|
@ -258,6 +291,15 @@ public:
|
|||
Value dim = op.getDim();
|
||||
Value self = op.getSelf();
|
||||
|
||||
// convert `start` to non-negative: start += int(start < 0) * dimSize
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
|
||||
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
|
||||
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
|
||||
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
|
||||
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value startPlusOne =
|
||||
|
@ -595,6 +637,128 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `aten.bucketize` into the following op sequence:
|
||||
//
|
||||
// def aten_bucketize(input, boundaries, out_int32, right):
|
||||
// unsqz_input = input.unsqueeze(-1)
|
||||
// if not right:
|
||||
// comparison = unsqz_input <= boundaries
|
||||
// else:
|
||||
// comparison = unsqz_input < boundaries
|
||||
// indices = torch.argmax(comparison.float(), dim=-1)
|
||||
// within_bound = comparison[..., -1]
|
||||
// result = torch.where(within_bound, indices, boundaries.shape[0])
|
||||
// if out_int32:
|
||||
// result = result.int()
|
||||
// return result
|
||||
//
|
||||
namespace {
|
||||
class DecomposeAtenBucketizeTensorOp
|
||||
: public OpRewritePattern<AtenBucketizeTensorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value input = op.getSelf();
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input must have known sizes");
|
||||
}
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
|
||||
Value boundaries = op.getBoundaries();
|
||||
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
||||
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: boundaries must have "
|
||||
"known sizes and must be a 1D array");
|
||||
}
|
||||
int64_t boundariesSize = boundariesType.getSizes()[0];
|
||||
|
||||
bool outInt32;
|
||||
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: out_int32 must be a constant bool");
|
||||
}
|
||||
|
||||
bool right;
|
||||
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: right must be a constant bool");
|
||||
}
|
||||
|
||||
// unsqueeze input at the last dim to make it broadcastable with boundaries
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
auto unsqzTensorInfo =
|
||||
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
|
||||
if (failed(unsqzTensorInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor");
|
||||
}
|
||||
Value unsqzInput = *unsqzTensorInfo;
|
||||
|
||||
// compare unsqueezed input with boundaries
|
||||
SmallVector<int64_t> compareShape(inputShape);
|
||||
compareShape.push_back(boundariesSize);
|
||||
Type compareType =
|
||||
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
|
||||
Value compare;
|
||||
if (!right) {
|
||||
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
|
||||
boundaries);
|
||||
} else {
|
||||
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
|
||||
boundaries);
|
||||
}
|
||||
|
||||
// convert the comparison results to float32 as the argmax op input,
|
||||
// which does not support integer dtype in LINALG backend
|
||||
Value compareF32 =
|
||||
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
|
||||
|
||||
// get the first boundary index where the input element is less than (or
|
||||
// equal to) the boundary value
|
||||
Type indicesType = inputType.getWithSizesAndDtype(
|
||||
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
|
||||
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
|
||||
/*dim=*/constMinusOne,
|
||||
/*keepdim=*/constFalse);
|
||||
|
||||
// get the comparison results between each input element and the rightmost
|
||||
// boundary value
|
||||
Type withinUpperBoundType =
|
||||
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
|
||||
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
|
||||
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
|
||||
/*index=*/constMinusOne);
|
||||
|
||||
// If the input element is less than (or equal to) the rightmost boundary,
|
||||
// take the max index as result. Otherwise, the element is beyond the
|
||||
// rightmost boundary, so take the boundary size.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value upperBound =
|
||||
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
|
||||
Value result = rewriter.create<AtenWhereScalarOtherOp>(
|
||||
loc, indicesType, withinUpperBound, indices, upperBound);
|
||||
|
||||
if (outInt32) {
|
||||
result = convertTensorToDtype(
|
||||
rewriter, loc, result,
|
||||
rewriter.getIntegerType(32, IntegerType::Signed));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// To avoid overflow we use the following decomposition rule:
|
||||
// x_max = aten.max(x, dim, keepdim=True)[0]
|
||||
// shifted = x - x_max
|
||||
|
@ -891,6 +1055,50 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`.
|
||||
namespace {
|
||||
class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenStackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> tensors;
|
||||
if (!getListConstructElements(op.getTensors(), tensors)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the tensor list is not from list construct");
|
||||
}
|
||||
// Ensure all tensors have known sizes
|
||||
for (Value tensor : tensors) {
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: one tensor does not have known sizes");
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> unsqueezedTensors;
|
||||
for (Value tensor : tensors) {
|
||||
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim());
|
||||
if (failed(unsqueezedInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "cannot generate unsqueeze tensor op");
|
||||
}
|
||||
unsqueezedTensors.push_back(*unsqueezedInfo);
|
||||
}
|
||||
|
||||
Type listElemType =
|
||||
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
|
||||
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), listType, unsqueezedTensors);
|
||||
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
|
||||
unsqueezedTensorList, op.getDim());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.roll into aten.slice and aten.cat ops.
|
||||
// https://pytorch.org/docs/stable/generated/torch.roll.html
|
||||
namespace {
|
||||
|
@ -929,7 +1137,7 @@ public:
|
|||
SmallVector<int64_t> sizes;
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[cstDim] = kUnknownSize;
|
||||
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
|
||||
selfTy.getOptionalDtype());
|
||||
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
||||
|
@ -1066,9 +1274,9 @@ public:
|
|||
|
||||
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type unsqueezedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
||||
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
|
||||
|
||||
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||
Value unsqueezedDims =
|
||||
|
@ -1226,6 +1434,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.masked_fill.Scalar into aten.where.self op.
|
||||
namespace {
|
||||
class DecomposeAtenMaskedFillScalarOp
|
||||
: public OpRewritePattern<AtenMaskedFillScalarOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value mask = op.getMask();
|
||||
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
|
||||
value, op.getSelf());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
// Decompose aten.convolution_overrideable to aten.convolution op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionOverrideableOp
|
||||
|
@ -1977,23 +2204,23 @@ public:
|
|||
// aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)).
|
||||
// Since the input x can be an integer tensor, it's important to cast it to
|
||||
// float type before passing it to the `aten.randLike` op.
|
||||
class DecomposeValsemVariantAtenBernoulliFloatOp
|
||||
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
|
||||
template <typename BernoulliLikeOp>
|
||||
class DecomposeAtenBernoulliLikeOp : public OpRewritePattern<BernoulliLikeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
|
||||
using OpRewritePattern<BernoulliLikeOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(BernoulliLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value p = op.getP();
|
||||
if (!op.getGenerator().getType().isa<Torch::NoneType>())
|
||||
if (!op.getGenerator().getType().template isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> empty;
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
||||
rewriter.getF64Type());
|
||||
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
||||
Value output;
|
||||
|
@ -2071,8 +2298,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
||||
for (int i = 0; i < axis; i++)
|
||||
meanVarSizes[i] = input.getSizes()[i];
|
||||
auto meanVarType = input.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
|
||||
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
|
||||
input.getOptionalDtype());
|
||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
|
||||
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
|
||||
|
@ -2309,7 +2536,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
runningStatsShapeInt[1] = kUnknownSize;
|
||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type reshapeType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||
|
||||
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
||||
runningStatsSizeList);
|
||||
|
@ -2455,8 +2682,7 @@ public:
|
|||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
Type tensorType =
|
||||
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
||||
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
|
||||
op.getFillValue());
|
||||
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
|
||||
|
@ -2492,7 +2718,7 @@ public:
|
|||
SmallVector<int64_t> transposeShape =
|
||||
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
||||
Type transposeType = weightType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
|
||||
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
|
||||
Value transposeWeight =
|
||||
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
||||
|
||||
|
@ -2562,8 +2788,7 @@ public:
|
|||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
Type tensorType =
|
||||
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
||||
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
|
||||
op.getLoc(), tensorType, op.getFillValue());
|
||||
fillVal =
|
||||
|
@ -3003,7 +3228,7 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|||
|
||||
template <typename OpTy>
|
||||
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||
bool unbiased, int64_t correction) {
|
||||
bool unbiased, double correction) {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
Value dimList = op.getDim();
|
||||
|
@ -3089,19 +3314,22 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|||
productDimSize =
|
||||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||
}
|
||||
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(correction));
|
||||
productDimSize = rewriter.create<AtenFloatScalarOp>(loc, productDimSize);
|
||||
constantOne = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value cstCorrection = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(correction));
|
||||
// The `correction` value should be less than or equal to `productDimSize +
|
||||
// 1`.
|
||||
Value productDimSizePlusOne =
|
||||
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
|
||||
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
|
||||
loc, productDimSize.getType(), productDimSize, constantOne);
|
||||
Value cond =
|
||||
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
|
||||
rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne, cstCorrection);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, cond,
|
||||
"correction value should be less than or equal to productDimSize + 1");
|
||||
Value productDimSizeSubCorrection =
|
||||
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
|
||||
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection);
|
||||
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
|
||||
productDimSizeSubCorrection);
|
||||
result =
|
||||
|
@ -3128,7 +3356,7 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant unbiased for aten.var");
|
||||
}
|
||||
int64_t correction = unbiased ? 1 : 0;
|
||||
double correction = unbiased ? 1.0 : 0.0;
|
||||
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
|
||||
correction)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
||||
|
@ -3148,18 +3376,32 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
int64_t correction;
|
||||
int64_t correctionValInt;
|
||||
double correctionValFloat = 1.0;
|
||||
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
|
||||
if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correction)))
|
||||
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
|
||||
if (!matchPattern(op.getCorrection(),
|
||||
m_TorchConstantFloat(&correctionValFloat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant int or float correction value for "
|
||||
"aten.var");
|
||||
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
|
||||
if (!matchPattern(op.getCorrection(),
|
||||
m_TorchConstantInt(&correctionValInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant int or float correction value for "
|
||||
"aten.var");
|
||||
correctionValFloat = (double)correctionValInt;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only support constant int correction for aten.var");
|
||||
} else {
|
||||
// The default value in case of `correction` being None is 1.
|
||||
correction = 1;
|
||||
op, "unimplemented: correction value should be only constant int "
|
||||
"or float for aten.var");
|
||||
}
|
||||
}
|
||||
bool unbiased = correction == 0 ? false : true;
|
||||
|
||||
bool unbiased = correctionValFloat == 0.0 ? false : true;
|
||||
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
|
||||
correction)))
|
||||
correctionValFloat)))
|
||||
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
|
||||
return success();
|
||||
}
|
||||
|
@ -3184,29 +3426,13 @@ public:
|
|||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value startPlusOne =
|
||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!srcTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
||||
|
||||
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
||||
// `src` has a reduced rank. Hence add 1.
|
||||
int64_t srcRank = srcShape.size() + 1;
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, srcRank);
|
||||
if (!isValidDim(dimInt, srcRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
|
||||
sizes.append(srcShape.begin(), srcShape.end());
|
||||
sizes.insert(sizes.begin() + dimInt, 1);
|
||||
|
||||
} else {
|
||||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
|
||||
if (failed(unsqueezedInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor op");
|
||||
}
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||
src = *unsqueezedInfo;
|
||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
||||
/*step=*/one);
|
||||
|
@ -3303,7 +3529,7 @@ public:
|
|||
op, "Expected the input tensor to have sizes");
|
||||
BaseTensorType subType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
||||
resultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
|
@ -3330,6 +3556,29 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
|
||||
class DecomposeAtenNormScalarOptDimOp
|
||||
: public OpRewritePattern<AtenNormScalarOptDimOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value ord = op.getP();
|
||||
if (ord.getType().isa<Torch::NoneType>()) {
|
||||
ord = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(2.0));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenLinalgVectorNormOp>(
|
||||
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
|
||||
/*dtype=*/none);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
|
||||
public:
|
||||
|
@ -3526,6 +3775,40 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
|
||||
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.getMemoryFormat(),
|
||||
m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the memory format should be specified in "
|
||||
"an integer constant");
|
||||
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only none, contiguous and preserve "
|
||||
"memory_format is supported");
|
||||
}
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
|
||||
auto sizeListType =
|
||||
Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||
Value sizeList =
|
||||
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
|
||||
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
|
||||
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
|
||||
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
|
||||
public:
|
||||
|
@ -3546,6 +3829,49 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNewEmptyStridedOp
|
||||
: public OpRewritePattern<AtenNewEmptyStridedOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all size list elements must be constant ints");
|
||||
if (!matchPattern(op.getStride(),
|
||||
m_TorchListOfConstantInts(strideListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all stride list elements must be constant ints");
|
||||
|
||||
// We only support the cases with default stride values.
|
||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
||||
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
||||
// stride[2] == 1.
|
||||
bool isDefaultStride = true;
|
||||
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
||||
int64_t defaultStride = 1;
|
||||
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
||||
defaultStride *= sizeListInts[j];
|
||||
if (defaultStride != strideListInts[i]) {
|
||||
isDefaultStride = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isDefaultStride)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only default strides supported for new_empty_strided op");
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
|
||||
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
|
||||
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -3591,6 +3917,7 @@ public:
|
|||
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
|
||||
|
@ -3598,6 +3925,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||
|
@ -3640,8 +3968,11 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenBernoulliLikeOp<ValsemVariantAtenBernoulliFloatOp>>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
|
@ -3688,6 +4019,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
||||
|
@ -3695,9 +4027,12 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
@ -3715,4 +4050,4 @@ std::unique_ptr<OperationPass<func::FuncOp>>
|
|||
mlir::torch::Torch::createDecomposeComplexOpsPass(
|
||||
ArrayRef<std::string> legalOps) {
|
||||
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -244,7 +244,7 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
|
|||
continue;
|
||||
opsToMove.push_back(&op);
|
||||
}
|
||||
BlockAndValueMapping mapping;
|
||||
IRMapping mapping;
|
||||
for (Operation *op : opsToMove) {
|
||||
// The ops are used by `torch.slot` ops in the enclosing module.
|
||||
// Cloning avoids needing to handle those uses specially.
|
||||
|
@ -329,7 +329,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
|
|||
// currently only analyzes a subset of ops.
|
||||
static LogicalResult analyzeInstances(func::FuncOp func,
|
||||
ArrayRef<ArgInstance> argInstances,
|
||||
BlockAndValueMapping &mapping) {
|
||||
IRMapping &mapping) {
|
||||
for (auto &argInstance : argInstances)
|
||||
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
|
||||
auto walkResult = func.walk([&](PrimGetAttrOp op) {
|
||||
|
@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
|
|||
}
|
||||
|
||||
static FailureOr<Monomorphization>
|
||||
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
|
||||
createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
|
||||
SymbolTable &symbolTable) {
|
||||
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
|
||||
Monomorphization monomorphization;
|
||||
|
@ -410,7 +410,7 @@ public:
|
|||
private:
|
||||
LogicalResult generateNewMonomorphizations(const Monomorphization &m) {
|
||||
auto func = m.func;
|
||||
BlockAndValueMapping mapping;
|
||||
IRMapping mapping;
|
||||
if (failed(analyzeInstances(func, m.argInstances, mapping)))
|
||||
return failure();
|
||||
auto walkResult = func.walk([&](func::CallOp op) {
|
||||
|
@ -495,7 +495,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
|||
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
|
||||
// `mapping` to corresponding global instances.
|
||||
static LogicalResult rewriteMonomorphizedFuncClone(
|
||||
func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
|
||||
func::FuncOp func, IRMapping mapping, SymbolTable &symbolTable,
|
||||
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
|
||||
ObjectGraphInfo &objectGraphInfo) {
|
||||
|
||||
|
@ -662,7 +662,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
}
|
||||
|
||||
for (auto &kv : newFuncs) {
|
||||
BlockAndValueMapping mapping;
|
||||
IRMapping mapping;
|
||||
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
|
||||
return failure();
|
||||
if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable,
|
||||
|
|
|
@ -27,8 +27,8 @@
|
|||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -373,7 +373,7 @@ class InlineGlobalSlotsPass
|
|||
// big deal.
|
||||
SmallVector<Operation *> slice =
|
||||
getBackwardSliceIncludingRoot(initialValue);
|
||||
BlockAndValueMapping mapping;
|
||||
IRMapping mapping;
|
||||
OpBuilder builder(op);
|
||||
for (Operation *opInSlice : slice)
|
||||
builder.clone(*opInSlice, mapping);
|
||||
|
|
|
@ -285,19 +285,16 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class VerifyBackendContractPass
|
||||
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||
class VerifyBackendContractNoDecompositionsPass
|
||||
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
|
||||
public:
|
||||
VerifyBackendContractPass() = default;
|
||||
VerifyBackendContractPass(bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps) {
|
||||
this->decompose = decompose;
|
||||
this->backendLegalOps = backendLegalOps;
|
||||
}
|
||||
VerifyBackendContractNoDecompositionsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target =
|
||||
getBackendContractTarget(context, decompose, backendLegalOps);
|
||||
getBackendContractTarget(context, /*decompose*/false,
|
||||
/*backendLegalOps*/{});
|
||||
|
||||
if (!satisfiesBackendContract(getOperation(), target,
|
||||
/*actuallyEmitDiagnostics=*/true)) {
|
||||
|
@ -315,10 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
|
|||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createVerifyBackendContractPass(
|
||||
bool decompose, ArrayRef<std::string> backendLegalOps) {
|
||||
return std::make_unique<VerifyBackendContractPass>(decompose,
|
||||
backendLegalOps);
|
||||
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
|
||||
return std::make_unique<VerifyBackendContractNoDecompositionsPass>();
|
||||
}
|
||||
|
||||
// The backend contract guarantees that ops with decompositions available will
|
||||
|
@ -347,6 +342,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenEmptyLikeOp>();
|
||||
target.addIllegalOp<AtenOnesLikeOp>();
|
||||
target.addIllegalOp<AtenZerosLikeOp>();
|
||||
target.addIllegalOp<AtenStackOp>();
|
||||
target.addIllegalOp<AtenRollOp>();
|
||||
target.addIllegalOp<AtenRepeatOp>();
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
|
@ -354,6 +350,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
|
@ -362,6 +359,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenAddmmOp>();
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
target.addIllegalOp<AtenMeanDimOp>();
|
||||
target.addIllegalOp<AtenNormScalarOptDimOp>();
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
|
@ -394,6 +392,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
target.addIllegalOp<AtenBernoulliPOp>();
|
||||
target.addIllegalOp<AtenBernoulliTensorOp>();
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
|
@ -442,7 +441,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<PrimsSqrtOp>();
|
||||
target.addIllegalOp<AtenRandnOp>();
|
||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||
target.addIllegalOp<AtenRandnLikeOp>();
|
||||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
for (std::string opName : backendLegalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
}
|
||||
|
|
|
@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// Clean up again to avoid needing to to back around the fixed-point
|
||||
// iteration.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCopy_Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getSelf().getDefiningOp() ||
|
||||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
|
||||
return failure();
|
||||
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
|
||||
|
||||
// Get indices
|
||||
int64_t dim;
|
||||
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
|
||||
return failure();
|
||||
int64_t end;
|
||||
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
|
||||
return failure();
|
||||
|
||||
Value newEnd = sliceOp.getEnd();
|
||||
if (end < 0) {
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(
|
||||
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
|
||||
newEnd =
|
||||
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
|
||||
}
|
||||
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||
|
||||
// Create IndexPut_Op
|
||||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
||||
Value range = rewriter.create<AtenArangeStartStepOp>(
|
||||
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
|
||||
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
|
||||
/*pin_memory=*/noneVal);
|
||||
|
||||
SmallVector<Value> indicesVector;
|
||||
for (auto i = 0; i < dim - 1; i++)
|
||||
indicesVector.push_back(noneVal);
|
||||
indicesVector.push_back(range);
|
||||
Value indices = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(),
|
||||
Torch::ListType::get(op->getContext(),
|
||||
Torch::OptionalType::get(tensorType)),
|
||||
indicesVector);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
|
||||
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
|
||||
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeComplexOps
|
||||
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
|
||||
public:
|
||||
RecomposeComplexOps() = default;
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
// pattern.add calls go here
|
||||
patterns.add<RecomposeSliceCopy_>(context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createRecomposeComplexOps() {
|
||||
return std::make_unique<RecomposeComplexOps>();
|
||||
}
|
|
@ -59,7 +59,6 @@
|
|||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -81,7 +80,9 @@ using namespace mlir::torch::Torch;
|
|||
// -----------------------------------------------------------------------------
|
||||
|
||||
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
||||
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
FailureOr<Type> result =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
||||
|
@ -111,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
|
|||
return torch_upstream::TypeKind::AnyType;
|
||||
}
|
||||
|
||||
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
|
||||
/// Returns `std::nullopt` if the types are contradictory. Note this can only
|
||||
/// be used on the `dtype` from tensors and can't be used on other types like
|
||||
/// scalar types.
|
||||
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
||||
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
|
||||
(void)isNullOrBuiltIn;
|
||||
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
|
||||
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
|
||||
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
if (!rhs)
|
||||
return lhs;
|
||||
if (lhs == rhs)
|
||||
return lhs;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
enum class OptionalKnowledge {
|
||||
unKnown,
|
||||
|
@ -475,7 +458,8 @@ private:
|
|||
void visitAtenToDtypeLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
|
||||
template <typename OpTy>
|
||||
void visitTypeConversionOp(OpTy op, ArrayRef<const ValueState *> operands);
|
||||
void visitAtenCatOp(AtenCatOp op, ArrayRef<const ValueState *> operands);
|
||||
template <typename OpTy>
|
||||
void visitAtenCatLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
|
||||
|
||||
template <typename OpTy>
|
||||
void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
|
||||
|
@ -563,7 +547,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
|
|||
/*skipRankCheck=*/true);
|
||||
state =
|
||||
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
|
||||
return getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||
FailureOr<Type> result =
|
||||
getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static SmallVector<std::optional<bool>>
|
||||
|
@ -600,7 +586,8 @@ static Type getPromotedResultType(MLIRContext *context,
|
|||
return Type();
|
||||
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
|
||||
}
|
||||
return getTypeForScalarType(context, result_type(state));
|
||||
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static Type getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
@ -649,23 +636,26 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
|
||||
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, AtenUniformOp,
|
||||
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
|
||||
AtenTanhBackwardOp, AtenHardtanhBackwardOp,
|
||||
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
|
||||
AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp,
|
||||
AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
|
||||
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
|
||||
AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp,
|
||||
AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp,
|
||||
AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
|
||||
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
|
||||
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
|
||||
AtenSelectIntOp, AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
|
||||
AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
||||
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
|
||||
AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
|
||||
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp,
|
||||
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
|
||||
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||
AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
|
||||
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
|
||||
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
|
||||
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
|
||||
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
|
||||
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
|
||||
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||
AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
|
||||
AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp,
|
||||
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
|
||||
AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
|
||||
Aten_IndexPutImplOp, AtenIndexPutOp, AtenCopyOp, AtenZeroOp,
|
||||
AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp,
|
||||
AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp,
|
||||
AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
|
||||
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
|
||||
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
|
||||
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
|
||||
|
@ -970,9 +960,16 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
|
||||
visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
|
||||
return;
|
||||
} else if (auto newEmptyStrided = dyn_cast<AtenNewEmptyStridedOp>(op)) {
|
||||
visitConstantTensorNewLikeOp<AtenNewEmptyStridedOp>(newEmptyStrided,
|
||||
operands);
|
||||
return;
|
||||
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
|
||||
visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
|
||||
return;
|
||||
} else if (auto randLike = dyn_cast<AtenRandnLikeOp>(op)) {
|
||||
visitConstantTensorAllocLikeOp<AtenRandnLikeOp>(randLike, operands);
|
||||
return;
|
||||
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
|
||||
visitConstantTensorAllocLikeOp<Aten_ToCopyOp>(toCopy, operands);
|
||||
return;
|
||||
|
@ -1008,7 +1005,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
}
|
||||
|
||||
if (auto cat = dyn_cast<AtenCatOp>(op)) {
|
||||
visitAtenCatOp(cat, operands);
|
||||
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
|
||||
return;
|
||||
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
|
||||
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1114,6 +1114,22 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
bool outInt32;
|
||||
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
|
||||
outInt32) {
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
|
||||
} else {
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
}
|
||||
incorporateKnowledge(bucketize.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation, so reset the state.
|
||||
setAllToEntryStates(results);
|
||||
return;
|
||||
|
@ -1338,30 +1354,26 @@ void TypeAnalysis::visitTypeConversionOp(
|
|||
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given
|
||||
// dimension. The output has the same sizes as the input for all dimensions
|
||||
// except the given dimension.
|
||||
void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
|
||||
ArrayRef<const ValueState *> operands) {
|
||||
template <typename OpTy>
|
||||
void TypeAnalysis::visitAtenCatLikeOp(OpTy op,
|
||||
ArrayRef<const ValueState *> operands) {
|
||||
auto tensorList = op.getTensors();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
|
||||
auto listConstruct = tensorList.template getDefiningOp<PrimListConstructOp>();
|
||||
if (!listConstruct) {
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
auto tensors = llvm::to_vector<4>(
|
||||
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge {
|
||||
return getLatticeElement(v)->getValue();
|
||||
SmallVector<ValueKnowledge*> tensors = llvm::to_vector(
|
||||
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* {
|
||||
return &getLatticeElement(v)->getValue();
|
||||
}));
|
||||
for (auto tensor : tensors) {
|
||||
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
|
||||
if (!newDtype.has_value()) {
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
knowledge.dtype = newDtype.value();
|
||||
}
|
||||
incorporateKnowledge(op.getResult(), knowledge);
|
||||
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), tensors);
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||
|
@ -1436,12 +1448,16 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
return getRefinedTensorType(tensorType, knowledge);
|
||||
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
|
||||
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
|
||||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
if (knowledge.optional == OptionalKnowledge::isNone)
|
||||
return Torch::NoneType::get(v.getContext());
|
||||
else if (knowledge.optional == OptionalKnowledge::notNone) {
|
||||
|
@ -1456,6 +1472,8 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
|
|||
if (!latticeElement)
|
||||
return nullptr;
|
||||
const ValueKnowledge &knowledge = latticeElement->getValue();
|
||||
if (!knowledge.isInitialized)
|
||||
return nullptr;
|
||||
if (knowledge.kind == torch_upstream::TypeKind::IntType)
|
||||
return Torch::IntType::get(v.getContext());
|
||||
if (knowledge.kind == torch_upstream::TypeKind::FloatType)
|
||||
|
|
|
@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
|||
impliedTypeFromDtype = *torchType;
|
||||
} else if (auto originalResultType =
|
||||
result.getType().dyn_cast<BaseTensorType>()) {
|
||||
FailureOr<Type> builtinType =
|
||||
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||
if (failed(builtinType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Failed to convert `dtypeScalarType` to a builtin type");
|
||||
}
|
||||
impliedTypeFromDtype =
|
||||
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
||||
originalResultType.getOptionalSizes(),
|
||||
getTypeForScalarType(op->getContext(), dtypeScalarType));
|
||||
originalResultType.getOptionalSizes(), *builtinType);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: Expected result type to "
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "SimplifyAbstractInterpCalculationsUtils.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
@ -47,7 +47,7 @@ public:
|
|||
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
|
||||
|
||||
SmallVector<Block *> blocksToMerge;
|
||||
BlockAndValueMapping bvm;
|
||||
IRMapping bvm;
|
||||
// TODO: Helper for region().front()
|
||||
auto condition =
|
||||
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
|
||||
|
@ -129,8 +129,7 @@ public:
|
|||
// Truncate the list of users to the number of users we're going to
|
||||
// interpret.
|
||||
allUsers.resize(numUsersToInterpret);
|
||||
auto usersToInterpret =
|
||||
makeArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
|
||||
// For each mutating op (which must be in the same block), we save the
|
||||
// current state of the list as a vector of Value's. These will then
|
||||
|
@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
|||
auto originalResultType = result.getType().cast<BaseTensorType>();
|
||||
auto impliedTypesFromShape =
|
||||
originalResultType.cast<BaseTensorType>()
|
||||
.getWithSizesAndDtype(makeArrayRef(sizes),
|
||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||
originalResultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace torch_upstream {
|
||||
|
@ -126,6 +128,23 @@ ScalarType result_type(const ResultTypeState &in_state) {
|
|||
combine_categories(in_state.zeroResult, in_state.wrappedResult));
|
||||
}
|
||||
|
||||
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
|
||||
if (reduce == "max" || reduce == "amax") {
|
||||
return torch_upstream::ReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
return torch_upstream::ReductionType::MEAN;
|
||||
} else if (reduce == "min" || reduce == "amin") {
|
||||
return torch_upstream::ReductionType::MIN;
|
||||
} else if (reduce == "sum") {
|
||||
return torch_upstream::ReductionType::SUM;
|
||||
} else if (reduce == "prod") {
|
||||
return torch_upstream::ReductionType::PROD;
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"'reduce' argument must be either sum, prod, mean, amax or amin");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch_upstream
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType(
|
|||
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
||||
}
|
||||
|
||||
Type Torch::getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness) {
|
||||
FailureOr<Type>
|
||||
Torch::getTypeForScalarType(MLIRContext *context,
|
||||
torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness) {
|
||||
switch (dtypeInt) {
|
||||
case torch_upstream::ScalarType::Float:
|
||||
return Float32Type::get(context);
|
||||
|
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
|
|||
return mlir::ComplexType::get(Float64Type::get(context));
|
||||
case torch_upstream::ScalarType::ComplexDouble:
|
||||
return mlir::ComplexType::get(Float128Type::get(context));
|
||||
case torch_upstream::ScalarType::Undefined:
|
||||
return failure();
|
||||
default:
|
||||
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
|
||||
}
|
||||
|
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
|
|||
return Torch::FloatType::get(context);
|
||||
case torch_upstream::ScalarType::Long:
|
||||
return Torch::IntType::get(context);
|
||||
case torch_upstream::ScalarType::Undefined:
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -32,11 +32,11 @@ namespace {
|
|||
struct TorchConversionInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
IRMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
IRMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
|
|||
// FromI64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// ToI64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// ToF64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
|
||||
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// FromF64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
|
||||
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
|
|
@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
|
|||
TorchMLIRTorchConversionToMLProgram
|
||||
MLIRMemRefTransforms)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||
list(APPEND LinkedLibs ChloPasses)
|
||||
endif()
|
||||
|
||||
|
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
Passes.cpp
|
||||
VerifyLinalgOnTensorsBackendContract.cpp
|
||||
VerifyTosaBackendContract.cpp
|
||||
VerifyMhloBackendContract.cpp
|
||||
VerifyStablehloBackendContract.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms
|
||||
|
|
|
@ -21,9 +21,8 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#endif
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
|
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
"Pipeline lowering torch backend contract to TOSA backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
|
||||
"torch-backend-to-mhlo-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to MHLO backend "
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
mlir::PassPipelineRegistration<
|
||||
TorchConversion::StablehloBackendPipelineOptions>(
|
||||
"torch-backend-to-stablehlo-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to StableHLO backend "
|
||||
"contract.",
|
||||
TorchConversion::createTorchBackendToMhloBackendPipeline);
|
||||
TorchConversion::createTorchBackendToStablehloBackendPipeline);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const TorchConversion::MhloBackendPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
|
||||
const TorchConversion::StablehloBackendPipelineOptions &options) {
|
||||
// Generate Stablehlo ops.
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
|
||||
options.enableStaticShape, options.enableI32Index));
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
|
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
|||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
|
||||
// Convert CHLO ops to MHLO ops
|
||||
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// MHLO backend contract.
|
||||
// StableHLO backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
// Verify that we have lowered to the form that MHLO backends
|
||||
// expect. This fails compilation (signalPassFailure) if the IR is not in the
|
||||
// correct form.
|
||||
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
|
||||
|
||||
// Verify that we have lowered to Stablehlo and Chlo ops.
|
||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -6,10 +6,9 @@
|
|||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
|
@ -18,6 +17,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -25,17 +25,15 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
namespace {
|
||||
class VerifyMhloBackendContractPass
|
||||
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
|
||||
class VerifyStablehloBackendContractPass
|
||||
: public VerifyStablehloBackendContractBase<
|
||||
VerifyStablehloBackendContractPass> {
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
auto module = getOperation();
|
||||
TypeConverter converter;
|
||||
converter.addConversion([](Type type) -> Type {
|
||||
auto elemTy = type;
|
||||
if (isa<TensorType>(type)) {
|
||||
if (isa<TensorType>(type))
|
||||
elemTy = type.cast<TensorType>().getElementType();
|
||||
}
|
||||
if (BaseMemRefType::isValidElementType(elemTy))
|
||||
return type;
|
||||
return nullptr;
|
||||
|
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
|
|||
|
||||
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
|
||||
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
|
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
|
|||
// Shape operations.
|
||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||
|
||||
target.addLegalDialect<mhlo::MhloDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect>();
|
||||
target.addLegalDialect<stablehlo::StablehloDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||
// doesn't unnecessarily spew out the entire module.
|
||||
emitError(module.getLoc())
|
||||
<< "Module does not conform to the MHLO backend contract. "
|
||||
"See dialect conversion legality information above.";
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
|
||||
return std::make_unique<VerifyMhloBackendContractPass>();
|
||||
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
|
||||
return std::make_unique<VerifyStablehloBackendContractPass>();
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
|
@ -20,6 +20,10 @@
|
|||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
#include "torch-mlir/RefBackend/Passes.h"
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "mhlo/transforms/passes.h"
|
||||
#endif
|
||||
|
||||
void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||
registry.insert<mlir::func::FuncDialect>();
|
||||
registry.insert<mlir::torch::Torch::TorchDialect>();
|
||||
|
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
|
|||
mlir::torch::registerConversionPasses();
|
||||
mlir::torch::RefBackend::registerRefBackendPasses();
|
||||
mlir::torch::TMTensor::registerPasses();
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
mlir::mhlo::registerSymbolicShapeOptimizationPass();
|
||||
mlir::mhlo::registerStablehloLegalizeToHloPass();
|
||||
mlir::mhlo::registerChloLegalizeToHloPass();
|
||||
mlir::mhlo::registerHloLegalizeToLinalgPass();
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
}
|
||||
|
|
|
@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
|||
loc,
|
||||
/*inputs=*/from,
|
||||
/*outputs=*/to,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args.front());
|
||||
|
|
|
@ -45,14 +45,16 @@ endif()
|
|||
declare_mlir_python_sources(TorchMLIRPythonSources)
|
||||
declare_mlir_python_sources(TorchMLIRPythonExtensions)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
__init__.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
)
|
||||
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
__init__.py
|
||||
compiler_utils.py
|
||||
dynamo.py
|
||||
)
|
||||
endif()
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
|
@ -91,7 +93,9 @@ if(TORCH_MLIR_ENABLE_LTC)
|
|||
endif()
|
||||
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
|
||||
# generates a dummy Python library when disabled.
|
||||
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
|
||||
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
|
||||
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Optionally handle JIT IR importer.
|
||||
|
|
|
@ -44,9 +44,9 @@ class OutputType(Enum):
|
|||
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||
TOSA = "tosa"
|
||||
|
||||
# This output type consists of `mhlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to MHLO.
|
||||
MHLO = "mhlo"
|
||||
# This output type consists of `stablehlo` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||
STABLEHLO = "stablehlo"
|
||||
|
||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||
# for end-users, but can be convenient for development or reporting bugs.
|
||||
|
@ -242,7 +242,7 @@ class ExampleArgs:
|
|||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.MHLO: [],
|
||||
OutputType.STABLEHLO: [],
|
||||
}
|
||||
|
||||
|
||||
|
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
|
|||
|
||||
# We only allow `backend_legal_ops` to be specified for the `"torch"`
|
||||
# output type because the other output types actually invoke their
|
||||
# respective backends (Linalg, TOSA, or MHLO), and those backends have
|
||||
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
|
||||
# very specific requirements about the ops which are legal.
|
||||
# See `BACKEND_LEGAL_OPS` for more details.
|
||||
if backend_legal_ops is not None:
|
||||
|
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
print(mb.module)
|
||||
return mb.module
|
||||
|
||||
elif output_type == OutputType.MHLO:
|
||||
elif output_type == OutputType.STABLEHLO:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"builtin.module(torch-backend-to-mhlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> MHLO Backend IR")
|
||||
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> StableHLO Backend IR")
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("MHLO Backend IR")
|
||||
print("StableHLO Backend IR")
|
||||
print(mb.module)
|
||||
return mb.module
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
|
|
|
@ -44,7 +44,7 @@ def run_pipeline_with_repro_report(module,
|
|||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
pm = PassManager.parse(pipeline)
|
||||
pm.run(module)
|
||||
pm.run(module.operation)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
|
|
|
@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED
|
|||
mlir_node.cpp
|
||||
ops/device_data.cpp
|
||||
ops/generic.cpp
|
||||
utils/jit_utils.cpp
|
||||
utils/tensor_utils.cpp
|
||||
)
|
||||
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue