Merge main into dtype-functions-staging (#1935)

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
Signed-off-by: Prateek Gupta <prateek.gupta2@cerebras.net>
Co-authored-by: Jiahao Li <liplus17@163.com>
Co-authored-by: Yuanqiang Liu <liuyuanqiang.yqliu@bytedance.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
Co-authored-by: Chi_Liu <chi@nod.ai>
Co-authored-by: Victor Guerra <vguerra@gmail.com>
Co-authored-by: Victor Guerra <vm.guerramoran@criteo.com>
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
Co-authored-by: Ashay Rane <ashay@users.noreply.github.com>
Co-authored-by: Eric Kunze <eric.kunze@arm.com>
Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
Co-authored-by: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com>
Co-authored-by: Yi Wang <yi.wang.2005@gmail.com>
Co-authored-by: Sean Silva <silvasean@google.com>
Co-authored-by: Zachary Cetinic <zachattack242@Hotmail.com>
Co-authored-by: Tanyo Kwok <tianyou.gty@alibaba-inc.com>
Co-authored-by: Zachary Cetinic <zacharycetinic@gmail.com>
Co-authored-by: Kunwar Grover <51270680+Groverkss@users.noreply.github.com>
Co-authored-by: Ziheng Jiang <ziheng@apache.org>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Co-authored-by: Gaurav Shukla <gaurav@nod-labs.com>
Co-authored-by: Prateek Gupta <108802984+prateekgu-cerebras@users.noreply.github.com>
Co-authored-by: nvda <nvda@stanford.edu>
Co-authored-by: Ahmed S. Taei <asaadaldien@users.noreply.github.com>
Co-authored-by: Priya Savithiri <104089347+PriyaBSavithiri@users.noreply.github.com>
Co-authored-by: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com>
Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
Co-authored-by: Kan Chen <chenkanhw@163.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
pull/1943/head
Ramiro Leal-Cavazos 2023-03-15 07:48:41 -07:00 committed by GitHub
parent ce7abf4911
commit 042d58b699
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
165 changed files with 5801 additions and 1845 deletions

View File

@ -17,7 +17,7 @@ runs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
- name: Install MLIR Python depends
run: |
@ -26,7 +26,8 @@ runs:
- name: Install PyTorch nightly depends
run: |
python -m pip install -r requirements.txt
python -m pip install -r pytorch-requirements.txt
python -m pip install -r build-requirements.txt
shell: bash
- name: Install prerequisites (Linux)

View File

@ -8,12 +8,19 @@ on:
jobs:
build_linux:
name: Manylinux Build
runs-on: ubuntu-latest
runs-on: a100
# Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir'
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Get torch-mlir
uses: actions/checkout@v3
with:
@ -31,6 +38,7 @@ jobs:
cd ${GITHUB_WORKSPACE}
python -m pip install wheel
sudo apt-get install unzip
# Fetch the most recent nightly torchvision release
VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/')
@ -44,7 +52,8 @@ jobs:
# Read the version from the downloaded whl file without extracting it
PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/')
echo "Found torch release ${PT_RELEASE}"
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\ntorchvision==%s\n" "${PT_RELEASE}" "${VISION_RELEASE}" > pytorch-requirements.txt
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt
# Read the commit hash from the downloaded whl file without extracting it
PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'")
@ -96,7 +105,7 @@ jobs:
git fetch --recurse-submodules=no
git checkout main
git pull origin main
git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git add pytorch-hash.txt pytorch-requirements.txt torchvision-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
- name: Update PyTorch Build Cache (if running on main branch)

View File

@ -20,6 +20,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checkout torch-mlir
uses: actions/checkout@v3
with:

View File

@ -51,6 +51,14 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Prepare workspace
if: ${{ matrix.os-arch == 'ubuntu-x86_64' }}
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checkout torch-mlir
uses: actions/checkout@v3
with:
@ -113,7 +121,7 @@ jobs:
-DLLVM_USE_HOST_TOOLS=ON \
-DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=OFF \
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
-DMACOSX_DEPLOYMENT_TARGET=12.0 \

View File

@ -13,8 +13,25 @@ on:
jobs:
build_linux:
name: Manylinux Build
runs-on: ubuntu-latest
runs-on: a100
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
py_version: [ cp38-cp38, cp310-cp310, cp311-cp311 ]
exclude:
- package: torch-mlir-core
py_version: cp38-cp38
- package: torch-mlir-core
py_version: cp310-cp310
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Get torch-mlir
uses: actions/checkout@v3
with:
@ -28,7 +45,7 @@ jobs:
python -m pip install wheel
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
./build_tools/python_deploy/build_linux_packages.sh
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh
# If we were given a release_id, then upload the package we just built
# to the github releases page.
@ -56,7 +73,7 @@ jobs:
run: mkdir dist
- name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != ''
run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/
# Wheels must be published from a linux environment.
#
@ -70,6 +87,9 @@ jobs:
build_macos:
name: MacOS Build
runs-on: macos-latest
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
steps:
- name: Get torch-mlir
uses: actions/checkout@v3
@ -85,7 +105,7 @@ jobs:
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
sudo ./build_tools/python_deploy/install_macos_deps.sh
TORCH_MLIR_PYTHON_VERSIONS="3.10" ./build_tools/python_deploy/build_macos_packages.sh
packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh
# If we were given a release_id, then upload the package we just built
# to the github releases page.
@ -113,7 +133,7 @@ jobs:
run: mkdir dist
- name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != ''
run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/
# Wheels must be published from a linux environment.
#
@ -127,6 +147,9 @@ jobs:
build_windows:
name: Windows Build
runs-on: windows-latest
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
steps:
- name: Get torch-mlir
uses: actions/checkout@v3
@ -142,6 +165,14 @@ jobs:
- name: Build Python wheels and smoke test.
shell: pwsh
run: |
if ( "${{ matrix.package }}" -eq "torch-mlir-core" )
{
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='0'
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='1'
} else {
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1'
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0'
}
$env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}'
./build_tools/python_deploy/build_windows.ps1
@ -172,7 +203,7 @@ jobs:
continue-on-error: true
- name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != ''
run: cp ./wheelhouse/torch_mlir-*.whl dist/
run: cp ./wheelhouse/torch_mlir*.whl dist/
# Wheels must be published from a linux environment.
#
@ -216,4 +247,4 @@ jobs:
# if: github.event.inputs.release_id != ''
# uses: pypa/gh-action-pypi-publish@v1.5.1
# with:
# password: ${{ secrets.PYPI_API_TOKEN }}
# password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -13,8 +13,13 @@ jobs:
if: github.repository == 'llvm/torch-mlir'
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}
- name: Run scrape releases script

View File

@ -10,8 +10,14 @@ jobs:
# Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir'
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

View File

@ -13,8 +13,15 @@ jobs:
# Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir'
steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository
uses: actions/checkout@v2
uses: actions/checkout@v3
with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

3
.gitignore vendored
View File

@ -32,3 +32,6 @@ bazel-*
build_oot/
docker_venv/
llvm-build/
# C++ build artifacts
compile_commands.json

View File

@ -36,12 +36,18 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF)
set(TORCH_MLIR_ENABLE_LTC OFF)
endif()
if(TORCH_MLIR_ENABLE_LTC)
set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
@ -109,7 +115,6 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OUT_OF_TREE_
# Don't try to compile the python extensions at the moment. We need
# to import lots of dependencies from AddMLIRPython to make this work.
set(MLIR_ENABLE_BINDINGS_PYTHON 1)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
@ -119,7 +124,6 @@ else()
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
# TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
@ -128,8 +132,8 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif()
if (TORCH_MLIR_ENABLE_MHLO)
set(MHLO_BUILD_EMBEDDED ON)
if (TORCH_MLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL)

View File

@ -8,13 +8,12 @@ necessarily a reflection of the completeness or stability of the code, it
does indicate that the project is not yet endorsed as a component of LLVM.
[PyTorch](https://pytorch.org)
An open source machine learning framework that accelerates the path from research prototyping to production deployment.
PyTorch is an open source machine learning framework that facilitates the seamless transition from research and prototyping to production-level deployment.
[MLIR](https://mlir.llvm.org)
The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together.
The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers.
[Torch-MLIR](https://github.com/llvm/torch-mlir)
Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks like PyTorch, JAX, and TensorFlow into MLIR and then progressively lowering down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem provides much needed relief to hardware vendors to focus on their unique value rather than implementing yet another PyTorch frontend for MLIR. The goal is to be similar to current hardware vendors adding LLVM target support instead of each one also implementing Clang / a C++ frontend.
Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend.
[![Release Build](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml/badge.svg)](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml)
@ -43,15 +42,26 @@ We have few paths to lower down to the Torch MLIR Dialect.
## Install torch-mlir snapshot
This installs a pre-built snapshot of torch-mlir for Python 3.7/3.8/3.9/3.10 on Linux and macOS.
At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS.
If you have Python 3.10, the following commands initialize a virtual environment.
```shell
python -m venv mlir_venv
python3.10 -m venv mlir_venv
source mlir_venv/bin/activate
# Some older pip installs may not be able to handle the recent PyTorch deps
```
Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10.
```shell
conda create -n torch-mlir python=3.10
conda activate torch-mlir
python -m pip install --upgrade pip
pip install --pre torch-mlir torchvision -f https://llvm.github.io/torch-mlir/package-index/ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# This will install the corresponding torch and torchvision nightlies
```
Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.
```
pip install --pre torch-mlir torchvision \
-f https://llvm.github.io/torch-mlir/package-index/
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
## Demos

View File

@ -1,5 +1,3 @@
-r pytorch-requirements.txt
numpy
pybind11
wheel

View File

@ -39,16 +39,16 @@ set -eu -o errtrace
this_dir="$(cd "$(dirname "$0")" && pwd)"
repo_root="$(cd "$this_dir"/../../ && pwd)"
# This needs to be a manylinux image so we can ship pip packages
TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}"
TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:d8994b87b45b7b2e6055fccc32db018ec73aeb05a4e43a9daa61b77cc34f846e}"
# This assumes an Ubuntu LTS like image. You can build your own with
# ./build_tools/docker/Dockerfile
TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}"
# Version of Python to use in Release builds. Ignored in CIs.
TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp310-cp310}"
TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}"
# Location to store Release wheels
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
# What "packages to build"
TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}"
# Use pre-built Pytorch
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration
@ -84,6 +84,11 @@ function run_on_host() {
export USERID=0
export GROUPID=0
;;
torch-mlir-core)
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
export USERID=0
export GROUPID=0
;;
out-of-tree)
TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE}
# CI uses only Python3.10
@ -159,6 +164,12 @@ function run_in_docker() {
clean_build torch_mlir "$python_version"
;;
torch-mlir-core)
clean_wheels torch_mlir_core "$python_version"
build_torch_mlir_core
run_audit_wheel torch_mlir_core "$python_version"
clean_build torch_mlir_core "$python_version"
;;
out-of-tree)
setup_venv "$python_version"
build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
@ -267,8 +278,8 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v
echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v
echo ":::: Run StableHLO e2e integration tests"
python -m e2e_testing.main --config=stablehlo -v
echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v
@ -277,7 +288,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v
echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot
}
function setup_venv() {
@ -373,6 +384,15 @@ function run_audit_wheel() {
rm "$generic_wheel"
}
function build_torch_mlir_core() {
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
}
function clean_wheels() {
local wheel_basename="$1"
local python_version="$2"

View File

@ -20,7 +20,7 @@ set -eu -o errtrace
this_dir="$(cd "$(dirname "$0")" && pwd)"
repo_root="$(cd "$this_dir"/../../ && pwd)"
python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10}"
python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10 3.11}"
output_dir="${output_dir:-${this_dir}/wheelhouse}"
packages="${packages:-torch-mlir}"
@ -61,6 +61,11 @@ function run() {
build_torch_mlir torch_mlir "$python_version"
run_audit_wheel torch_mlir "$python_version"
;;
torch-mlir-core)
clean_wheels torch_mlir_core "$python_version"
build_torch_mlir_core torch_mlir_core "$python_version"
run_audit_wheel torch_mlir_core "$python_version"
;;
*)
echo "Unrecognized package '$package'"
exit 1
@ -77,7 +82,8 @@ function build_torch_mlir() {
python"${python_version}" -m venv "$output_dir"/build_venv
source "$output_dir"/build_venv/bin/activate
python"${python_version}" -m pip install -U pip
python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
@ -87,6 +93,25 @@ function build_torch_mlir() {
rm -rf "$output_dir"/build_venv
}
function build_torch_mlir_core() {
local wheel_basename="$1"
local python_version="$2"
rm -rf "$output_dir"/build_venv
python"${python_version}" -m venv "$output_dir"/build_venv
source "$output_dir"/build_venv/bin/activate
python"${python_version}" -m pip install -U pip delocate
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root"
deactivate
rm -rf "$output_dir"/build_venv
}
function clean_wheels() {
local wheel_basename="$1"
local python_version="$2"
@ -107,7 +132,8 @@ function run_audit_wheel() {
python"${python_version}" -m venv "$output_dir"/test_venv
source "$output_dir"/test_venv/bin/activate
python"${python_version}" -m pip install -U pip
python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
python"${python_version}" -m pip install "$generic_wheel" --extra-index-url https://download.pytorch.org/whl/nightly/cpu
DYLD_LIBRARY_PATH="$output_dir"/test_venv/lib/python"${python_version}"/site-packages/torch/lib delocate-wheel -v "$generic_wheel"
deactivate

View File

@ -13,7 +13,9 @@
Write-Host "Installing Build Dependencies"
python -m venv .\mlir_venv\
.\mlir_venv\Scripts\Activate.PS1
pip install -r .\requirements.txt
pip install -r .\pytorch-requirements.txt
pip install -r .\build-requirements.txt
pip install delvewheel
Write-Host "Build Deps installation completed successfully"
Write-Host "Building torch-mlir"
@ -22,3 +24,7 @@ $env:TORCH_MLIR_ENABLE_LTC='0'
python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt
Write-Host "Build completed successfully"
Write-Host "Fixing up wheel dependencies"
delvewheel repair --add-path .\build\cmake_build\tools\torch-mlir\python_packages\torch_mlir\torch_mlir\_mlir_libs --add-dll TorchMLIRAggregateCAPI.dll --no-dll 'c10.dll;torch_python.dll;torch_cpu.dll' -v (get-item .\wheelhouse\torch_mlir*.whl).FullName
Write-Host "All Done."

View File

@ -19,11 +19,13 @@ if [[ "$(whoami)" != "root" ]]; then
fi
PYTHON_INSTALLER_URLS=(
"https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg"
"https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg"
"https://www.python.org/ftp/python/3.10.10/python-3.10.10-macos11.pkg"
"https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg"
)
PYTHON_SPECS=(
3.11@https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg
3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg
3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg
)

View File

@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)
The terms "frontend" and "backend" are highly overloaded in any compiler
project, but frequently in Torch-MLIR this is the meaning that they have.
Sometimes "frontend" can mean something even further up the stack, such as
something in PyTorch itself. When there is ambiguity we will refer to this as
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
sitting below Linalg-on-Tensors, TOSA, or MHLO.
sitting below Linalg-on-Tensors, TOSA, or StableHLO.
## The `torch` dialect
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
The backend contract is a normalized form of the `torch` dialect with a set of
properties that make it easy to lower into various forms such as
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
box. The primary guarantees that we provide Torch-MLIR's backends are:
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
the box. The primary guarantees that we provide Torch-MLIR's backends are:
- All tensors have been converted to value semantics.
- All tensors have at least a known number of dimensions (i.e. rank), and
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
`tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)
### The Linalg Backend (Linalg-on-Tensors)
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
- It is extremely solid with static shapes (and many of its users only care
about static shapes, so that's fine).
### The MHLO Backend
### The StableHLO Backend
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo
The MHLO backend was the third backend that we added, and it offers a reasonable
blend of the benefits of the other two.
The StableHLO backend was the third backend that we added, and it offers a
reasonable blend of the benefits of the other two.
- It is a coarse-grained named-op approach.
- It has a pretty clear spec for most of the ops (with a bit of mental
translation and hoping that MHLO is the same as HLO):
translation and hoping that StableHLO is the same as HLO):
https://www.tensorflow.org/xla/operation_semantics
- It functionally supports dynamic shapes (though not as coherent and consistent
as Linalg-on-Tensors, and the dynamic shape support falls outside the
@ -317,7 +317,7 @@ blend of the benefits of the other two.
example, TOSA limits (for highly considered reasons) the number of dimensions
that certain operators can handle to 1D-4D, when from a purely algebraic
perspective there isn't a good reason to not be more general. Similarly, more
general forms of reduction and scatter also fall into MHLO nicely while
general forms of reduction and scatter also fall into StableHLO nicely while
TOSA's principles tend to bias it away from that.
### Backend Implementation
@ -433,8 +433,9 @@ filling in some corners missing upstream and
to pull together upstream functionality into a working system.
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
RefBackend.
The RefBackend is absolutely not suitable for any production use case. It leaks
memory, doesn't support any error handling, performs no optimizations, and

View File

@ -34,7 +34,7 @@ and Clang's
- Eric Kunze (@eric-k256)
- Suraj Sudhir (@sjarus)
### TorchToMHLO
### TorchToStablehlo
- Tianyo Kwok (@tanyokwok)
- Ziheng Jiang (@ZihengJiang)

View File

@ -139,7 +139,7 @@ Ex:
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.
## Jupyter

View File

@ -46,7 +46,7 @@ the ecosystem are:
- The frontend work required to lower TorchScript to the backend contract.
- The irregular support surface area of the large number of PyTorch ops across
the Linalg, TOSA, and MHLO backends.
the Linalg, TOSA, and StableHLO backends.
Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals.
@ -108,7 +108,7 @@ more advanced).
### Refactoring the backend
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
TOSA, and MHLO. These backends take IR in the backend contract form (see
TOSA, and StableHLO. These backends take IR in the backend contract form (see
[architecture.md](architecture.md)) and lowers them to the respective dialects.
Today, each backend is implemented completely independently. This leads to
duplication and irregularity across the backends.
@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
forward-looking efforts that intersect with this effort:
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
initially forked from MHLO which intends to create a stable support surface
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
complete op set, so it is very attractive to have "almost all" models
bottleneck through a stable interface like StableHLO. StableHLO is currently
under relatively early development, but already delivers on many of the goals
of stability.
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
attractive to have "almost all" models bottleneck through a stable interface
like StableHLO. StableHLO is currently under relatively early development,
but already delivers on many of the goals of stability.
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
which could serve a role very similar to MHLO, while providing community
ownership. TCP is still in early planning phases, but there is strong

View File

@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.configs import (
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
StablehloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
)
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
@ -42,7 +42,7 @@ def _get_argparse():
help=f"""
Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"mhlo": run through torch-mlir"s default MHLO backend.
"stablehlo": run through torch-mlir"s default StableHLO backend.
"tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
@ -80,9 +80,9 @@ def main():
if args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == "mhlo":
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
if args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = {}

View File

@ -26,6 +26,11 @@ TORCHDYNAMO_XFAIL_SET = {
# https://github.com/pytorch/pytorch/issues/89629
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2D_basic",
# error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
# https://github.com/llvm/torch-mlir/issues/1859
"ConvolutionModule2DGroups_basic",
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
# https://github.com/pytorch/pytorch/issues/89630
@ -39,10 +44,6 @@ TORCHDYNAMO_XFAIL_SET = {
# RuntimeError: Failed running call_function aten.uniform(...
# https://github.com/pytorch/torchdynamo/issues/1954
"UniformNoCorrelationModule_basic",
# TypeError: expected np.ndarray (got float)
# TODO: This is due to returning a scalar float as output from the test.
# We should probably just standardize all tests to return tensors.
"DivIntModule_basic",
#### Torch-MLIR internal compiler errors
@ -66,14 +67,13 @@ TORCHDYNAMO_XFAIL_SET = {
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
"Matmul_vecmat",
# https://github.com/llvm/torch-mlir/issues/1611
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
"Aten_EmbeddingBagExample_basic",
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
"ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic",
@ -83,8 +83,16 @@ TORCHDYNAMO_XFAIL_SET = {
# error: unsupported by backend contract: tensor with unknown rank
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
"ElementwisePreluModule_basic",
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
"StdCorrectionKeepDimModule_basic",
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
"UpSampleNearest2dDynamicFactor_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
#ERROR: value (-56) is not equal to golden value (200)
"AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor)
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
# ERROR: assert isinstance(e, FakeTensor)
"RsubInt0d_NumToTensor_Module_basic",
# Dtype function transition failures
"MobilenetV3Module_basic",
@ -92,8 +100,12 @@ TORCHDYNAMO_XFAIL_SET = {
"ResNet18StaticModule_basic",
}
MHLO_PASS_SET = {
STABLEHLO_PASS_SET = {
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
@ -108,10 +120,15 @@ MHLO_PASS_SET = {
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
@ -126,19 +143,29 @@ MHLO_PASS_SET = {
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwisePowModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic",
@ -201,6 +228,8 @@ MHLO_PASS_SET = {
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardsigmoidModule_basic",
@ -223,6 +252,8 @@ MHLO_PASS_SET = {
"MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic",
"MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic",
"Mv_basic",
"NativeLayerNormModule4D_basic",
@ -239,6 +270,7 @@ MHLO_PASS_SET = {
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
@ -250,9 +282,15 @@ MHLO_PASS_SET = {
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic",
"NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic",
@ -272,6 +310,7 @@ MHLO_PASS_SET = {
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
@ -288,6 +327,7 @@ MHLO_PASS_SET = {
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"SliceStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
@ -358,6 +398,7 @@ MHLO_PASS_SET = {
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
@ -420,12 +461,14 @@ MHLO_PASS_SET = {
"UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic",
"TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic",
@ -449,6 +492,7 @@ TOSA_PASS_SET = {
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"TanhBackward_basic",
"HardtanhBackward_basic",
"ElementwiseAddModule_basic",
"ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic",
@ -459,6 +503,7 @@ TOSA_PASS_SET = {
"BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
@ -480,6 +525,7 @@ TOSA_PASS_SET = {
"Matmul_3d",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic",
@ -509,6 +555,7 @@ TOSA_PASS_SET = {
"ElementwiseDivScalarModule_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",
@ -572,6 +619,7 @@ TOSA_PASS_SET = {
"ViewExpandCollapseWithOnesModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic",
@ -604,6 +652,7 @@ TOSA_PASS_SET = {
"_LogSoftmaxModuleStable_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"MaskedFillScalarIntValueModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillTensorIntValueStaticModule_basic",
"ElementwiseAddScalarInt64Module_basic",
@ -611,8 +660,11 @@ TOSA_PASS_SET = {
"TensorOpaqueLiteralModule_basic",
"TypePromotionDifferentCategoryModule_basic",
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionSameCategoryZeroRankWider_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",
@ -651,6 +703,10 @@ TOSA_PASS_SET = {
"HardsigmoidRandomModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
}
LTC_XFAIL_SET = {
@ -666,7 +722,7 @@ LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"BernoulliFloatModule_basic",
"AtenIntBoolOpModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
@ -686,6 +742,7 @@ LTC_XFAIL_SET = {
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
"HardtanhBackward_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
@ -720,6 +777,8 @@ LTC_XFAIL_SET = {
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
@ -752,6 +811,8 @@ LTC_XFAIL_SET = {
"SubFloatModule_basic",
"SubIntModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
@ -788,4 +849,34 @@ LTC_XFAIL_SET = {
"ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliPModule_basic",
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarCorrectionKeepDimModule_basic",
"VarCorrectionLargeInputModule_basic",
"VarCorrectionModule_basic",
"VarCorrectionNoneModule_basic",
"VarCorrectionSingleDimReduceModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimEmptyDimModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimNoneDimModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"AtenFloatScalarModule_basic"
}

View File

@ -91,4 +91,4 @@ resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels)
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels)

View File

@ -1,14 +0,0 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")

View File

@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")

View File

@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")

View File

@ -10,7 +10,7 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"

View File

@ -457,7 +457,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping bvm;
IRMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());

View File

@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
}
auto scfIf = b.create<scf::IfOp>(
loc, TypeRange{}, cond,
loc, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices);
@ -232,7 +232,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
auto &srcBlock = getRegion().front();
Region &thisRegion = scfIf.getElseRegion();
BlockAndValueMapping bvm;
IRMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = thisRegion.front();
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
LogicalResult ScanOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
@ -461,7 +461,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Value init = b.create<memref::LoadOp>(loc, original(), starts);
BlockAndValueMapping bvm;
IRMapping bvm;
Block &block = getRegion().front();
bvm.map(block.getArgument(0), update);
bvm.map(block.getArgument(1), init);

@ -1 +1 @@
Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173
Subproject commit 21f4b84c456b471cc52016cf360e14d45f7f2960

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4
Subproject commit b1ac0403ee2a40fc648ada6b9f11096f3d50fd19

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()

View File

@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to mhlo ops.
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
// Specify any options.
let options = [

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
@ -16,10 +16,11 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H

View File

@ -2014,6 +2014,55 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
}];
}
def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClampTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClamp_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [
AllowsTypeRefinement,
HasValueSemantics,
@ -3340,6 +3389,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
@ -3638,6 +3688,31 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
}];
}
def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$p,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBernoulliPOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenBernoulliPOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
AllowsTypeRefinement,
HasValueSemantics,
@ -3771,6 +3846,34 @@ def Torch_AtenRandnGeneratorOp : Torch_Op<"aten.randn.generator", [
}];
}
def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory,
AnyTorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandnLikeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRandnLikeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics,
@ -5146,11 +5249,11 @@ def Torch_AtenStdCorrectionOp : Torch_Op<"aten.std.correction", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalIntType:$correction,
AnyTorchOptionalScalarType:$correction,
Torch_BoolType:$keepdim
);
let results = (outs
@ -5222,11 +5325,11 @@ def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalIntType:$correction,
AnyTorchOptionalScalarType:$correction,
Torch_BoolType:$keepdim
);
let results = (outs
@ -5248,11 +5351,11 @@ def Torch_AtenVarMeanCorrectionOp : Torch_Op<"aten.var_mean.correction", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)`";
let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalIntType:$correction,
AnyTorchOptionalScalarType:$correction,
Torch_BoolType:$keepdim
);
let results = (outs
@ -6482,6 +6585,35 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
}];
}
def Torch_AtenNewEmptyStridedOp : Torch_Op<"aten.new_empty_strided", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNewEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenNewEmptyStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
AllowsTypeRefinement,
HasValueSemantics,
@ -6975,30 +7107,6 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
let hasFolder = 1;
}
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenSumOp : Torch_Op<"aten.sum", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7556,6 +7664,86 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
}];
}
def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_add_ : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterAdd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatterReduceTwoOp : Torch_Op<"aten.scatter_reduce.two", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduceTwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduceTwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenScatterReduce_TwoOp : Torch_Op<"aten.scatter_reduce_.two", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_reduce_.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduce_TwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduce_TwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
AllowsTypeRefinement,
HasValueSemantics,
@ -8670,6 +8858,31 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
let hasFolder = 1;
}
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
AllowsTypeRefinement
]> {
@ -9085,6 +9298,7 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
@ -9111,6 +9325,30 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
let hasFolder = 1;
}
def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
let arguments = (ins
Torch_BoolType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
AllowsTypeRefinement,
HasValueSemantics,
@ -9580,6 +9818,7 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
@ -9850,6 +10089,31 @@ def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [
}];
}
def Torch_AtenPowIntFloatOp : Torch_Op<"aten.pow.int_float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::pow.int_float : (int, float) -> (float)`";
let arguments = (ins
Torch_IntType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPowIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPowIntFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
AllowsTypeRefinement,
HasValueSemantics,
@ -10307,6 +10571,7 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
@ -10359,6 +10624,32 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
}];
}
def Torch_AtenHardtanhBackwardOp : Torch_Op<"aten.hardtanh_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchScalarType:$min_val,
AnyTorchScalarType:$max_val
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHardtanhBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenHardtanhBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
AllowsTypeRefinement,
HasValueSemantics,
@ -10733,6 +11024,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
@ -10933,11 +11225,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`";
let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims,
Torch_IntType:$correction,
Torch_FloatType:$correction,
AnyTorchOptionalIntType:$output_dtype
);
let results = (outs

View File

@ -376,9 +376,6 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
Pure,
TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
"isValidSubtype">,
AllowedInModuleInitializer,
]> {
let summary = "TorchScript prim::TupleConstruct op";
@ -397,6 +394,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
let assemblyFormat = [{
$elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result))
}];
let hasVerifier = 1;
}
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [

View File

@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
@ -121,8 +123,7 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps);
createVerifyBackendContractNoDecompositionsPass();
StringRef getAbstractInterpLibrary();

View File

@ -343,24 +343,17 @@ def LowerToBackendContract
let dependentDialects = ["func::FuncDialect"];
}
def VerifyBackendContract
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
def VerifyBackendContractNoDecompositions
: Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> {
let summary = "Check that program satisfies backend contract.";
let constructor = [{
mlir::torch::Torch::createVerifyBackendContractPass(
/*decompose=*/true, /*backendLegalOps=*/{})
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass()
}];
let description = [{
This pass performs a set of inspections to check that program satisfies backend
contract. In case of check failure it prints out the error message and returns
`signalPassFailure()` status.
contract assuming that no decompositions were applied. In case of check failure
it prints out the error message and returns `signalPassFailure()` status.
}];
let options = [
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend.">
];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -9,6 +9,7 @@
#define TORCHMLIR_DIALECT_TORCH_UPSTREAM_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
// For layering reasons, the parts of the core MLIR compiler code written in C++
// never take a C++ dependency on Torch itself (any code depending on Torch C++
@ -160,6 +161,15 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
//===----------------------------------------------------------------------===//
// Possible value for `reduce` argument for Scatter reduce ops.
// Source:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
//===-----------------------------------------------------------------------===//
enum ReductionType { MAX, MEAN, MIN, SUM, PROD };
ReductionType get_reduction_enum(const llvm::StringRef &reduce);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir

View File

@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType(
FailureOr<Type> getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()

View File

@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_MHLO
struct MhloBackendPipelineOptions
: public PassPipelineOptions<MhloBackendPipelineOptions> {
// Do not register the stablehlo options if the stablehlo target is disabled
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct StablehloBackendPipelineOptions
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
Option<bool> enableStaticShape{
*this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions
llvm::cl::init(false)};
};
void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, const MhloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

View File

@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#ifdef TORCH_MLIR_ENABLE_MHLO
def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the mhlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO
#endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -61,7 +61,7 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
return wrap(Torch::TupleType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}
@ -89,7 +89,7 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
return wrap(Torch::UnionType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}
@ -230,7 +230,7 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
@ -293,7 +293,7 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked.
if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}

View File

@ -3,13 +3,7 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
set(LinkedLibs
MLIRFuncDialect
MLIRIR
MLIRSupport
@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll
TorchMLIRRefBackend
)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
MhloPasses
MhloToLinalg
StablehloToMhlo
)
endif()
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
${LinkedLibs}
)
torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()
add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram)
@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs
MhloPasses
TorchMLIRTorchToMhlo)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
endif()
add_mlir_library(TorchMLIRConversionPasses

View File

@ -9,15 +9,15 @@
#include "torch-mlir/Conversion/Passes.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
@ -32,12 +32,4 @@ namespace {
void mlir::torch::registerConversionPasses() {
::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createSymbolicShapeOptimizationPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
}

View File

@ -68,7 +68,7 @@ public:
// temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar);

View File

@ -232,6 +232,67 @@ public:
return success();
}
};
class ConvertTorchConstantIntOp
: public OpConversionPattern<Torch::ConstantIntOp> {
public:
using OpConversionPattern<Torch::ConstantIntOp>::OpConversionPattern;
using OpAdaptor = Torch::ConstantIntOp::Adaptor;
LogicalResult
matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// note: arith.constant only accept singless integer, so convert singed to
// singless
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getIntegerAttr(rewriter.getI64Type(),
op.getValueAttr().getValue()));
return success();
}
};
} // namespace
namespace {
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenFloatScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result =
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value operandA =
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
Value operandB =
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
if (resultType.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
} else if (resultType.isa<mlir::IntegerType>()) {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
} else {
return rewriter.notifyMatchFailure(
op, "unimplemented: only support integer or float result type");
}
return success();
}
};
} // namespace
namespace {
@ -381,8 +442,14 @@ public:
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
context);
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
target.addIllegalOp<AtenFloatScalarOp>();
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);

View File

@ -463,8 +463,8 @@ public:
}
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize);
ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize);
// Association indices for expand/collapse ops. These two vectors
// are populated such that two entries at the same index corresponds
@ -1117,6 +1117,18 @@ public:
RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = newResultType.getElementType();
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
builder.create<linalg::YieldOp>(loc, elem);
};
for (size_t i = 0; i < tensors.size(); ++i) {
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
}
int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
@ -1136,7 +1148,7 @@ public:
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : makeArrayRef(tensors).drop_front()) {
for (auto tensor : ArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
@ -1270,7 +1282,7 @@ public:
/*resultType=*/selfType,
/*inputs=*/broadcastedSrc,
/*outputs=*/self,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0];

View File

@ -81,9 +81,21 @@ public:
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
maxDimOp,
"aten.max_dim to linalg.* requires Float input element type");
if (inElementType.isa<mlir::IntegerType>()) {
auto integerTy = maxDimOp.getSelf()
.getType()
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
if (integerTy.isUnsigned())
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires input element type "
"to be signed in case of integer");
} else {
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
"input element type");
}
}
// Constant op to account for the reduction along dim.
@ -104,13 +116,23 @@ public:
Value initTensorMax = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), inElementType);
FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
inElementType,
APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
Value fillValueMax;
if (inElementType.isa<mlir::FloatType>()) {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inElementType,
APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
true)));
} else {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
inElementType,
APSInt::getSignedMinValue(
inElementType.cast<mlir::IntegerType>().getWidth())));
}
Value fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
.result();
@ -152,10 +174,18 @@ public:
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));
auto resultMax = rewriter.create<arith::MaxFOp>(
nestedLoc, newValue, oldValue);
Value predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
Value resultMax, predicate;
if (inElementType.isa<mlir::FloatType>()) {
resultMax =
rewriter.create<arith::MaxFOp>(nestedLoc, newValue, oldValue);
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
} else {
resultMax =
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
}
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(

View File

@ -127,9 +127,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}
// Create an uninitialized tensor of `resultSize` shape and fill it with
@ -227,9 +232,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}
// Create an uninitialized tensor of `resultSize` shape.

View File

@ -59,6 +59,15 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs);
}
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs,
Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UGE,
arith::CmpIPredicate::uge,
arith::CmpIPredicate::sge>(
b, loc, elementalType, lhs, rhs);
}
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULT,
@ -67,6 +76,14 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs);
}
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULE,
arith::CmpIPredicate::ule,
arith::CmpIPredicate::sle>(
b, loc, elementalType, lhs, rhs);
}
static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UEQ,
@ -117,6 +134,46 @@ static Value createCalculationForMathOpWithDtypeConversion(
return b.create<MathOpTy>(loc, arg);
}
template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -177,8 +234,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
(!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
clone.emitError("unimplemented: only default memory format is supported");
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
clone.emitError("unimplemented: only contiguous and channels last memory "
"format is supported");
return nullptr;
}
return payloadArgs[0];
@ -293,7 +352,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
round.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
return b.create<math::RoundOp>(loc, payloadArgs[0]);
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]);
}
if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!prelu.getType()
@ -370,6 +429,29 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
}
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
if (!hardtanhBackward.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value gradOutput = payloadArgs[0];
Type elementType = gradOutput.getType();
Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType);
Value constantZero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType);
Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType);
Value lesser =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, self, min);
Value greater =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, self, max);
Value cmp = b.create<arith::OrIOp>(loc, lesser, greater);
return b.create<arith::SelectOp>(loc, cmp, constantZero, gradOutput);
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(add.getType())
@ -463,64 +545,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
return nullptr;
}
Type elementalType =
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]);
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
AtenEqTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], payloadArgs[1]);
if (elementalType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], payloadArgs[1]);
}
eqTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
AtenLtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createLessThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]);
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
@ -964,18 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType();
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
}
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillScalar.getType())
.cast<RankedTensorType>()
.getElementType();
Value input = payloadArgs[0];
Value mask = payloadArgs[1];
Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
}
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillTensor.getType())
@ -1034,7 +1065,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value allOnesVal = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(
elementType,
APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth())));
APSInt::getAllOnes(elementType.getIntOrFloatBitWidth())));
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
}
@ -1082,10 +1113,10 @@ public:
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
@ -1561,12 +1592,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp>();
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -1,35 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
MhloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
DEPENDS
MhloDialect
MhloToLinalg
MLIRMhloPassIncGen
LMHLOTransformsPassIncGen
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MhloDialect
MhloToLinalg
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToMhlo)

View File

@ -1,74 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -29,7 +30,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
if (selfRank > otherRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits);
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits);
if (failed(unsqueezeInfo))
return failure();
other = *unsqueezeInfo;
} else if (otherRank > selfRank) {
auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
unsqueezeDims, dimSizeIndexBits);
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
dimSizeIndexBits);
if (failed(unsqueezeInfo))
return failure();
self = *unsqueezeInfo;
@ -78,7 +79,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
@ -91,7 +93,8 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMaxValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
@ -105,7 +108,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true));
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
if (elementType.isa<mlir::IntegerType>()) {
@ -118,7 +122,8 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMinValue(integerType.getWidth()));
}
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
return rewriter
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult();
}
return failure();
@ -126,7 +131,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
// These legalizations are for unary ops.
namespace {
template <typename AtenOpT, typename MhloOpT>
template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -137,13 +142,13 @@ public:
Value self = adaptor.getSelf();
auto selfType = self.getType().cast<TensorType>();
if (!selfType) {
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
self = mhlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
self = hlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
return success();
}
};
@ -152,7 +157,7 @@ public:
// These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these.
namespace {
template <typename AtenOpT, typename MhloOpT>
template <typename AtenOpT, typename StablehloOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -164,10 +169,10 @@ public:
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<MhloOpT>(
rewriter.replaceOpWithNewOp<StablehloOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -198,7 +203,7 @@ public:
.template dyn_cast<TensorType>();
if (!outType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
@ -216,9 +221,9 @@ public:
SmallVector<int32_t> values(size, fillVal);
auto constOp =
mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
return success();
}
};
@ -247,8 +252,8 @@ public:
->convertType(op.getType())
.template cast<TensorType>();
lhs = mhlo::promoteType(rewriter, lhs, outTy);
rhs = mhlo::promoteType(rewriter, rhs, outTy);
lhs = hlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr);
@ -274,7 +279,7 @@ public:
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -287,18 +292,19 @@ public:
}
if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs);
}
}
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
@ -328,7 +334,7 @@ public:
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -343,11 +349,12 @@ public:
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -368,15 +375,15 @@ public:
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
return success();
@ -401,7 +408,7 @@ public:
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsTy)
return op.emitError("only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in StableHLO");
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -414,11 +421,12 @@ public:
}
if (!rhsTy) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
lhsElemTy);
}
// TODO: what is the PyTorch default type promotion?
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
rhs = hlo::promoteType(rewriter, rhs, lhsTy);
chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr;
@ -485,8 +493,8 @@ public:
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -537,8 +545,8 @@ public:
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation);
return success();
}
};
@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
Value self = adaptor.getSelf();
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
return success();
}
@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
} else {
Value inputRank = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
inputRank);
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getIndexType(), dim);
}
@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
AtenWhereSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
Value cond = adaptor.getCondition();
Value other = adaptor.getOther();
@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
return op.emitError("failed broadcast other and condition ranks");
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
op,
getTypeConverter()->convertType(op.getType()),
op, getTypeConverter()->convertType(op.getType()),
ArrayRef<Value>{cond, self, other});
return success();
}
@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
.cast<RankedTensorType>();
if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
return success();
}
@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers =
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
}
@ -708,28 +715,11 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
permutation);
return success();
}
// AtenTanhOp
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
} else {
return op.emitError(
"only floating-point datatype legalization currently supported");
}
}
// ValueTensorLiteralOp
template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
@ -751,16 +741,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
});
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
valueAttr);
return success();
}
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
adaptor.getValue());
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
adaptor.getValue());
return success();
}
// AtenReciprocalOp
// Reciprocal(x) = Div(1, x)
template <>
@ -777,7 +767,45 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
}
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
// AtenPowTensorScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.getExponent();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
rewriter.replaceOp(op, result);
return success();
}
@ -790,9 +818,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto outputElemType = outputType.getElementType();
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, mhloTensor);
Value stablehloTensor = hlo::scalarToStablehloTensor(
rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, stablehloTensor);
return success();
}
@ -815,7 +843,6 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success();
}
// AtenReluOp
// Relu(x) = Max(0, x)
template <>
@ -836,11 +863,10 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
return success();
}
// Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
template <>
@ -857,12 +883,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul);
return success();
}
@ -881,7 +907,6 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success();
}
// AtenBatchNormOp
template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
@ -919,28 +944,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) {
weight = mhlo::getConstantOfShape(
weight = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, bias))) {
bias = mhlo::getConstantOfShape(
bias = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningVar))) {
runningVar = mhlo::getConstantOfShape(
runningVar = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningMean))) {
runningMean = mhlo::getConstantOfShape(
runningMean = hlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
@ -983,10 +1008,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
rewriter.replaceOp(op, batchNormTrainingResult.getResult(0));
return success();
} else {
@ -995,10 +1021,11 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0];
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
// Feature counts must match among operands of
// stablehlo::BatchNormInferenceOp.
Value inputCasted =
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute.
@ -1008,7 +1035,6 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
}
}
// AtenNativeLayerNormOp
template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
@ -1076,21 +1102,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
}
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
numEmbeddingDimSize};
SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
auto mhloBatchNormOutTy =
auto stablehloBatchNormOutTy =
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
auto mhloBathNormOutMeanOrVarTy =
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
meanOrVarStablehloOutShape, inputTy.getElementType());
// Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), stablehloBatchNormOutTy, input,
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
.value());
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>()
@ -1103,16 +1129,18 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto oneOrZeroConstType =
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
Value scale = rewriter.create<mhlo::ConstantOp>(
Value scale = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
Value offset = rewriter.create<mhlo::ConstantOp>(
Value offset = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
auto batchNormTrainingResult =
rewriter.create<stablehlo::BatchNormTrainingOp>(
op->getLoc(), stablehloBatchNormOutTy,
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
// Reshape back
auto outputTy =
@ -1120,36 +1148,35 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto outputMeanOrVarTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
auto output = rewriter.create<mhlo::DynamicReshapeOp>(
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())})
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())})
.value());
auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
mhlo::getConstTensor(
hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value());
auto var = rewriter.create<mhlo::DynamicReshapeOp>(
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
mhlo::getConstTensor(
hlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value());
// Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto outputMulWeight =
rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput =
rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput = rewriter.create<stablehlo::AddOp>(
op->getLoc(), outputMulWeight, bcastedBias);
rewriter.replaceOp(op, {finalOuput, mean, var});
return success();
}
// AtenCatOp
template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
@ -1173,11 +1200,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type
for (auto &v : builtinTensors) {
v = mhlo::promoteType(rewriter, v, outType);
v = hlo::promoteType(rewriter, v, outType);
}
size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
op, outType, ValueRange(builtinTensors), posDim);
return success();
}
@ -1225,7 +1252,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(checkNotNone(rewriter, op, minValue))) {
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) {
return rewriter.notifyMatchFailure(
@ -1233,7 +1261,8 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
}
minValue = *minInfo;
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
minValue =
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) {
return rewriter.notifyMatchFailure(
@ -1241,10 +1270,13 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
}
maxValue = *maxInfo;
} else {
minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
minValue =
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
}
rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
maxValue);
return success();
}
@ -1266,24 +1298,27 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: only int or float dtype supported");
}
Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
Value start =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
Value end =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
// Get length of the 1-d output tensor
Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step);
Value resultLength = rewriter.create<mhlo::ReshapeOp>(
Value resultLength = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<mhlo::ConvertOp>(
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<stablehlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
}
Value window =
rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions);
@ -1298,9 +1333,8 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto outType = this->getTypeConverter()
->convertType(op.getType())
.cast<TensorType>();
auto outType =
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
if (!outType) {
return op.emitError("only tensor type is supported");
}
@ -1320,26 +1354,27 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute
Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
Value erfArg =
rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
Value kBeta0 =
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
adaptor.getSelf());
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<mhlo::MulOp>(
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<stablehlo::MulOp>(
loc, outType, adaptor.getSelf(), adaptor.getSelf());
Value negHalfInputSquared =
rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
Value expRes =
rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
Value pdfTimesInput =
rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
Value pdfTimesInputAddCdf =
rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
pdfTimesInputAddCdf);
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
return success();
}
@ -1366,9 +1401,9 @@ public:
};
} // namespace
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>();
@ -1376,23 +1411,29 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp);
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
#undef INSERT_UNARY_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
@ -1459,9 +1500,9 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
@ -1482,10 +1523,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
context)
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
typeConverter, context)
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);

View File

@ -0,0 +1,29 @@
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -24,7 +25,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t, 4> startIndexMap(1, axis);
// indexVecDim
int64_t indexVecDim = indicesRank;
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedSliceDims,
@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
auto outputTy =
RankedTensorType::get(outputShape, inputRankTy.getElementType());
return rewriter
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr)
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr)
.getResult();
}
} // namespace
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// Ref:
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional)
// If specified, the entries at padding_idx do not contribute to the gradient;
// therefore, the embedding vector at padding_idx is not updated during training,
// i.e. it remains as a fixed “pad”.
// If specified, the entries at padding_idx do not contribute to the
// gradient; therefore, the embedding vector at padding_idx is not updated
// during training, i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional)
// If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False.
@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis(
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
return success();
@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
return success();
@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
auto options = getOptions();
auto indexShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param");
@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
loc, toConcatIndexType, index, toConcatIndexShape));
} else {
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
loc, toConcatIndexType, toConcatIndexShape,
rewriter.getI64IntegerAttr(i)));
}
}
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
startIndexMap.push_back(i);
}
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
return success();
}
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
ArrayRef<int64_t> broadcastDims) {
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
auto loc = op->getLoc();
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType());
@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, mhloShape, broadcastAttr);
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr);
return broadcast;
}
@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
ArrayRef<int64_t> inpTransDims) {
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank();
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
auto inpShape = inputTy.getShape();
std::vector<int64_t> newShape;
newShape.reserve(rank);
@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
permuteAttr);
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy,
input, permuteAttr);
return result.getResult();
}
@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
}
// set result dimensions
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) &&
lhsResultDim >= 0) {
outShape.push_back(lhsShape[lhsResultDim]);
}
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) &&
rhsResultDim >= 0) {
outShape.push_back(rhsShape[rhsResultDim]);
}
return RankedTensorType::get(outShape, lhsTy.getElementType());
@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims,
dimSizeIndexBits);
auto lhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims,
dimSizeIndexBits);
auto rhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -218,8 +221,8 @@ public:
if (lhsRank <= 2 && rhsRank <= 2) {
auto tensorType =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
nullptr);
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs,
rhs, nullptr);
return success();
}
@ -253,8 +256,8 @@ public:
lhsContractingDim = nBatchDims;
}
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
@ -264,8 +267,8 @@ public:
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.getResult();
return success();
}
@ -312,7 +315,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
return success();
}
@ -335,7 +338,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
@ -371,7 +374,7 @@ public:
if (!lhsTy || !rhsTy)
return op.emitError(
"only ranked tensor types are supported in MHLO matmul");
"only ranked tensor types are supported in StableHLO matmul");
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();
@ -398,10 +401,10 @@ public:
auto bias = adaptor.getBias();
auto biasTy = bias.getType();
// MHLO does not mandate that elementwise op tensors need to be ranked.
// StableHLO does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() &&
!biasTy.template isa<RankedTensorType>())
return op.emitError("only ranked tensor types are supported in MHLO "
return op.emitError("only ranked tensor types are supported in StableHLO "
"matmul for bias tensor");
// weight.T
@ -427,14 +430,14 @@ public:
auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
Value matmulPlusBias = matmulOutput;
@ -464,7 +467,7 @@ public:
auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank();
const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank);
@ -488,7 +491,7 @@ public:
}
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
@ -497,7 +500,7 @@ public:
for (int64_t i = 0; i <= rank; i++)
transposeDims[i] = i;
std::swap(transposeDims[1], transposeDims[0]);
weight = rewriter.create<mhlo::TransposeOp>(
weight = rewriter.create<stablehlo::TransposeOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
@ -509,7 +512,7 @@ public:
weightShapeVec[1] = OCMulGValue;
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
weight = rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
return weight;
@ -544,25 +547,27 @@ public:
}
// Prepare for transposed convolution
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride =
rewriter.getI64TensorAttr(stablehloStrideVec);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
mhloPaddingVec[i * 2] = padInt;
mhloPaddingVec[i * 2 + 1] = padInt;
stablehloPaddingVec[i * 2] = padInt;
stablehloPaddingVec[i * 2 + 1] = padInt;
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
mhloPaddingVec);
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
DenseIntElementsAttr mhloLhsDilation =
rewriter.getI64TensorAttr(mhloLhsDilationVec);
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
DenseIntElementsAttr mhloRhsDilation =
rewriter.getI64TensorAttr(mhloRhsDilationVec);
stablehloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(),
stablehloRhsDilationVec.begin());
DenseIntElementsAttr stablehloRhsDilation =
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
@ -571,8 +576,8 @@ public:
for (int i = 0; i < nSpatialDims; ++i) {
spatialDims.push_back(i + 2);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get(
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDims,
@ -583,17 +588,18 @@ public:
/*outputSpatialDimensions=*/spatialDims);
// Reverse and transpose weight
weight = rewriter.create<mhlo::ReverseOp>(
weight = rewriter.create<stablehlo::ReverseOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
if (groups != 1) {
weight = reshapeConvWeight(rewriter, op, weight, groups);
}
// Create transposed convolution
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
static_cast<uint64_t>(groups), 1, precisionConfig);
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, stablehloStride,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
// Handle output padding
if (!needHandleOutputPadding) {
@ -605,8 +611,8 @@ public:
std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2);
Value paddingValue =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh =
@ -614,7 +620,7 @@ public:
mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec);
auto paddedOutput = rewriter.create<mhlo::PadOp>(
auto paddedOutput = rewriter.create<stablehlo::PadOp>(
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
edgePaddingHigh, interiorPadding);
@ -628,22 +634,22 @@ public:
ArrayRef<int64_t> dilation, int64_t groups) const {
int64_t nDims = outType.getRank();
// Get mhlo::ConvolutionOp attributes
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
// Get stablehlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()),
stride);
std::vector<int64_t> mhloPaddingVec;
std::vector<int64_t> stablehloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) {
mhloPaddingVec.emplace_back(padding[i]);
mhloPaddingVec.emplace_back(padding[i]);
stablehloPaddingVec.emplace_back(padding[i]);
stablehloPaddingVec.emplace_back(padding[i]);
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
mhloPaddingVec);
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
stablehloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
@ -651,8 +657,8 @@ public:
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get(
stablehlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDimensions,
@ -662,17 +668,18 @@ public:
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDimensions);
// mhlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr mhloLhsDilation;
// stablehlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation;
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
static_cast<uint64_t>(groups), 1, precisionConfig);
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, stablehloWindowStride,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1,
precisionConfig);
return mhloConvOp.getResult();
return stablehloConvOp.getResult();
}
LogicalResult
@ -754,21 +761,22 @@ public:
}
}
Value mhloConvResult;
Value stablehloConvResult;
if (transposed) {
mhloConvResult = convertTransposedConv(
stablehloConvResult = convertTransposedConv(
op, rewriter, outTy, input, weight, stride, padding, dilation,
outputPadding, groups, needHandleOutputPadding);
} else {
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
stride, padding, dilation, groups);
stablehloConvResult =
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding,
dilation, groups);
}
auto bias = adaptor.getBias();
// No bias provided
if (failed(checkNotNone(rewriter, op, op.getBias()))) {
rewriter.replaceOp(op, mhloConvResult);
rewriter.replaceOp(op, stablehloConvResult);
return success();
}
@ -790,21 +798,21 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
const auto &options = getOptions();
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy);
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
bias, bcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, stablehloConvResult, bias, bcastDimensions);
return success();
}
};
} // namespace
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \

View File

@ -7,15 +7,16 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,26 +29,26 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
op->emitError("unimplemented lowering in AtenPoolingOp");
@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
stablehloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value result =
rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
}
rewriter.replaceOp(op, reduceWindowOp.getResults());
@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(),
mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
stablehloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
stablehloPadding);
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto initIndexTensor =
rewriter
.create<mhlo::DynamicIotaOp>(
.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(initIndexShapeForType,
rewriter.getI64Type()),
@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto indexTensor =
rewriter
.create<mhlo::DynamicReshapeOp>(
.create<stablehlo::DynamicReshapeOp>(
op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()),
initIndexTensor, inputShapeTensor)
.getResult();
Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr;
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>(
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx =
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>(
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input
SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(),
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(),
mhloKernelSize.begin() + inputRank - 2);
mhloPadding[mhloPadding.size() - 4] = padding[0];
mhloPadding[mhloPadding.size() - 3] = padding[0];
mhloPadding[mhloPadding.size() - 2] = padding[1];
mhloPadding[mhloPadding.size() - 1] = padding[1];
stablehloKernelSize.begin() + inputRank - 2);
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
mhloKernelSize);
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
mhloStride);
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
mhloDilation);
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
mhloPadding);
stablehloPadding);
auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
// Use kernel size as the divisor
if (countIncludePad) {
Value divisor = mhlo::getConstTensor<int64_t>(
Value divisor = hlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value();
divisor = mhlo::promoteType(rewriter, divisor, outTy);
divisor = hlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success();
}
// Use another mhlo.ReduceWindowOp to get the divisor
// Use another stablehlo.ReduceWindowOp to get the divisor
Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy);
const auto &options = getOptions();
auto inputShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
windowDilations, pad);
@ -522,18 +526,99 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult =
rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOpWithNewOp<mhlo::DivOp>(
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success();
}
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
// AtenCumsumOp
template <>
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto inputShape = inputTy.getShape();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim must be a constant int");
}
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is out of range");
}
if (inputTy.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: cumsum dim must be static");
}
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
stablehloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
sumBlock.addArgument(blockArgumentType, op->getLoc());
sumBlock.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = sumBlock.args_begin();
auto *secondArg = std::next(firstArg);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOp(op, reduceWindowSum.getResults());
return success();
}
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
@ -542,4 +627,6 @@ void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context, options);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
}

View File

@ -0,0 +1,69 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_stablehlo {
struct TorchToStablehloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToStablehloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToStablehloOptions &getOptions() const { return options; }
private:
TorchToStablehloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
} // namespace torch_to_stablehlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +26,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
} else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get(
constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr);
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}
}
@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
return std::nullopt;
Value initIndex;
if (dimSizeIndexBits == 32) {
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else {
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
op->getLoc(),
RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim));
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), ValueRange{input, indexTensor},
ValueRange{
initValue,
@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
},
dimensions);
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
// Add block arguments
auto blockValArgumentType =
@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg);
mhlo::ComparisonTypeAttr compareTypeAttr;
stablehlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::FLOAT);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), mhlo::ComparisonType::SIGNED);
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
}
mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
mhlo::ComparisonDirection::EQ);
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<mhlo::CompareOp>(
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<mhlo::SelectOp>(
Value retValResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<mhlo::CompareOp>(
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr);
Value minIdx =
rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
*secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<mhlo::SelectOp>(
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<mhlo::ReturnOp>(
rewriter.create<stablehlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
}
return mhloReduceOp.getResults();
return stablehloReduceOp.getResults();
}
namespace {
@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenArgmaxOp to MHLO");
"AtenArgmaxOp to StableHLO");
}
int64_t dim;
@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1],
outShapeTensor);
return success();
}
rewriter.replaceOp(op, mhloReduceResults[1]);
rewriter.replaceOp(op, stablehloReduceResults[1]);
return success();
}
} // namespace
@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxDimOp to MHLO");
"AtenMaxDimOp to StableHLO");
}
RankedTensorType valResultType = getTypeConverter()
@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec,
dim, options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
auto stablehloReduceValueResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), valResultType, stablehloReduceResults[0],
outShapeTensor);
auto stablehloReduceIndexResult =
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success();
}
rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
rewriter.replaceOp(op,
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success();
}
} // namespace
@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
}
auto inputElemTy = inputTy.getElementType();
@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumOp to MHLO");
"AtenSumOp to StableHLO");
}
SmallVector<int64_t> dims;
@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>(
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) {
@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to MHLO");
"AtenMaxOp to StableHLO");
}
SmallVector<int64_t> dims;
@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = mhloReduceOp.getBody().emplaceBlock();
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc());
@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<mhlo::MaxOp>(
Value maxResult = rewriter.create<stablehlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type.
auto dstElemTy = outTy.getElementType();
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
}
auto inputElemTy = inputTy.getElementType();
@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumDimIntListOp to MHLO");
"AtenSumDimIntListOp to StableHLO");
}
SmallVector<int64_t> inputDims;
@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) return failure();
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Region &region = mhloReduceOp.getBody();
Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<mhlo::AddOp>(
Value addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult);
}
if (keepDim) {
const auto &options = getOptions();
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()),
mhloReduceOp.getResult(0), outShapeTensor);
stablehloReduceOp.getResult(0), outShapeTensor);
return success();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
mhloReduceOp.getResults());
stablehloReduceOp.getResults());
return success();
}
} // namespace
// AtenFrobeniusNormDimOp
// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
// + mhlo.sqrt
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
// dims)
// + stablehlo.sqrt
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const TorchToMhloOptions &options = getOptions();
const TorchToStablehloOptions &options = getOptions();
Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
}
// Sort the dims in ascending order, making the conversion
// Sort the dims in ascending order, making the conversion
// stable with unordered dims.
std::sort(dims.begin(), dims.end());
@ -624,58 +642,57 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported");
}
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) {
return failure();
}
auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getI64TensorAttr(dims));
Region &region = squareSumReduceOp.getBody();
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());
auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
auto firstArgument = *block.args_begin();
auto secondArgument = *block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), blockArgumentTy,
DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
auto squareResult = rewriter.create<mhlo::PowOp>(
op->getLoc(), abs, constantOrd2);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
*firstArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
auto addResult = rewriter.create<stablehlo::AddOp>(
op->getLoc(), firstArgument, secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult());
}
auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
squareSumReduceOp.getResult(0));
auto output =
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0));
if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto outShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor);
return success();
@ -685,9 +702,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
} // namespace
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \

View File

@ -7,11 +7,12 @@
//
//===----------------------------------------------------------------------===//
#include "./MhloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric>
@ -21,27 +22,27 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir {
namespace mhlo {
namespace hlo {
// Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) {
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) {
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -65,8 +66,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -88,8 +89,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -111,8 +112,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -133,8 +134,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
auto const_op = rewriter.create<stablehlo::ConstantOp>(
op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype) {
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<mhlo::ReshapeOp>(
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<stablehlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor);
}
@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType,
input);
}
return input;
}
@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
input =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
input);
}
ArrayRef<int64_t> inShape = in_type.getShape();
@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()),
bcastDims);
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
input, bcast_attr);
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
op->getLoc(), outType, input, bcast_attr);
return bcast_op.getResult();
}
@ -348,8 +350,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
return rewriter
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
rewriter.getI64TensorAttr({}))
.create<stablehlo::DynamicBroadcastInDimOp>(
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({}))
.getResult();
}
} // namespace mhlo
} // namespace hlo
} // namespace mlir

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@ -18,22 +18,22 @@
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace hlo {
using mlir::ConversionPatternRewriter;
// Create a 32-bit float constant operator from a float
Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
// Create a 64-bit float constant operator from a double
Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val);
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val);
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
// To create INT48 MHLO constant, need to pass in llvm::APInt instead.
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
@ -42,8 +42,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Value scalarValue, Type dtype);
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
} // namespace mhlo
} // namespace hlo
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H

View File

@ -7,17 +7,18 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -30,17 +31,18 @@ using namespace mlir::torch::Torch;
namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
class ConvertTorchToStablehlo
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
public:
ConvertTorchToMhlo() = default;
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
ConvertTorchToStablehlo() = default;
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<stablehlo::StablehloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
@ -48,7 +50,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect,
tensor::TensorDialect, arith::ArithDialect>();
TypeConverter typeConverter;
@ -57,20 +59,20 @@ public:
RewritePatternSet patterns(context);
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
enableI32Index ? 32u : 64u};
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
torch_to_stablehlo::TorchToStablehloOptions options{
enableStaticShape, enableI32Index ? 32u : 64u};
torch_to_stablehlo::populateBasicOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populateReductionOpPatternsAndLegality(
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
@ -82,13 +84,13 @@ public:
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>(false, false);
mlir::torch::createConvertTorchToStablehloPass() {
return std::make_unique<ConvertTorchToStablehlo>(false, false);
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
enableI32Index);
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape,
enableI32Index);
}

View File

@ -7,14 +7,15 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "PopulatePatterns.h"
#include "StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -28,7 +29,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;
using namespace mlir::torch::torch_to_stablehlo;
namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize].
@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
return rewriter.create<mhlo::RealDynamicSliceOp>(
return rewriter.create<stablehlo::RealDynamicSliceOp>(
loc, outTy, input, startTensor, endTensor, stridesTensor);
}
@ -144,7 +145,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
}
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -179,7 +180,7 @@ public:
auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -214,17 +215,18 @@ public:
numel);
if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
Value stablehloShape =
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
loc, stablehloShape.getType(), numel, stablehloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
dims.push_back(r);
}
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto mhloShape =
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success();
}
@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim);
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo;
auto mhloShape =
auto stablehloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
return success();
}
@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");
@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success();
}
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -17,6 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ValueRange.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -26,6 +27,9 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir;
using namespace mlir::torch;
@ -52,6 +56,147 @@ using namespace mlir::torch::TMTensor;
// that these patterns become mostly mechanical associations of
// "aten.foo -> linalg.foo".
static Attribute getNumericLimit(PatternRewriter &rewriter, Type elementType,
bool getMin = true) {
auto bitWidth = elementType.getIntOrFloatBitWidth();
if (llvm::isa<mlir::IntegerType>(elementType)) {
if (getMin) {
return rewriter.getIntegerAttr(elementType,
APInt::getSignedMinValue(bitWidth));
} else {
return rewriter.getIntegerAttr(elementType,
APInt::getSignedMaxValue(bitWidth));
}
} else if (mlir::FloatType floatType =
llvm::dyn_cast<mlir::FloatType>(elementType)) {
return rewriter.getFloatAttr(
elementType,
APFloat::getLargest(floatType.getFloatSemantics(), getMin));
} else {
llvm_unreachable("Only float/integer types are supported!");
}
}
// This function will reformat the `index` and `src` from torch operations
// like `torch.scatter` or `torch.scatter_reduce` to match the expected
// input for the TMScatterOp. It will return the reformated `index` and `src`
// as a pair of mlir::Value that can be used as inputs for the TMScatterOp.
static std::pair<Value, Value>
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
Value indices, Value src,
int64_t dim) {
// Get information on types for inputs
RankedTensorType indexType = indices.getType().cast<RankedTensorType>();
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>();
// Store location for insertions
Location loc = src.getLoc();
Value indexSize = getTensorSize(rewriter, loc, indices);
indexSize = castIntToIndex(rewriter, loc, indexSize);
SmallVector<Value> indexShape = getTensorSizes(rewriter, loc, indices);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
// We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...)
SmallVector<Value> indSliceShape({indexSize, cstOne});
Value indSlice =
createZeroInitTensor(rewriter, loc, indSliceShape, rewriter.getI32Type());
// New output shape will be equal to the product of the dimensions of the
// updates
SmallVector<Value> outputs(indexType.getRank(), indSlice);
outputs.push_back(createZeroInitTensor(rewriter, loc, {indexSize},
srcSelf.getElementType()));
SmallVector<Type> outputsType(indexType.getRank(), indSlice.getType());
outputsType.push_back(outputs[indexType.getRank()].getType());
// Create mapping over flattened iteration space
SmallVector<AffineExpr> indSliceExpr = {rewriter.getAffineDimExpr(0),
rewriter.getAffineConstantExpr(0)};
SmallVector<AffineMap> mapping(
indexType.getRank(), AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
indSliceExpr, src.getContext()));
// Mapping for updates
mapping.push_back(rewriter.getDimIdentityMap());
SmallVector<utils::IteratorType> iteratorTypes(
{utils::IteratorType::parallel});
// This function goes over the flattened iteration space of the `indices`
// and `src`. It will reconstruct the original induction variables based
// on the current flattened index. The flattened iteration space is required
// because TMTensorScatterOp expects a list of single element updates.
auto flattenedUpdates =
rewriter
.create<linalg::GenericOp>(
loc, outputsType, ValueRange(), outputs, mapping, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indexValues(indexType.getRank());
Value ind = b.create<linalg::IndexOp>(loc, 0);
for (int i = indexType.getRank() - 1; i >= 0; i--) {
indexValues[i] =
b.create<arith::RemSIOp>(loc, ind, indexShape[i]);
ind = b.create<arith::DivSIOp>(loc, ind, indexShape[i]);
}
// Extract the scatter index and update value
Value extractIndexValue =
b.create<tensor::ExtractOp>(loc, indices, indexValues);
Value extractSrcValue =
b.create<tensor::ExtractOp>(loc, src, indexValues);
SmallVector<Value> yieldVals;
for (Value v : indexValues) {
Value scalar = castIndexToInt64(b, loc, v);
yieldVals.push_back(b.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), scalar));
}
// Replace the original index with the index specified
// by the scatter.
yieldVals[dim] = b.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), extractIndexValue);
yieldVals.push_back(extractSrcValue);
b.create<linalg::YieldOp>(loc, yieldVals);
})
.getResultTensors();
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
// The result of the linalg::Generic operation gives us (rank(`src`) + 1)
// 1D-tensors where each contains a number of elements equal to the total
// number of elements in the `src` tensor. The indices must now be
// constructed by concatanating the first rank(`src`) tensors together. The
// new `src` tensor is the last tensor returned from the linalg::Generic
// operation.
SmallVector<Value> offsets = {
rewriter.create<arith::ConstantIndexOp>(loc, 0),
rewriter.create<arith::ConstantIndexOp>(loc, 0)};
SmallVector<Value> strides = {
rewriter.create<arith::ConstantIndexOp>(loc, 1),
rewriter.create<arith::ConstantIndexOp>(loc, 1)};
Value indicesRank =
rewriter.create<arith::ConstantIndexOp>(loc, indexType.getRank());
Value flattenedIndices = createZeroInitTensor(
rewriter, loc, SmallVector<Value>({indexSize, indicesRank}),
rewriter.getI32Type());
SmallVector<Value> scatterInputsVector(flattenedUpdates);
for (auto const slice : ArrayRef(scatterInputsVector).drop_back()) {
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, slice);
flattenedIndices = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, slice, flattenedIndices,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
// Increment offset to insert into next column
offsets[1] = rewriter.createOrFold<arith::AddIOp>(loc, offsets[1], cstOne);
}
return std::make_pair(flattenedIndices,
scatterInputsVector[indexType.getRank()]);
}
static Value createTMTensorScatterOp(
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
bool uniqueIndices,
@ -142,7 +287,7 @@ public:
// Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(maxTensorSizes),
context, llvm::ArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
@ -165,7 +310,7 @@ public:
SmallVector<int64_t> expandedInputSizes{
makeShapeTorchCompatible(inputType.getShape())[0], 1};
ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedInputSizes),
context, llvm::ArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
@ -286,9 +431,9 @@ public:
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
int64_t indexTensorSize = indexTensorType.getSizes()[0];
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype());
ValueTensorType expandedIndexTensorType =
ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
@ -552,6 +697,229 @@ public:
};
} // namespace
namespace {
class ConvertAtenScatterReduceTwoOp
: public OpConversionPattern<AtenScatterReduceTwoOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScatterReduceTwoOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
RankedTensorType selfType =
adaptor.getSelf().getType().cast<RankedTensorType>();
RankedTensorType indexType =
adaptor.getIndex().getType().cast<RankedTensorType>();
RankedTensorType srcType =
adaptor.getSrc().getType().cast<RankedTensorType>();
Value self = adaptor.getSelf();
if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(op,
"'self', 'index' and 'src' should all "
"have the same number of dimensions.");
std::string reduceType;
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceType)))
return rewriter.notifyMatchFailure(op,
"'reduce' must be a costant string");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "'dim' is not constant");
bool includeSelf;
if (!matchPattern(op.getIncludeSelf(), m_TorchConstantBool(&includeSelf)))
return rewriter.notifyMatchFailure(op, "'include_self' is not constant");
// Get reduce string as the equivalent enum
auto reduceEnum = torch_upstream::get_reduction_enum(reduceType);
// Get the inputs reformatted for the TMScatterOp
auto [indices, updates] =
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(
rewriter, adaptor.getIndex(), adaptor.getSrc(), dim);
// Value 'counts' will be used to tally the number of reductions into
// each unique index. The tally is used to calculate the average of the
// values scattered per index.
Value counts = nullptr;
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
SmallVector<Value> selfShape =
getTensorSizes(rewriter, loc, adaptor.getSelf());
Attribute initAttr;
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
initAttr = rewriter.getFloatAttr(srcType.getElementType(), 1);
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
initAttr = rewriter.getIntegerAttr(srcType.getElementType(), 1);
} else {
llvm_unreachable("Only integer/float types supported!");
}
Value initElement = rewriter.create<arith::ConstantOp>(loc, initAttr);
counts = createInitTensor(rewriter, loc, selfShape,
selfType.getElementType(), initElement);
}
// If the original values shouldn't be included, normalize the
// input tensor where the scatters take place.
if (!includeSelf) {
Value normalizationValue;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) {
// Set the values in the input tensor to '0' so they are not included
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(srcType.getElementType()));
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
// Set the values in the input tensor to '1' (multiplication identity)
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(srcType.getElementType(), 1.0));
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(srcType.getElementType(), 1));
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
// Set the values in the input tensor to the smallest element of that
// type
auto minAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/true);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, minAttr);
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
// Set the values in the input tensor to the largest element of that
// type
auto maxAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/false);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, maxAttr);
}
// Scatter the normalizations into the input tensor
Value indexSize = getTensorSize(rewriter, loc, adaptor.getIndex());
indexSize = castIntToIndex(rewriter, loc, indexSize);
Value normalizations = createInitTensor(
rewriter, loc, SmallVector<Value>({indexSize}),
srcType.getElementType(), /*init_element=*/normalizationValue);
self = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, self,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
});
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, counts,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
});
}
}
// Create final operation
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updates, indices, self,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::AddIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::AddFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MulIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MulFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MaxSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MaxFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MinSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MinFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
}
b.create<TMTensor::YieldOp>(loc, result);
});
// Special case for the mean
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, updates, indices, counts,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;
if (mlir::IntegerType intType =
llvm::dyn_cast<mlir::IntegerType>(current.getType())) {
Value constantUpdate = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(intType, 1));
result = b.create<arith::AddIOp>(loc, constantUpdate, current);
} else if (mlir::FloatType floatType =
llvm::dyn_cast<mlir::FloatType>(current.getType())) {
Value constantUpdate = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(floatType, 1.0));
result = b.create<arith::AddFOp>(loc, constantUpdate, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
b.create<TMTensor::YieldOp>(loc, result);
});
Value output = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, self),
selfType.getElementType());
// Finally divide the result
scatterOp =
rewriter
.create<linalg::MapOp>(
loc, ValueRange{scatterOp, counts}, output,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result;
if (llvm::isa<mlir::IntegerType>(args[0].getType())) {
result = b.create<arith::DivSIOp>(loc, args[0], args[1]);
} else if (llvm::isa<mlir::FloatType>(args[0].getType())) {
result = b.create<arith::DivFOp>(loc, args[0], args[1]);
} else {
llvm_unreachable("Only integer/float types supported!");
}
b.create<linalg::YieldOp>(loc, result);
})
.getResult()[0];
}
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
};
} // namespace
namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public:
@ -644,6 +1012,8 @@ public:
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context);
target.addIllegalOp<AtenScatterReduceTwoOp>();
patterns.add<ConvertAtenScatterReduceTwoOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@ -717,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -747,8 +748,8 @@ class ConvertAtenOneDimReductionOp
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef({reduceDim}));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -781,8 +782,8 @@ public:
reduceDims.push_back(i);
int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false;
return success();
@ -2645,6 +2646,36 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants");
// the shape -1 is inferred from other dimensions
size_t countNegativeShape{0};
// Check at most one -1 shape
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] < 0) {
countNegativeShape++;
if (countNegativeShape > 1)
return rewriter.notifyMatchFailure(op, "At most one -1 shape");
}
}
auto inputShape = selfType.getShape();
size_t totalSize = 1;
for (size_t i = 0; i < inputShape.size(); i++) {
totalSize *= inputShape[i];
}
size_t otherSize = 1;
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] > 0) {
otherSize *= outShape[i];
}
}
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] < 0) {
outShape[i] = totalSize / otherSize;
break;
}
}
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(outShape));
@ -2816,6 +2847,79 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
AtenHardtanhBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
if (!selfType) {
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");
}
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
}
// Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
}
Value gradOutput = adaptor.getGradOutput();
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>();
Type gradOutputElemType = gradOutputType.getElementType();
if (selfElemTy != gradOutputElemType) {
return rewriter.notifyMatchFailure(
op,
"Input element type should be same as the grad_output element type.");
}
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
Value maxVal, minVal;
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMinVal(), minVal,
selfElemTy, constTypeShape))) {
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMaxVal(), maxVal,
selfElemTy, constTypeShape))) {
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
Type outType = getTypeConverter()->convertType(op.getType());
Value lesser = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
minVal, adaptor.getSelf());
Value greater = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
adaptor.getSelf(), maxVal);
Value cmp = rewriter.create<tosa::LogicalOrOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
lesser, greater);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmp, replace,
gradOutput);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
AtenEmbeddingOp op, OpAdaptor adaptor,
@ -3113,31 +3217,70 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported");
}
SmallVector<int64_t> outShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape)))
SmallVector<int64_t> resultShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants");
// Get the result type
auto resultType = getTypeConverter()->convertType(op.getType());
SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape()));
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand
// irrespective of the users of the op result.
if (!llvm::equal(inputShape, outShape)) {
for (auto user : op->getResult(0).getUsers()) {
// This case is only supported if the result of the `broadcast_to` op is
// not used by an op which is a view like.
if (isViewLikeOp(user)) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) {
// If we reach here, then it means that the broadcasting is not required
// since the input and result are of same shape.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
// can check for broadcasting compatibility for the latter case. For
// broadcasting compatibility, either the shape of input and result should
// be equal at each dimenion or one of them should be 1.
if (selfType.getRank() != 0) {
for (unsigned i = 0; i < inputShape.size(); i++) {
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: broadcast not supported for this case");
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
}
}
}
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
// If the above condition hold true then we can directly create a const
// zero tensor of shape same as the result shape.
SmallVector<int64_t> zeroTensorShape{resultShape};
// create the 0 constant tensor
int64_t totalNumElements = 1;
for (auto dimSize : zeroTensorShape) {
totalNumElements = dimSize * totalNumElements;
}
// There is some danger here. For edge cases in floating point, x + 0 != x.
// The cases are denormalized values, which may get flushed, and -0 + 0 =
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
SmallVector<int64_t> intValues(totalNumElements, 0);
SmallVector<float> floatValues(totalNumElements, 0.0);
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
? tosa::getConstTensor<float>(
rewriter, op, floatValues, zeroTensorShape)
.value()
: tosa::getConstTensor<int64_t>(
rewriter, op, intValues, zeroTensorShape)
.value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
zeroTensor);
return success();
}
return rewriter.notifyMatchFailure(
@ -3232,6 +3375,171 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10],
// [11,12,13,14,15],[16,17,18,19,20]]) # 4*5
// i = tf.constant([[1,2,3], [3,2,1]]) # 2*3
// i_expand = tf.expand_dims(i,axis=2) # 2*3*1
// IndexTensorOutput = tf.gather_nd(t,tf.i_expand)
// = torch.ops.aten.index(t, (i, )) = t[i] # 2*3*5
// [[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
// Deal with torch.prim.ListConstruct of non const value to get the index
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
auto outType = getTypeConverter()->convertType(op.getType());
// Support for multiple indexes
if (indexTensors.size() > 1) {
// t[i, i]
// = torch.ops.aten.index(t,(i,i))
// = tensor([[ t[1,1], t[2,2], t[3,3]],
// [ t[3,3], t[2,2], t[1,1]]])
// = tensor([[ 7, 13, 19], [19, 13, 7]])
// = tf.gather_nd(t,tf.ii_expand)
// ii_expand
// = tf.concat((i_expand,i_expand), dim=2)
// = tf.constant([[[1,1],[2,2],[3,3]],
// [[3,3],[2,2],[1,1]]]) # 2*3*2
SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index input like torch.ops.aten.index(x,
// (None, index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
// create concat tensor for indicesTf
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}
// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
}
}
// concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2]
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
// Support for multiple index
auto index = indexTensors[0];
auto indexTorch = tensorsTorchType[0];
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor,
@ -3968,9 +4276,11 @@ public:
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
return op.emitError(
"unimplemented: only default memory format is supported");
"unimplemented: only contiguous and channels last memory "
"format is supported");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -4169,6 +4479,7 @@ public:
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
@ -4196,6 +4507,7 @@ public:
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);

View File

@ -230,6 +230,10 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) ||
(src.isF32() && dest.isF64()) ||
(src.isF64() && dest.isF32()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(1))) {
return success();

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
@ -31,11 +32,11 @@ namespace {
struct TorchInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
IRMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
IRMapping &) const final {
return true;
}
};

View File

@ -128,32 +128,36 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}
static Value getScalarValue(Value input, Location loc,
PatternRewriter &rewriter) {
static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::IntType>()) {
return input;
}
Value scalar = nullptr;
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
if (!inputTensorType)
return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64))
return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
std::optional<unsigned> tensorRank =
getTensorRank(valueTensorLiteralOp.getResult());
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
auto tensorType =
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<int64_t>();
scalar = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
}
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
scalar = primNumToTensorScalarOp.getA();
return primNumToTensorScalarOp.getA();
}
return scalar;
return nullptr;
}
//===----------------------------------------------------------------------===//
@ -386,7 +390,7 @@ void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
// If the condition is constant, we can give a more precise answer.
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
Region *executedRegion =
condAttr.getValue().isOneValue() ? &getThenRegion() : &getElseRegion();
condAttr.getValue().isOne() ? &getThenRegion() : &getElseRegion();
regions.push_back(RegionSuccessor(executedRegion));
return;
}
@ -507,7 +511,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]);
}
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
if (!uncheckedCast)
return nullptr;
@ -570,10 +574,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// Aten__RangeLengthOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
auto lo = operands[0];
auto hi = operands[1];
auto step = operands[2];
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
auto lo = adaptor.getLo();
auto hi = adaptor.getHi();
auto step = adaptor.getStep();
if (!lo || !hi || !step)
return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
@ -595,10 +599,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
// Aten__DeriveIndexOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
auto index = operands[0];
auto start = operands[1];
auto step = operands[2];
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
auto index = adaptor.getIndex();
auto start = adaptor.getStart();
auto step = adaptor.getStep();
if (!index || !start || !step)
return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
@ -612,7 +616,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
// Aten__Is__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
}
@ -620,7 +624,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Isnot__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
}
@ -628,7 +632,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Not__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
bool value;
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
return nullptr;
@ -639,7 +643,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
// AtenNeBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
@ -655,7 +659,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand();
@ -667,7 +671,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
@ -679,7 +683,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
// AtenRoundOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf();
@ -691,7 +695,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
// AtenTypeAsOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
Type inType = getSelf().getType();
Type newType = getOther().getType();
@ -705,7 +709,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
@ -736,7 +740,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeLayoutOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
// The pin_memory arg should be either constant `False` or `none`.
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
bool pinMemory;
@ -797,7 +801,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
// AtenViewOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr;
@ -812,7 +816,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
// AtenDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes())
return IntegerAttr::get(IntegerType::get(getContext(), 64),
@ -825,7 +829,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
// AtenLenTOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
// `len([1,1,1])` -> `3`, if it is not mutated.
if (auto listConstruct =
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
@ -853,7 +857,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenLenStrOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
return getI64IntegerAttr(getContext(),
stringConstruct.getValueAttr().getValue().size());
@ -869,22 +873,25 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
if (op->getNumOperands() < 2) {
return failure();
}
auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter);
auto outType = op->getResult(0).getType();
if (!lhs || !rhs) {
return rewriter.notifyMatchFailure(
op, "only int scalar lhs or rhs is supported");
}
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
op)) {
Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
AtenAddScalarOp>(op)) {
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
if (!alpha) {
return rewriter.notifyMatchFailure(op,
"only int scalar alpha is supported");
}
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
if (isa<AtenRsubScalarOp>(op))
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
else
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
}
if (isa<AtenDivTensorModeOp>(op)) {
@ -937,6 +944,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
} else if (isa<AtenRsubScalarOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
}
@ -984,6 +993,16 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenRSubScalarOp
//===----------------------------------------------------------------------===//
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}
//===----------------------------------------------------------------------===//
// AtenMulTensorOp
//===----------------------------------------------------------------------===//
@ -1014,6 +1033,23 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenScalarImplicitOp
//===----------------------------------------------------------------------===//
void AtenScalarImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenSizeOp
//===----------------------------------------------------------------------===//
@ -1092,7 +1128,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenSizeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
int64_t dim;
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
return nullptr;
@ -1132,7 +1168,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
// AtenLtFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a < b; });
}
@ -1141,7 +1177,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGtFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a > b; });
}
@ -1150,7 +1186,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGeFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a >= b; });
}
@ -1159,7 +1195,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenEqFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a == b; });
}
@ -1225,7 +1261,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
// AtenNeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a != b; });
}
@ -1234,7 +1270,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a == b; });
}
@ -1243,7 +1279,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqStrOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), true);
@ -1259,7 +1295,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
// AtenLtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a < b; });
}
@ -1268,7 +1304,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenLeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a <= b; });
}
@ -1277,7 +1313,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a > b; });
}
@ -1286,7 +1322,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a >= b; });
}
@ -1295,7 +1331,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI1IntegerAttr(getContext(), c != 0.0);
@ -1306,7 +1342,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI1IntegerAttr(getContext(), c != 0);
@ -1317,9 +1353,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold int -> float conversion.
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
return FloatAttr::get(
mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -1330,13 +1366,27 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -1347,6 +1397,18 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b));
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//
@ -1440,7 +1502,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
return success();
}
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1545,7 +1607,7 @@ void CopyToValueTensorOp::getEffects(
// ConstantNoneOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
return TypeAttr::get(Torch::NoneType::get(getContext()));
}
@ -1558,9 +1620,7 @@ void ConstantNoneOp::getAsmResultNames(
// ConstantStrOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
void ConstantStrOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
@ -1598,7 +1658,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
}
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1614,7 +1674,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
// ConstantFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1644,7 +1704,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
// ConstantNumberOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1672,7 +1732,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
// ConstantBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1690,7 +1750,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(outputs[0], inputs[0]);
}
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
if (derefineOp.getOperand().getType() == getType())
return derefineOp.getOperand();
@ -1824,7 +1884,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenEqIntListOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsLiteral)
return nullptr;
@ -1849,6 +1909,20 @@ OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// PrimTupleConstructOp
//===----------------------------------------------------------------------===//
LogicalResult PrimTupleConstructOp::verify() {
if (!(isValidSubtype(
Torch::TupleType::get(getContext(),
llvm::to_vector<6>(getElements().getType())),
getResult().getType())))
return emitOpError(
"failed to verify that contained types correspond to operand types");
return success();
}
//===----------------------------------------------------------------------===//
// PrimTupleIndexOp
//===----------------------------------------------------------------------===//
@ -1950,7 +2024,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
// Aten__Getitem__DictStrOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getSelf());
if (!dictConstruct)
return nullptr;
@ -1968,7 +2042,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
// Aten__Contains__StrOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getDict());
if (!dictConstruct)
return nullptr;
@ -1991,7 +2065,7 @@ static bool isListConstructNotModified(Value torchList) {
});
}
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
auto itemConstruct = getItem();
if (!isListConstructNotModified(getL()))
return nullptr;
@ -2052,43 +2126,55 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenFloordivIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
adaptor.getOperands(),
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
}
//===----------------------------------------------------------------------===//
// AtenRemainderIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a % b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
}
//===----------------------------------------------------------------------===//
// AtenAddIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a + b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
}
//===----------------------------------------------------------------------===//
// AtenSubIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a - b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenCatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
return list.getElements()[0];
}
//===----------------------------------------------------------------------===//
// AtenStackOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
@ -2099,7 +2185,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenSliceTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
@ -2118,7 +2204,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenMulIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2129,46 +2215,70 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenSubFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenSubOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1]) {
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a - b; });
}
return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a - b; });
adaptor.getOperands(),
[](double a, double b) -> double { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenDivOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1]) {
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
// Since AtenDivOp always returns float value, we don't need to deal with the
// case where the operands are both integers separately.
return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a / b; });
adaptor.getOperands(),
[](double a, double b) -> double { return a / b; });
}
//===----------------------------------------------------------------------===//
// AtenPowIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); });
}
//===----------------------------------------------------------------------===//
// AtenCeilScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0]) {
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
if (!floatValue) {
return nullptr;
}
@ -2181,7 +2291,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenNegIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI64IntegerAttr(getContext(), -c);
@ -2192,7 +2302,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
// AtenSqrtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getF64FloatAttr(getContext(), std::sqrt(c));
@ -2203,7 +2313,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
// PrimDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) {
torch_upstream::ScalarType scalarType =
@ -2217,7 +2327,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2229,7 +2339,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Float.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2241,7 +2351,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenDivFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
double lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
@ -2258,7 +2368,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenDivIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2271,7 +2381,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI64IntegerAttr(getContext(), std::ceil(c));
@ -2282,13 +2392,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
// PrimMaxIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity.
if (getA() == getB())
return getA();
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
@ -2301,7 +2411,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
// PrimMinSelfIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
if (!list)
return nullptr;
@ -2320,6 +2430,25 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
*std::min_element(values.begin(), values.end()));
}
//===----------------------------------------------------------------------===//
// PrimMinIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity.
if (getA() == getB())
return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
return IntegerAttr::get(
lhs.getType(),
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}
//===----------------------------------------------------------------------===//
// ShapeCalculateOp
//===----------------------------------------------------------------------===//

View File

@ -68,16 +68,32 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true;
}
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
auto typeTensorType = type.dyn_cast<BaseTensorType>();
if (subtypeTensorType && typeTensorType) {
// Check that both tensors have the same `BaseTensorType` subtype.
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtypeTensorType.isa<ValueTensorType>() !=
typeTensorType.isa<ValueTensorType>())
return false;
// `type` must not have more static information than `subtype`, and `type`
// must not disagree with `subtype`.
if (typeTensorType.hasDtype() &&
(!subtypeTensorType.hasDtype() ||
typeTensorType.getDtype() != subtypeTensorType.getDtype())) {
return false;
}
if (typeTensorType.hasSizes() &&
(!subtypeTensorType.hasSizes() ||
typeTensorType.getSizes() != subtypeTensorType.getSizes())) {
return false;
}
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
}
return false;
}
@ -463,7 +479,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
}
}
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype);
}
////===----------------------------------------------------------------------===//
@ -505,4 +521,4 @@ DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
return failure();
}
return success();
}
}

View File

@ -4088,6 +4088,259 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n"
" return %none : !torch.none\n"
" }\n"
" func.func @__torch__.torch.jit._shape_functions.stack(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n"
" %str_0 = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %1, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.add.int %17, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %19 = torch.aten.le.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %20 = torch.prim.If %19 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %18 : !torch.int\n"
" }\n"
" %21 = torch.aten.neg.int %20 : !torch.int -> !torch.int\n"
" %22 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %23 = torch.aten.lt.int %arg1, %21 : !torch.int, !torch.int -> !torch.bool\n"
" %24 = torch.prim.If %23 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %31 = torch.aten.gt.int %arg1, %22 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %31 : !torch.bool\n"
" }\n"
" %25 = torch.aten.__not__ %24 : !torch.bool -> !torch.bool\n"
" torch.prim.If %25 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %26 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
" %31 = torch.aten.add.int %arg1, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %31 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %28 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %29, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %31 = torch.aten.__getitem__.t %16, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %32 = torch.aten.append.t %28, %31 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.aten.insert.t %28, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
" %30 = torch.aten.append.t %0, %28 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.gt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %18 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %4 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %5 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
" %6 = torch.prim.Loop %4, %true, init(%5) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %19 = torch.prim.If %18 -> (!torch.bool) {\n"
" %22 = torch.aten.__getitem__.t %16, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %23 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n"
" %21 = torch.prim.If %20 -> (!torch.optional<int>) {\n"
" %22 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %23 = torch.prim.If %22 -> (!torch.int) {\n"
" %25 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %26 = torch.aten.le.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %25 : !torch.int\n"
" }\n"
" %28 = torch.aten.neg.int %27 : !torch.int -> !torch.int\n"
" %29 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %30 = torch.aten.lt.int %arg1, %28 : !torch.int, !torch.int -> !torch.bool\n"
" %31 = torch.prim.If %30 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %35 = torch.aten.gt.int %arg1, %29 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %35 : !torch.bool\n"
" }\n"
" %32 = torch.aten.__not__ %31 : !torch.bool -> !torch.bool\n"
" torch.prim.If %32 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %33 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %34 = torch.prim.If %33 -> (!torch.int) {\n"
" %35 = torch.aten.add.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %35 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" torch.prim.If.yield %34 : !torch.int\n"
" } else {\n"
" %25 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %25 : !torch.int\n"
" }\n"
" %24 = torch.derefine %23 : !torch.int to !torch.optional<int>\n"
" torch.prim.If.yield %24 : !torch.optional<int>\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.optional<int>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%21 : !torch.optional<int>)\n"
" } : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>\n"
" %7 = torch.aten.__is__ %6, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" %16 = torch.prim.unchecked_cast %6 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %16 : !torch.int\n"
" }\n"
" %9 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %10 = torch.aten.gt.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %12 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %13 = torch.prim.Loop %11, %true, init(%12) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.prim.Loop %17, %true, init(%int1) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %16, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.mul.int %arg5, %23 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop.condition %true, iter(%24 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %19 = torch.aten.eq.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
" %23 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %24 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n"
" %22 = torch.prim.If %21 -> (!torch.optional<list<int>>) {\n"
" %23 = torch.derefine %16 : !torch.list<int> to !torch.optional<list<int>>\n"
" torch.prim.If.yield %23 : !torch.optional<list<int>>\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.optional<list<int>>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%22 : !torch.optional<list<int>>)\n"
" } : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>\n"
" %14 = torch.aten.__is__ %13, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %15 = torch.prim.If %14 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" } else {\n"
" %16 = torch.prim.unchecked_cast %13 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %17 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %18 = torch.prim.Loop %17, %true, init(%int0) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n"
" %22 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %23 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %24 = torch.prim.Loop %23, %true, init(%int1) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %29 = torch.aten.__getitem__.t %22, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %30 = torch.aten.mul.int %arg5, %29 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop.condition %true, iter(%30 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %25 = torch.aten.eq.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %26 = torch.prim.If %25 -> (!torch.bool) {\n"
" %29 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %30 = torch.aten.eq.int %29, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %30 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool\n"
" %28 = torch.prim.If %27 -> (!torch.int) {\n"
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %30 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %31 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %32 = torch.aten.__range_length %int0, %29, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %32, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %35 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %36 = torch.aten.ne.int %35, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %36 -> () {\n"
" %37 = torch.aten.__getitem__.t %16, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
" %38 = torch.aten.__getitem__.t %22, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
" %39 = torch.aten.eq.int %37, %38 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %39 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %33 = torch.aten.__getitem__.t %22, %8 : !torch.list<int>, !torch.int -> !torch.int\n"
" %34 = torch.aten.add.int %arg3, %33 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %34 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%28 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %19 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %20 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %20, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %22 = torch.aten.__getitem__.t %16, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = torch.aten.append.t %19, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %21 = torch.aten._set_item.t %19, %8, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %19 : !torch.list<int>\n"
" }\n"
" return %15 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %true = torch.constant.bool true\n"
@ -5790,6 +6043,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -5877,6 +6134,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -5993,7 +6254,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.float, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %false = torch.constant.bool false\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
@ -6006,13 +6267,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -6035,7 +6296,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -6549,6 +6810,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -6583,6 +6847,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.p\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._index_put_impl\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -6596,6 +6863,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
@ -6881,6 +7151,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -7310,6 +7583,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -7340,6 +7617,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %int0 : !torch.int to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
@ -7855,6 +8139,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
@ -7925,6 +8233,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
@ -8644,7 +8976,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n"
" %true = torch.constant.bool true\n"
" %int4 = torch.constant.int 4\n"
@ -8659,7 +8990,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
@ -8681,7 +9012,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield %12 : !torch.bool\n"
" }\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"

View File

@ -10,7 +10,6 @@
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
ReduceOpVariants.cpp
RefinePublicReturn.cpp
RefineTypes.cpp

View File

@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false;
Type resDtype =
FailureOr<Type> resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return resDtype.isa<mlir::FloatType>();
if (failed(resDtype))
return false;
return resDtype->isa<mlir::FloatType>();
}
// Helper function to compute the return type of the reduction function.
@ -70,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
: llvm::ArrayRef(sizes),
tensorType.getOptionalDtype());
return resultType;
}
@ -106,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
valueType
.getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(valueType.getSizes()),
: llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
return rewriter
@ -140,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
SmallVector<int64_t> sizes;
Type rank0TensorTy = inputType.getWithSizesAndDtype(
makeArrayRef(sizes), inputType.getOptionalDtype());
ArrayRef(sizes), inputType.getOptionalDtype());
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{});
@ -169,6 +171,37 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
return sub;
}
// Helper function to unsqueeze the input tensor at given dim.
// Return the unsqueezed tensor or failure.
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
Operation *op, Value input, Value dim) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size");
}
SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes();
// `input` has a reduced rank. Hence add 1.
int64_t unsqueezedRank = inputShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, unsqueezedRank);
if (!isValidDim(dimInt, unsqueezedRank)) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
unsqueezedShape.append(inputShape.begin(), inputShape.end());
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType = inputType.getWithSizesAndDtype(
unsqueezedShape, inputType.getOptionalDtype());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
}
namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
@ -258,6 +291,15 @@ public:
Value dim = op.getDim();
Value self = op.getSelf();
// convert `start` to non-negative: start += int(start < 0) * dimSize
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
@ -595,6 +637,128 @@ public:
};
} // namespace
// Decompose `aten.bucketize` into the following op sequence:
//
// def aten_bucketize(input, boundaries, out_int32, right):
// unsqz_input = input.unsqueeze(-1)
// if not right:
// comparison = unsqz_input <= boundaries
// else:
// comparison = unsqz_input < boundaries
// indices = torch.argmax(comparison.float(), dim=-1)
// within_bound = comparison[..., -1]
// result = torch.where(within_bound, indices, boundaries.shape[0])
// if out_int32:
// result = result.int()
// return result
//
namespace {
class DecomposeAtenBucketizeTensorOp
: public OpRewritePattern<AtenBucketizeTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes");
}
ArrayRef<int64_t> inputShape = inputType.getSizes();
Value boundaries = op.getBoundaries();
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
return rewriter.notifyMatchFailure(op,
"unimplemented: boundaries must have "
"known sizes and must be a 1D array");
}
int64_t boundariesSize = boundariesType.getSizes()[0];
bool outInt32;
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: out_int32 must be a constant bool");
}
bool right;
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: right must be a constant bool");
}
// unsqueeze input at the last dim to make it broadcastable with boundaries
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
Value unsqzInput = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
SmallVector<int64_t> compareShape(inputShape);
compareShape.push_back(boundariesSize);
Type compareType =
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
Value compare;
if (!right) {
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
boundaries);
} else {
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
boundaries);
}
// convert the comparison results to float32 as the argmax op input,
// which does not support integer dtype in LINALG backend
Value compareF32 =
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
// get the first boundary index where the input element is less than (or
// equal to) the boundary value
Type indicesType = inputType.getWithSizesAndDtype(
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
/*dim=*/constMinusOne,
/*keepdim=*/constFalse);
// get the comparison results between each input element and the rightmost
// boundary value
Type withinUpperBoundType =
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
/*index=*/constMinusOne);
// If the input element is less than (or equal to) the rightmost boundary,
// take the max index as result. Otherwise, the element is beyond the
// rightmost boundary, so take the boundary size.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value upperBound =
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
Value result = rewriter.create<AtenWhereScalarOtherOp>(
loc, indicesType, withinUpperBound, indices, upperBound);
if (outInt32) {
result = convertTensorToDtype(
rewriter, loc, result,
rewriter.getIntegerType(32, IntegerType::Signed));
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// To avoid overflow we use the following decomposition rule:
// x_max = aten.max(x, dim, keepdim=True)[0]
// shifted = x - x_max
@ -891,6 +1055,50 @@ public:
};
} // namespace
// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`.
namespace {
class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStackOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the tensor list is not from list construct");
}
// Ensure all tensors have known sizes
for (Value tensor : tensors) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (!tensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: one tensor does not have known sizes");
}
}
SmallVector<Value> unsqueezedTensors;
for (Value tensor : tensors) {
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim());
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate unsqueeze tensor op");
}
unsqueezedTensors.push_back(*unsqueezedInfo);
}
Type listElemType =
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
op.getLoc(), listType, unsqueezedTensors);
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
unsqueezedTensorList, op.getDim());
return success();
}
};
} // namespace
// Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html
namespace {
@ -929,7 +1137,7 @@ public:
SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
@ -1066,9 +1274,9 @@ public:
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIntSizes), dtype);
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
Type expandedType =
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims =
@ -1226,6 +1434,25 @@ public:
};
} // namespace
// Decompose aten.masked_fill.Scalar into aten.where.self op.
namespace {
class DecomposeAtenMaskedFillScalarOp
: public OpRewritePattern<AtenMaskedFillScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
value, op.getSelf());
return success();
}
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution op.
namespace {
class DecomposeAtenConvolutionOverrideableOp
@ -1977,23 +2204,23 @@ public:
// aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.randLike` op.
class DecomposeValsemVariantAtenBernoulliFloatOp
: public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
template <typename BernoulliLikeOp>
class DecomposeAtenBernoulliLikeOp : public OpRewritePattern<BernoulliLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
using OpRewritePattern<BernoulliLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BernoulliLikeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Value p = op.getP();
if (!op.getGenerator().getType().isa<Torch::NoneType>())
if (!op.getGenerator().getType().template isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
auto inputType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
rewriter.getF64Type());
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
Value output;
@ -2071,8 +2298,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
std::vector<int64_t> meanVarSizes(inputRank, 1);
for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
@ -2309,7 +2536,7 @@ class DecomposeAtenNativeBatchNormOp
runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList);
@ -2455,8 +2682,7 @@ public:
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
op.getFillValue());
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
@ -2492,7 +2718,7 @@ public:
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
@ -2562,8 +2788,7 @@ public:
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(), tensorType, op.getFillValue());
fillVal =
@ -3003,7 +3228,7 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
template <typename OpTy>
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
bool unbiased, int64_t correction) {
bool unbiased, double correction) {
Location loc = op.getLoc();
Value self = op.getSelf();
Value dimList = op.getDim();
@ -3089,19 +3314,22 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(correction));
productDimSize = rewriter.create<AtenFloatScalarOp>(loc, productDimSize);
constantOne = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));
Value cstCorrection = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(correction));
// The `correction` value should be less than or equal to `productDimSize +
// 1`.
Value productDimSizePlusOne =
rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
loc, productDimSize.getType(), productDimSize, constantOne);
Value cond =
rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"correction value should be less than or equal to productDimSize + 1");
Value productDimSizeSubCorrection =
rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
productDimSizeSubCorrection);
result =
@ -3128,7 +3356,7 @@ public:
return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var");
}
int64_t correction = unbiased ? 1 : 0;
double correction = unbiased ? 1.0 : 0.0;
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
@ -3148,18 +3376,32 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
PatternRewriter &rewriter) const override {
int64_t correction;
int64_t correctionValInt;
double correctionValFloat = 1.0;
if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correction)))
if (op.getCorrection().getType().isa<Torch::FloatType>()) {
if (!matchPattern(op.getCorrection(),
m_TorchConstantFloat(&correctionValFloat)))
return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for "
"aten.var");
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
if (!matchPattern(op.getCorrection(),
m_TorchConstantInt(&correctionValInt)))
return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for "
"aten.var");
correctionValFloat = (double)correctionValInt;
} else {
return rewriter.notifyMatchFailure(
op, "Only support constant int correction for aten.var");
} else {
// The default value in case of `correction` being None is 1.
correction = 1;
op, "unimplemented: correction value should be only constant int "
"or float for aten.var");
}
}
bool unbiased = correction == 0 ? false : true;
bool unbiased = correctionValFloat == 0.0 ? false : true;
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
correction)))
correctionValFloat)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success();
}
@ -3184,29 +3426,13 @@ public:
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
SmallVector<int64_t> sizes;
if (!srcTensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "src tensor must have size");
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
// `src` has a reduced rank. Hence add 1.
int64_t srcRank = srcShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, srcRank);
if (!isValidDim(dimInt, srcRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(srcShape.begin(), srcShape.end());
sizes.insert(sizes.begin() + dimInt, 1);
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim);
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor op");
}
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
src = *unsqueezedInfo;
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
/*step=*/one);
@ -3303,7 +3529,7 @@ public:
op, "Expected the input tensor to have sizes");
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
resultType.getOptionalDtype())
.cast<BaseTensorType>();
@ -3330,6 +3556,29 @@ public:
};
} // namespace
namespace {
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
class DecomposeAtenNormScalarOptDimOp
: public OpRewritePattern<AtenNormScalarOptDimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value ord = op.getP();
if (ord.getType().isa<Torch::NoneType>()) {
ord = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));
}
rewriter.replaceOpWithNewOp<AtenLinalgVectorNormOp>(
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
/*dtype=*/none);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
public:
@ -3526,6 +3775,40 @@ public:
};
} // namespace
namespace {
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
PatternRewriter &rewriter) const override {
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public:
@ -3546,6 +3829,49 @@ public:
};
} // namespace
namespace {
class DecomposeAtenNewEmptyStridedOp
: public OpRewritePattern<AtenNewEmptyStridedOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -3591,6 +3917,7 @@ public:
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
@ -3598,6 +3925,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
@ -3640,8 +3968,11 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
addPatternIfTargetOpIsIllegal<
DecomposeAtenBernoulliLikeOp<ValsemVariantAtenBernoulliFloatOp>>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
@ -3688,6 +4019,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
@ -3695,9 +4027,12 @@ public:
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
@ -3715,4 +4050,4 @@ std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass(
ArrayRef<std::string> legalOps) {
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
}
}

View File

@ -10,9 +10,9 @@
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -10,9 +10,9 @@
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -244,7 +244,7 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
continue;
opsToMove.push_back(&op);
}
BlockAndValueMapping mapping;
IRMapping mapping;
for (Operation *op : opsToMove) {
// The ops are used by `torch.slot` ops in the enclosing module.
// Cloning avoids needing to handle those uses specially.
@ -329,7 +329,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
// currently only analyzes a subset of ops.
static LogicalResult analyzeInstances(func::FuncOp func,
ArrayRef<ArgInstance> argInstances,
BlockAndValueMapping &mapping) {
IRMapping &mapping) {
for (auto &argInstance : argInstances)
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
auto walkResult = func.walk([&](PrimGetAttrOp op) {
@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
}
static FailureOr<Monomorphization>
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
SymbolTable &symbolTable) {
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
Monomorphization monomorphization;
@ -410,7 +410,7 @@ public:
private:
LogicalResult generateNewMonomorphizations(const Monomorphization &m) {
auto func = m.func;
BlockAndValueMapping mapping;
IRMapping mapping;
if (failed(analyzeInstances(func, m.argInstances, mapping)))
return failure();
auto walkResult = func.walk([&](func::CallOp op) {
@ -495,7 +495,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
// `mapping` to corresponding global instances.
static LogicalResult rewriteMonomorphizedFuncClone(
func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
func::FuncOp func, IRMapping mapping, SymbolTable &symbolTable,
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
ObjectGraphInfo &objectGraphInfo) {
@ -662,7 +662,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
}
for (auto &kv : newFuncs) {
BlockAndValueMapping mapping;
IRMapping mapping;
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
return failure();
if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable,

View File

@ -27,8 +27,8 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -373,7 +373,7 @@ class InlineGlobalSlotsPass
// big deal.
SmallVector<Operation *> slice =
getBackwardSliceIncludingRoot(initialValue);
BlockAndValueMapping mapping;
IRMapping mapping;
OpBuilder builder(op);
for (Operation *opInSlice : slice)
builder.clone(*opInSlice, mapping);

View File

@ -285,19 +285,16 @@ public:
}
};
class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
class VerifyBackendContractNoDecompositionsPass
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
public:
VerifyBackendContractPass() = default;
VerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps) {
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
}
VerifyBackendContractNoDecompositionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);
getBackendContractTarget(context, /*decompose*/false,
/*backendLegalOps*/{});
if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
@ -315,10 +312,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass(
bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<VerifyBackendContractPass>(decompose,
backendLegalOps);
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
return std::make_unique<VerifyBackendContractNoDecompositionsPass>();
}
// The backend contract guarantees that ops with decompositions available will
@ -347,6 +342,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
@ -354,6 +350,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
@ -362,6 +359,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>();
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenTOp>();
@ -394,6 +392,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_ReshapeAliasOp>();
target.addIllegalOp<AtenBernoulliOp>();
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenRandLikeOp>();
@ -442,7 +441,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<PrimsSqrtOp>();
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}

View File

@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

View File

@ -10,7 +10,6 @@
#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopy_Op op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
return failure();
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
// Get indices
int64_t dim;
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
int64_t end;
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();
Value newEnd = sliceOp.getEnd();
if (end < 0) {
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
// Create IndexPut_Op
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);
SmallVector<Value> indicesVector;
for (auto i = 0; i < dim - 1; i++)
indicesVector.push_back(noneVal);
indicesVector.push_back(range);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
return success();
}
};
} // namespace
namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
}

View File

@ -59,7 +59,6 @@
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@ -81,7 +80,9 @@ using namespace mlir::torch::Torch;
// -----------------------------------------------------------------------------
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
FailureOr<Type> result =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return failed(result) ? Type() : *result;
}
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
@ -111,24 +112,6 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
return torch_upstream::TypeKind::AnyType;
}
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
/// Returns `std::nullopt` if the types are contradictory. Note this can only
/// be used on the `dtype` from tensors and can't be used on other types like
/// scalar types.
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
(void)isNullOrBuiltIn;
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
if (!lhs)
return rhs;
if (!rhs)
return lhs;
if (lhs == rhs)
return lhs;
return std::nullopt;
}
enum class OptionalKnowledge {
unKnown,
@ -475,7 +458,8 @@ private:
void visitAtenToDtypeLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitTypeConversionOp(OpTy op, ArrayRef<const ValueState *> operands);
void visitAtenCatOp(AtenCatOp op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitAtenCatLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy>
void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
@ -563,7 +547,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
/*skipRankCheck=*/true);
state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
return getTypeForScalarType(scalarType.getContext(), result_type(state));
FailureOr<Type> result =
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
}
static SmallVector<std::optional<bool>>
@ -600,7 +586,8 @@ static Type getPromotedResultType(MLIRContext *context,
return Type();
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
}
return getTypeForScalarType(context, result_type(state));
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
return failed(result) ? Type() : *result;
}
static Type getPromotedResultTypeAssumingNonZeroRank(
@ -649,23 +636,26 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, AtenUniformOp,
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
AtenTanhBackwardOp, AtenHardtanhBackwardOp,
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp,
AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp,
AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp,
AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
AtenSelectIntOp, AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
Aten_IndexPutImplOp, AtenIndexPutOp, AtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp,
AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp,
AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
@ -970,9 +960,16 @@ void TypeAnalysis::visitOperation(Operation *op,
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
return;
} else if (auto newEmptyStrided = dyn_cast<AtenNewEmptyStridedOp>(op)) {
visitConstantTensorNewLikeOp<AtenNewEmptyStridedOp>(newEmptyStrided,
operands);
return;
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
return;
} else if (auto randLike = dyn_cast<AtenRandnLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandnLikeOp>(randLike, operands);
return;
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
visitConstantTensorAllocLikeOp<Aten_ToCopyOp>(toCopy, operands);
return;
@ -1008,7 +1005,10 @@ void TypeAnalysis::visitOperation(Operation *op,
}
if (auto cat = dyn_cast<AtenCatOp>(op)) {
visitAtenCatOp(cat, operands);
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
return;
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
return;
}
@ -1114,6 +1114,22 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
bool outInt32;
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
outInt32) {
knowledge.dtype =
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
} else {
knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
}
incorporateKnowledge(bucketize.getResult(), knowledge);
return;
}
// Otherwise, this is an unknown operation, so reset the state.
setAllToEntryStates(results);
return;
@ -1338,30 +1354,26 @@ void TypeAnalysis::visitTypeConversionOp(
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given
// dimension. The output has the same sizes as the input for all dimensions
// except the given dimension.
void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
ArrayRef<const ValueState *> operands) {
template <typename OpTy>
void TypeAnalysis::visitAtenCatLikeOp(OpTy op,
ArrayRef<const ValueState *> operands) {
auto tensorList = op.getTensors();
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
auto listConstruct = tensorList.template getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
incorporateKnowledge(op.getResult(), knowledge);
return;
}
auto tensors = llvm::to_vector<4>(
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge {
return getLatticeElement(v)->getValue();
SmallVector<ValueKnowledge*> tensors = llvm::to_vector(
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* {
return &getLatticeElement(v)->getValue();
}));
for (auto tensor : tensors) {
auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
if (!newDtype.has_value()) {
incorporateKnowledge(op.getResult(), knowledge);
return;
}
knowledge.dtype = newDtype.value();
}
incorporateKnowledge(op.getResult(), knowledge);
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), tensors);
incorporateKnowledge(op->getResult(0), knowledge);
}
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
@ -1436,12 +1448,16 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
return getRefinedTensorType(tensorType, knowledge);
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.optional == OptionalKnowledge::isNone)
return Torch::NoneType::get(v.getContext());
else if (knowledge.optional == OptionalKnowledge::notNone) {
@ -1456,6 +1472,8 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement)
return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.kind == torch_upstream::TypeKind::IntType)
return Torch::IntType::get(v.getContext());
if (knowledge.kind == torch_upstream::TypeKind::FloatType)

View File

@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
impliedTypeFromDtype = *torchType;
} else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) {
FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(),
getTypeForScalarType(op->getContext(), dtypeScalarType));
originalResultType.getOptionalSizes(), *builtinType);
} else {
return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to "

View File

@ -10,7 +10,7 @@
#include "PassDetail.h"
#include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -47,7 +47,7 @@ public:
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
SmallVector<Block *> blocksToMerge;
BlockAndValueMapping bvm;
IRMapping bvm;
// TODO: Helper for region().front()
auto condition =
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
@ -129,8 +129,7 @@ public:
// Truncate the list of users to the number of users we're going to
// interpret.
allUsers.resize(numUsersToInterpret);
auto usersToInterpret =
makeArrayRef(allUsers).take_front(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then
@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(makeArrayRef(sizes),
.getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype())
.cast<BaseTensorType>();

View File

@ -8,6 +8,8 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "llvm/Support/ErrorHandling.h"
namespace mlir {
namespace torch {
namespace torch_upstream {
@ -126,6 +128,23 @@ ScalarType result_type(const ResultTypeState &in_state) {
combine_categories(in_state.zeroResult, in_state.wrappedResult));
}
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
if (reduce == "max" || reduce == "amax") {
return torch_upstream::ReductionType::MAX;
} else if (reduce == "mean") {
return torch_upstream::ReductionType::MEAN;
} else if (reduce == "min" || reduce == "amin") {
return torch_upstream::ReductionType::MIN;
} else if (reduce == "sum") {
return torch_upstream::ReductionType::SUM;
} else if (reduce == "prod") {
return torch_upstream::ReductionType::PROD;
} else {
llvm_unreachable(
"'reduce' argument must be either sum, prod, mean, amax or amin");
}
}
} // namespace torch_upstream
} // namespace torch
} // namespace mlir

View File

@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType(
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
}
Type Torch::getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
FailureOr<Type>
Torch::getTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
return Float32Type::get(context);
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float128Type::get(context));
case torch_upstream::ScalarType::Undefined:
return failure();
default:
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
}
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context);
case torch_upstream::ScalarType::Undefined:
default:
return failure();
}

View File

@ -32,11 +32,11 @@ namespace {
struct TorchConversionInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
IRMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
IRMapping &) const final {
return true;
}
};

View File

@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
// FromI64Op
//===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToI64Op
//===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToF64Op
//===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// FromF64Op
//===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {

View File

@ -11,7 +11,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"

View File

@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
TorchMLIRTorchConversionToMLProgram
MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs ChloPasses)
endif()
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyMhloBackendContract.cpp
VerifyStablehloBackendContract.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -21,9 +21,8 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_MHLO
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
"torch-backend-to-mhlo-backend-pipeline",
"Pipeline lowering torch backend contract to MHLO backend "
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::PassPipelineRegistration<
TorchConversion::StablehloBackendPipelineOptions>(
"torch-backend-to-stablehlo-backend-pipeline",
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.",
TorchConversion::createTorchBackendToMhloBackendPipeline);
TorchConversion::createTorchBackendToStablehloBackendPipeline);
#endif
}
@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
}
#ifdef TORCH_MLIR_ENABLE_MHLO
void TorchConversion::createTorchBackendToMhloBackendPipeline(
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm,
const TorchConversion::MhloBackendPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
const TorchConversion::StablehloBackendPipelineOptions &options) {
// Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Clean up any non-canonical code introduced above..
@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// MHLO backend contract.
// StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that MHLO backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
// Verify that we have lowered to Stablehlo and Chlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
}
#endif

View File

@ -6,10 +6,9 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_MHLO
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "PassDetail.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@ -18,6 +17,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
@ -25,17 +25,15 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace {
class VerifyMhloBackendContractPass
: public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
class VerifyStablehloBackendContractPass
: public VerifyStablehloBackendContractBase<
VerifyStablehloBackendContractPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter;
converter.addConversion([](Type type) -> Type {
auto elemTy = type;
if (isa<TensorType>(type)) {
if (isa<TensorType>(type))
elemTy = type.cast<TensorType>().getElementType();
}
if (BaseMemRefType::isValidElementType(elemTy))
return type;
return nullptr;
@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// Structural operations.
@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass
// Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the MHLO backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
return std::make_unique<VerifyMhloBackendContractPass>();
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() {
return std::make_unique<VerifyStablehloBackendContractPass>();
}
#endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCH_MLIR_ENABLE_STABLEHLO

View File

@ -20,6 +20,10 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
}

View File

@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front());

View File

@ -45,14 +45,16 @@ endif()
declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions)
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
compiler_utils.py
dynamo.py
)
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
compiler_utils.py
dynamo.py
)
endif()
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
@ -91,7 +93,9 @@ if(TORCH_MLIR_ENABLE_LTC)
endif()
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
# generates a dummy Python library when disabled.
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
endif()
################################################################################
# Optionally handle JIT IR importer.

View File

@ -44,9 +44,9 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa"
# This output type consists of `mhlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to MHLO.
MHLO = "mhlo"
# This output type consists of `stablehlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
@ -242,7 +242,7 @@ class ExampleArgs:
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.MHLO: [],
OutputType.STABLEHLO: [],
}
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
# We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or MHLO), and those backends have
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
# very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None:
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
print(mb.module)
return mb.module
elif output_type == OutputType.MHLO:
elif output_type == OutputType.STABLEHLO:
run_pipeline_with_repro_report(
mb.module,
"builtin.module(torch-backend-to-mhlo-backend-pipeline)",
"Lowering Torch Backend IR -> MHLO Backend IR")
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
if verbose:
print("\n====================")
print("MHLO Backend IR")
print("StableHLO Backend IR")
print(mb.module)
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")

View File

@ -44,7 +44,7 @@ def run_pipeline_with_repro_report(module,
# Lower module in place to make it ready for compiler backends.
with module.context:
pm = PassManager.parse(pipeline)
pm.run(module)
pm.run(module.operation)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many

View File

@ -71,6 +71,7 @@ add_library(torch_mlir_ltc_backend SHARED
mlir_node.cpp
ops/device_data.cpp
ops/generic.cpp
utils/jit_utils.cpp
utils/tensor_utils.cpp
)
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)

Some files were not shown because too many files have changed in this diff Show More