Revert "Merge main into dtype-functions-staging (#1935)"

This reverts commit 042d58b699.
revert-1935-merge-main
Ramiro Leal-Cavazos 2023-03-15 11:25:26 -07:00 committed by GitHub
parent 042d58b699
commit ca224bcf17
165 changed files with 1846 additions and 5802 deletions

View File

@ -17,7 +17,7 @@ runs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: '3.11' python-version: '3.10'
- name: Install MLIR Python depends - name: Install MLIR Python depends
run: | run: |
@ -26,8 +26,7 @@ runs:
- name: Install PyTorch nightly depends - name: Install PyTorch nightly depends
run: | run: |
python -m pip install -r pytorch-requirements.txt python -m pip install -r requirements.txt
python -m pip install -r build-requirements.txt
shell: bash shell: bash
- name: Install prerequisites (Linux) - name: Install prerequisites (Linux)

View File

@ -8,19 +8,12 @@ on:
jobs: jobs:
build_linux: build_linux:
name: Manylinux Build name: Manylinux Build
runs-on: a100 runs-on: ubuntu-latest
# Don't run this in everyone's forks. # Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir' if: github.repository == 'llvm/torch-mlir'
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Get torch-mlir - name: Get torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -38,7 +31,6 @@ jobs:
cd ${GITHUB_WORKSPACE} cd ${GITHUB_WORKSPACE}
python -m pip install wheel python -m pip install wheel
sudo apt-get install unzip
# Fetch the most recent nightly torchvision release # Fetch the most recent nightly torchvision release
VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/')
@ -52,8 +44,7 @@ jobs:
# Read the version from the downloaded whl file without extracting it # Read the version from the downloaded whl file without extracting it
PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/')
echo "Found torch release ${PT_RELEASE}" echo "Found torch release ${PT_RELEASE}"
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\ntorchvision==%s\n" "${PT_RELEASE}" "${VISION_RELEASE}" > pytorch-requirements.txt
printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt
# Read the commit hash from the downloaded whl file without extracting it # Read the commit hash from the downloaded whl file without extracting it
PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'") PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'")
@ -105,7 +96,7 @@ jobs:
git fetch --recurse-submodules=no git fetch --recurse-submodules=no
git checkout main git checkout main
git pull origin main git pull origin main
git add pytorch-hash.txt pytorch-requirements.txt torchvision-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main) git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
- name: Update PyTorch Build Cache (if running on main branch) - name: Update PyTorch Build Cache (if running on main branch)

View File

@ -20,12 +20,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checkout torch-mlir - name: Checkout torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:

View File

@ -51,14 +51,6 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Prepare workspace
if: ${{ matrix.os-arch == 'ubuntu-x86_64' }}
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checkout torch-mlir - name: Checkout torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -121,7 +113,7 @@ jobs:
-DLLVM_USE_HOST_TOOLS=ON \ -DLLVM_USE_HOST_TOOLS=ON \
-DLLVM_ENABLE_ZSTD=OFF \ -DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ -DTORCH_MLIR_ENABLE_MHLO=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \ -DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
-DMACOSX_DEPLOYMENT_TARGET=12.0 \ -DMACOSX_DEPLOYMENT_TARGET=12.0 \

View File

@ -13,25 +13,8 @@ on:
jobs: jobs:
build_linux: build_linux:
name: Manylinux Build name: Manylinux Build
runs-on: a100 runs-on: ubuntu-latest
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
py_version: [ cp38-cp38, cp310-cp310, cp311-cp311 ]
exclude:
- package: torch-mlir-core
py_version: cp38-cp38
- package: torch-mlir-core
py_version: cp310-cp310
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Get torch-mlir - name: Get torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
@ -45,7 +28,7 @@ jobs:
python -m pip install wheel python -m pip install wheel
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh ./build_tools/python_deploy/build_linux_packages.sh
# If we were given a release_id, then upload the package we just built # If we were given a release_id, then upload the package we just built
# to the github releases page. # to the github releases page.
@ -73,7 +56,7 @@ jobs:
run: mkdir dist run: mkdir dist
- name: Copy releases to publish to dist directory - name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != '' if: github.event.inputs.release_id != ''
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
# Wheels must be published from a linux environment. # Wheels must be published from a linux environment.
# #
@ -87,9 +70,6 @@ jobs:
build_macos: build_macos:
name: MacOS Build name: MacOS Build
runs-on: macos-latest runs-on: macos-latest
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
steps: steps:
- name: Get torch-mlir - name: Get torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -105,7 +85,7 @@ jobs:
TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }}
printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version
sudo ./build_tools/python_deploy/install_macos_deps.sh sudo ./build_tools/python_deploy/install_macos_deps.sh
packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh TORCH_MLIR_PYTHON_VERSIONS="3.10" ./build_tools/python_deploy/build_macos_packages.sh
# If we were given a release_id, then upload the package we just built # If we were given a release_id, then upload the package we just built
# to the github releases page. # to the github releases page.
@ -133,7 +113,7 @@ jobs:
run: mkdir dist run: mkdir dist
- name: Copy releases to publish to dist directory - name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != '' if: github.event.inputs.release_id != ''
run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/
# Wheels must be published from a linux environment. # Wheels must be published from a linux environment.
# #
@ -147,9 +127,6 @@ jobs:
build_windows: build_windows:
name: Windows Build name: Windows Build
runs-on: windows-latest runs-on: windows-latest
strategy:
matrix:
package: [ torch-mlir, torch-mlir-core ]
steps: steps:
- name: Get torch-mlir - name: Get torch-mlir
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -165,14 +142,6 @@ jobs:
- name: Build Python wheels and smoke test. - name: Build Python wheels and smoke test.
shell: pwsh shell: pwsh
run: | run: |
if ( "${{ matrix.package }}" -eq "torch-mlir-core" )
{
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='0'
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='1'
} else {
$env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1'
$env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0'
}
$env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}'
./build_tools/python_deploy/build_windows.ps1 ./build_tools/python_deploy/build_windows.ps1
@ -203,7 +172,7 @@ jobs:
continue-on-error: true continue-on-error: true
- name: Copy releases to publish to dist directory - name: Copy releases to publish to dist directory
if: github.event.inputs.release_id != '' if: github.event.inputs.release_id != ''
run: cp ./wheelhouse/torch_mlir*.whl dist/ run: cp ./wheelhouse/torch_mlir-*.whl dist/
# Wheels must be published from a linux environment. # Wheels must be published from a linux environment.
# #

View File

@ -13,13 +13,8 @@ jobs:
if: github.repository == 'llvm/torch-mlir' if: github.repository == 'llvm/torch-mlir'
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository - name: Checking out repository
uses: actions/checkout@v3 uses: actions/checkout@v2
with: with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}
- name: Run scrape releases script - name: Run scrape releases script

View File

@ -10,14 +10,8 @@ jobs:
# Don't run this in everyone's forks. # Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir' if: github.repository == 'llvm/torch-mlir'
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository - name: Checking out repository
uses: actions/checkout@v3 uses: actions/checkout@v2
with: with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

View File

@ -13,15 +13,8 @@ jobs:
# Don't run this in everyone's forks. # Don't run this in everyone's forks.
if: github.repository == 'llvm/torch-mlir' if: github.repository == 'llvm/torch-mlir'
steps: steps:
- name: Prepare workspace
run: |
# Clear the workspace directory so that we don't run into errors about
# existing lock files.
sudo rm -rf $GITHUB_WORKSPACE/*
- name: Checking out repository - name: Checking out repository
uses: actions/checkout@v3 uses: actions/checkout@v2
with: with:
token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }}

3
.gitignore vendored
View File

@ -32,6 +32,3 @@ bazel-*
build_oot/ build_oot/
docker_venv/ docker_venv/
llvm-build/ llvm-build/
# C++ build artifacts
compile_commands.json

View File

@ -36,18 +36,12 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro() endmacro()
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
endif() endif()
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF)
if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS)
set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF)
set(TORCH_MLIR_ENABLE_LTC OFF)
endif()
if(TORCH_MLIR_ENABLE_LTC) if(TORCH_MLIR_ENABLE_LTC)
set(ENV{TORCH_MLIR_ENABLE_LTC} 1) set(ENV{TORCH_MLIR_ENABLE_LTC} 1)
@ -115,6 +109,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OUT_OF_TREE_
# Don't try to compile the python extensions at the moment. We need # Don't try to compile the python extensions at the moment. We need
# to import lots of dependencies from AddMLIRPython to make this work. # to import lots of dependencies from AddMLIRPython to make this work.
set(MLIR_ENABLE_BINDINGS_PYTHON 1) set(MLIR_ENABLE_BINDINGS_PYTHON 1)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
set(TORCH-MLIR_BUILT_STANDALONE 1) set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
@ -124,6 +119,7 @@ else()
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir # In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF) option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
# TODO: Fix this upstream so that global include directories are not needed. # TODO: Fix this upstream so that global include directories are not needed.
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
@ -132,8 +128,8 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif() endif()
if (TORCH_MLIR_ENABLE_STABLEHLO) if (TORCH_MLIR_ENABLE_MHLO)
set(STABLEHLO_BUILD_EMBEDDED ON) set(MHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL) EXCLUDE_FROM_ALL)

View File

@ -8,12 +8,13 @@ necessarily a reflection of the completeness or stability of the code, it
does indicate that the project is not yet endorsed as a component of LLVM. does indicate that the project is not yet endorsed as a component of LLVM.
[PyTorch](https://pytorch.org) [PyTorch](https://pytorch.org)
PyTorch is an open source machine learning framework that facilitates the seamless transition from research and prototyping to production-level deployment. An open source machine learning framework that accelerates the path from research prototyping to production deployment.
[MLIR](https://mlir.llvm.org) [MLIR](https://mlir.llvm.org)
The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers. The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together.
[Torch-MLIR](https://github.com/llvm/torch-mlir) [Torch-MLIR](https://github.com/llvm/torch-mlir)
Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend. Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks like PyTorch, JAX, and TensorFlow into MLIR and then progressively lowering down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem provides much needed relief to hardware vendors to focus on their unique value rather than implementing yet another PyTorch frontend for MLIR. The goal is to be similar to current hardware vendors adding LLVM target support instead of each one also implementing Clang / a C++ frontend.
[![Release Build](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml/badge.svg)](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml) [![Release Build](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml/badge.svg)](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml)
@ -42,26 +43,15 @@ We have few paths to lower down to the Torch MLIR Dialect.
## Install torch-mlir snapshot ## Install torch-mlir snapshot
At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS. This installs a pre-built snapshot of torch-mlir for Python 3.7/3.8/3.9/3.10 on Linux and macOS.
If you have Python 3.10, the following commands initialize a virtual environment.
```shell ```shell
python3.10 -m venv mlir_venv python -m venv mlir_venv
source mlir_venv/bin/activate source mlir_venv/bin/activate
``` # Some older pip installs may not be able to handle the recent PyTorch deps
Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10.
```shell
conda create -n torch-mlir python=3.10
conda activate torch-mlir
python -m pip install --upgrade pip python -m pip install --upgrade pip
``` pip install --pre torch-mlir torchvision -f https://llvm.github.io/torch-mlir/package-index/ --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# This will install the corresponding torch and torchvision nightlies
Then, we can install torch-mlir with the corresponding torch and torchvision nightlies.
```
pip install --pre torch-mlir torchvision \
-f https://llvm.github.io/torch-mlir/package-index/
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
``` ```
## Demos ## Demos

View File

@ -1,3 +1,5 @@
-r pytorch-requirements.txt
numpy numpy
pybind11 pybind11
wheel wheel

View File

@ -39,16 +39,16 @@ set -eu -o errtrace
this_dir="$(cd "$(dirname "$0")" && pwd)" this_dir="$(cd "$(dirname "$0")" && pwd)"
repo_root="$(cd "$this_dir"/../../ && pwd)" repo_root="$(cd "$this_dir"/../../ && pwd)"
# This needs to be a manylinux image so we can ship pip packages # This needs to be a manylinux image so we can ship pip packages
TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:d8994b87b45b7b2e6055fccc32db018ec73aeb05a4e43a9daa61b77cc34f846e}" TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}"
# This assumes an Ubuntu LTS like image. You can build your own with # This assumes an Ubuntu LTS like image. You can build your own with
# ./build_tools/docker/Dockerfile # ./build_tools/docker/Dockerfile
TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}" TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}"
# Version of Python to use in Release builds. Ignored in CIs. # Version of Python to use in Release builds. Ignored in CIs.
TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}" TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp310-cp310}"
# Location to store Release wheels # Location to store Release wheels
TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}"
# What "packages to build" # What "packages to build"
TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}" TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
# Use pre-built Pytorch # Use pre-built Pytorch
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration # Skip running tests if you want quick iteration
@ -84,11 +84,6 @@ function run_on_host() {
export USERID=0 export USERID=0
export GROUPID=0 export GROUPID=0
;; ;;
torch-mlir-core)
TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE}
export USERID=0
export GROUPID=0
;;
out-of-tree) out-of-tree)
TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE} TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE}
# CI uses only Python3.10 # CI uses only Python3.10
@ -164,12 +159,6 @@ function run_in_docker() {
clean_build torch_mlir "$python_version" clean_build torch_mlir "$python_version"
;; ;;
torch-mlir-core)
clean_wheels torch_mlir_core "$python_version"
build_torch_mlir_core
run_audit_wheel torch_mlir_core "$python_version"
clean_build torch_mlir_core "$python_version"
;;
out-of-tree) out-of-tree)
setup_venv "$python_version" setup_venv "$python_version"
build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
@ -278,8 +267,8 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests" echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v python -m e2e_testing.main --config=linalg -v
echo ":::: Run StableHLO e2e integration tests" echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=stablehlo -v python -m e2e_testing.main --config=mhlo -v
echo ":::: Run TOSA e2e integration tests" echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v python -m e2e_testing.main --config=tosa -v
@ -288,7 +277,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v python -m e2e_testing.main --config=lazy_tensor_core -v
echo ":::: Run TorchDynamo e2e integration tests" echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
} }
function setup_venv() { function setup_venv() {
@ -384,15 +373,6 @@ function run_audit_wheel() {
rm "$generic_wheel" rm "$generic_wheel"
} }
function build_torch_mlir_core() {
python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir
}
function clean_wheels() { function clean_wheels() {
local wheel_basename="$1" local wheel_basename="$1"
local python_version="$2" local python_version="$2"

View File

@ -20,7 +20,7 @@ set -eu -o errtrace
this_dir="$(cd "$(dirname "$0")" && pwd)" this_dir="$(cd "$(dirname "$0")" && pwd)"
repo_root="$(cd "$this_dir"/../../ && pwd)" repo_root="$(cd "$this_dir"/../../ && pwd)"
python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10 3.11}" python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10}"
output_dir="${output_dir:-${this_dir}/wheelhouse}" output_dir="${output_dir:-${this_dir}/wheelhouse}"
packages="${packages:-torch-mlir}" packages="${packages:-torch-mlir}"
@ -61,11 +61,6 @@ function run() {
build_torch_mlir torch_mlir "$python_version" build_torch_mlir torch_mlir "$python_version"
run_audit_wheel torch_mlir "$python_version" run_audit_wheel torch_mlir "$python_version"
;; ;;
torch-mlir-core)
clean_wheels torch_mlir_core "$python_version"
build_torch_mlir_core torch_mlir_core "$python_version"
run_audit_wheel torch_mlir_core "$python_version"
;;
*) *)
echo "Unrecognized package '$package'" echo "Unrecognized package '$package'"
exit 1 exit 1
@ -82,8 +77,7 @@ function build_torch_mlir() {
python"${python_version}" -m venv "$output_dir"/build_venv python"${python_version}" -m venv "$output_dir"/build_venv
source "$output_dir"/build_venv/bin/activate source "$output_dir"/build_venv/bin/activate
python"${python_version}" -m pip install -U pip python"${python_version}" -m pip install -U pip
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
CMAKE_GENERATOR=Ninja \ CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \ MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
@ -93,25 +87,6 @@ function build_torch_mlir() {
rm -rf "$output_dir"/build_venv rm -rf "$output_dir"/build_venv
} }
function build_torch_mlir_core() {
local wheel_basename="$1"
local python_version="$2"
rm -rf "$output_dir"/build_venv
python"${python_version}" -m venv "$output_dir"/build_venv
source "$output_dir"/build_venv/bin/activate
python"${python_version}" -m pip install -U pip delocate
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
CMAKE_GENERATOR=Ninja \
TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \
MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \
CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \
TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \
python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root"
deactivate
rm -rf "$output_dir"/build_venv
}
function clean_wheels() { function clean_wheels() {
local wheel_basename="$1" local wheel_basename="$1"
local python_version="$2" local python_version="$2"
@ -132,8 +107,7 @@ function run_audit_wheel() {
python"${python_version}" -m venv "$output_dir"/test_venv python"${python_version}" -m venv "$output_dir"/test_venv
source "$output_dir"/test_venv/bin/activate source "$output_dir"/test_venv/bin/activate
python"${python_version}" -m pip install -U pip python"${python_version}" -m pip install -U pip
python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt
python"${python_version}" -m pip install "$generic_wheel" --extra-index-url https://download.pytorch.org/whl/nightly/cpu python"${python_version}" -m pip install "$generic_wheel" --extra-index-url https://download.pytorch.org/whl/nightly/cpu
DYLD_LIBRARY_PATH="$output_dir"/test_venv/lib/python"${python_version}"/site-packages/torch/lib delocate-wheel -v "$generic_wheel" DYLD_LIBRARY_PATH="$output_dir"/test_venv/lib/python"${python_version}"/site-packages/torch/lib delocate-wheel -v "$generic_wheel"
deactivate deactivate

View File

@ -13,9 +13,7 @@
Write-Host "Installing Build Dependencies" Write-Host "Installing Build Dependencies"
python -m venv .\mlir_venv\ python -m venv .\mlir_venv\
.\mlir_venv\Scripts\Activate.PS1 .\mlir_venv\Scripts\Activate.PS1
pip install -r .\pytorch-requirements.txt pip install -r .\requirements.txt
pip install -r .\build-requirements.txt
pip install delvewheel
Write-Host "Build Deps installation completed successfully" Write-Host "Build Deps installation completed successfully"
Write-Host "Building torch-mlir" Write-Host "Building torch-mlir"
@ -24,7 +22,3 @@ $env:TORCH_MLIR_ENABLE_LTC='0'
python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt
Write-Host "Build completed successfully" Write-Host "Build completed successfully"
Write-Host "Fixing up wheel dependencies"
delvewheel repair --add-path .\build\cmake_build\tools\torch-mlir\python_packages\torch_mlir\torch_mlir\_mlir_libs --add-dll TorchMLIRAggregateCAPI.dll --no-dll 'c10.dll;torch_python.dll;torch_cpu.dll' -v (get-item .\wheelhouse\torch_mlir*.whl).FullName
Write-Host "All Done."

View File

@ -19,13 +19,11 @@ if [[ "$(whoami)" != "root" ]]; then
fi fi
PYTHON_INSTALLER_URLS=( PYTHON_INSTALLER_URLS=(
"https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg" "https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg"
"https://www.python.org/ftp/python/3.10.10/python-3.10.10-macos11.pkg"
"https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg" "https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg"
) )
PYTHON_SPECS=( PYTHON_SPECS=(
3.11@https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg
3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg 3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg
3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg 3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg
) )

View File

@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various
- Linalg-on-Tensors (+ `arith`, `tensor`, etc.) - Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [StableHLO](https://github.com/openxla/stablehlo) - [MHLO](https://github.com/tensorflow/mlir-hlo)
The terms "frontend" and "backend" are highly overloaded in any compiler The terms "frontend" and "backend" are highly overloaded in any compiler
project, but frequently in Torch-MLIR this is the meaning that they have. project, but frequently in Torch-MLIR this is the meaning that they have.
Sometimes "frontend" can mean something even further up the stack, such as Sometimes "frontend" can mean something even further up the stack, such as
something in PyTorch itself. When there is ambiguity we will refer to this as something in PyTorch itself. When there is ambiguity we will refer to this as
"at the PyTorch level". Similarly, "backend" can sometimes refer to something "at the PyTorch level". Similarly, "backend" can sometimes refer to something
sitting below Linalg-on-Tensors, TOSA, or StableHLO. sitting below Linalg-on-Tensors, TOSA, or MHLO.
## The `torch` dialect ## The `torch` dialect
@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96
The backend contract is a normalized form of the `torch` dialect with a set of The backend contract is a normalized form of the `torch` dialect with a set of
properties that make it easy to lower into various forms such as properties that make it easy to lower into various forms such as
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
the box. The primary guarantees that we provide Torch-MLIR's backends are: box. The primary guarantees that we provide Torch-MLIR's backends are:
- All tensors have been converted to value semantics. - All tensors have been converted to value semantics.
- All tensors have at least a known number of dimensions (i.e. rank), and - All tensors have at least a known number of dimensions (i.e. rank), and
@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, - [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
`tensor`, etc.) `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [StableHLO](https://github.com/openxla/stablehlo) - [MHLO](https://github.com/tensorflow/mlir-hlo)
### The Linalg Backend (Linalg-on-Tensors) ### The Linalg Backend (Linalg-on-Tensors)
@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
- It is extremely solid with static shapes (and many of its users only care - It is extremely solid with static shapes (and many of its users only care
about static shapes, so that's fine). about static shapes, so that's fine).
### The StableHLO Backend ### The MHLO Backend
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
The StableHLO backend was the third backend that we added, and it offers a The MHLO backend was the third backend that we added, and it offers a reasonable
reasonable blend of the benefits of the other two. blend of the benefits of the other two.
- It is a coarse-grained named-op approach. - It is a coarse-grained named-op approach.
- It has a pretty clear spec for most of the ops (with a bit of mental - It has a pretty clear spec for most of the ops (with a bit of mental
translation and hoping that StableHLO is the same as HLO): translation and hoping that MHLO is the same as HLO):
https://www.tensorflow.org/xla/operation_semantics https://www.tensorflow.org/xla/operation_semantics
- It functionally supports dynamic shapes (though not as coherent and consistent - It functionally supports dynamic shapes (though not as coherent and consistent
as Linalg-on-Tensors, and the dynamic shape support falls outside the as Linalg-on-Tensors, and the dynamic shape support falls outside the
@ -317,7 +317,7 @@ reasonable blend of the benefits of the other two.
example, TOSA limits (for highly considered reasons) the number of dimensions example, TOSA limits (for highly considered reasons) the number of dimensions
that certain operators can handle to 1D-4D, when from a purely algebraic that certain operators can handle to 1D-4D, when from a purely algebraic
perspective there isn't a good reason to not be more general. Similarly, more perspective there isn't a good reason to not be more general. Similarly, more
general forms of reduction and scatter also fall into StableHLO nicely while general forms of reduction and scatter also fall into MHLO nicely while
TOSA's principles tend to bias it away from that. TOSA's principles tend to bias it away from that.
### Backend Implementation ### Backend Implementation
@ -433,9 +433,8 @@ filling in some corners missing upstream and
to pull together upstream functionality into a working system. to pull together upstream functionality into a working system.
The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support ops and lowers them to loops. Note that TOSA and MHLO support lowering to
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
RefBackend.
The RefBackend is absolutely not suitable for any production use case. It leaks The RefBackend is absolutely not suitable for any production use case. It leaks
memory, doesn't support any error handling, performs no optimizations, and memory, doesn't support any error handling, performs no optimizations, and

View File

@ -34,7 +34,7 @@ and Clang's
- Eric Kunze (@eric-k256) - Eric Kunze (@eric-k256)
- Suraj Sudhir (@sjarus) - Suraj Sudhir (@sjarus)
### TorchToStablehlo ### TorchToMHLO
- Tianyo Kwok (@tanyokwok) - Tianyo Kwok (@tanyokwok)
- Ziheng Jiang (@ZihengJiang) - Ziheng Jiang (@ZihengJiang)

View File

@ -139,7 +139,7 @@ Ex:
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
``` ```
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
## Jupyter ## Jupyter

View File

@ -46,7 +46,7 @@ the ecosystem are:
- The frontend work required to lower TorchScript to the backend contract. - The frontend work required to lower TorchScript to the backend contract.
- The irregular support surface area of the large number of PyTorch ops across - The irregular support surface area of the large number of PyTorch ops across
the Linalg, TOSA, and StableHLO backends. the Linalg, TOSA, and MHLO backends.
Most of this document describes long-term ecosystem changes that will address Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals. these, drastically improving Torch-MLIR's ability to meet its goals.
@ -108,7 +108,7 @@ more advanced).
### Refactoring the backend ### Refactoring the backend
Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors, Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
TOSA, and StableHLO. These backends take IR in the backend contract form (see TOSA, and MHLO. These backends take IR in the backend contract form (see
[architecture.md](architecture.md)) and lowers them to the respective dialects. [architecture.md](architecture.md)) and lowers them to the respective dialects.
Today, each backend is implemented completely independently. This leads to Today, each backend is implemented completely independently. This leads to
duplication and irregularity across the backends. duplication and irregularity across the backends.
@ -120,10 +120,12 @@ lowering of so many ops across backends. Additionally, there are 3
forward-looking efforts that intersect with this effort: forward-looking efforts that intersect with this effort:
- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect - [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
initially forked from MHLO. MHLO is a fairly complete op set, so it is very initially forked from MHLO which intends to create a stable support surface
attractive to have "almost all" models bottleneck through a stable interface area for what today is our "at head" dependency on MHLO. MHLO is a fairly
like StableHLO. StableHLO is currently under relatively early development, complete op set, so it is very attractive to have "almost all" models
but already delivers on many of the goals of stability. bottleneck through a stable interface like StableHLO. StableHLO is currently
under relatively early development, but already delivers on many of the goals
of stability.
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect - [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
which could serve a role very similar to MHLO, while providing community which could serve a role very similar to MHLO, while providing community
ownership. TCP is still in early planning phases, but there is strong ownership. TCP is still in early planning phases, but there is strong

View File

@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY
from torch_mlir_e2e_test.configs import ( from torch_mlir_e2e_test.configs import (
LazyTensorCoreTestConfig, LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig, LinalgOnTensorsBackendTestConfig,
StablehloBackendTestConfig, MhloBackendTestConfig,
NativeTorchTestConfig, NativeTorchTestConfig,
TorchScriptTestConfig, TorchScriptTestConfig,
TosaBackendTestConfig, TosaBackendTestConfig,
@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import (
) )
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
# Import tests to register them in the global registry. # Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests() register_all_tests()
def _get_argparse(): def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"] config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config", parser.add_argument("-c", "--config",
choices=config_choices, choices=config_choices,
@ -42,7 +42,7 @@ def _get_argparse():
help=f""" help=f"""
Meaning of options: Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend. "linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"stablehlo": run through torch-mlir"s default StableHLO backend. "mhlo": run through torch-mlir"s default MHLO backend.
"tosa": run through torch-mlir"s default TOSA backend. "tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
@ -80,9 +80,9 @@ def main():
if args.config == "tosa": if args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == "stablehlo": if args.config == "mhlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == "native_torch": elif args.config == "native_torch":
config = NativeTorchTestConfig() config = NativeTorchTestConfig()
xfail_set = {} xfail_set = {}

View File

@ -26,11 +26,6 @@ TORCHDYNAMO_XFAIL_SET = {
# https://github.com/pytorch/pytorch/issues/89629 # https://github.com/pytorch/pytorch/issues/89629
"ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2D_basic",
# error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
# https://github.com/llvm/torch-mlir/issues/1859
"ConvolutionModule2DGroups_basic",
# RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Index tensor must have the same number of dimensions as self tensor
# RuntimeError: Failed running call_function aten.nll_loss_backward(... # RuntimeError: Failed running call_function aten.nll_loss_backward(...
# https://github.com/pytorch/pytorch/issues/89630 # https://github.com/pytorch/pytorch/issues/89630
@ -44,6 +39,10 @@ TORCHDYNAMO_XFAIL_SET = {
# RuntimeError: Failed running call_function aten.uniform(... # RuntimeError: Failed running call_function aten.uniform(...
# https://github.com/pytorch/torchdynamo/issues/1954 # https://github.com/pytorch/torchdynamo/issues/1954
"UniformNoCorrelationModule_basic", "UniformNoCorrelationModule_basic",
# TypeError: expected np.ndarray (got float)
# TODO: This is due to returning a scalar float as output from the test.
# We should probably just standardize all tests to return tensors.
"DivIntModule_basic",
#### Torch-MLIR internal compiler errors #### Torch-MLIR internal compiler errors
@ -67,13 +66,14 @@ TORCHDYNAMO_XFAIL_SET = {
"IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
# %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor
"Matmul_vecmat",
# https://github.com/llvm/torch-mlir/issues/1611 # https://github.com/llvm/torch-mlir/issues/1611
# error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible # error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible
"Aten_EmbeddingBagExample_basic", "Aten_EmbeddingBagExample_basic",
# error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal
"BernoulliFloatModule_basic", "BernoulliFloatModule_basic",
"BernoulliPModule_basic",
# error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
"ElementwiseFlattenBroadcastModule_basic", "ElementwiseFlattenBroadcastModule_basic",
"FlattenRank0Module_basic", "FlattenRank0Module_basic",
@ -83,16 +83,8 @@ TORCHDYNAMO_XFAIL_SET = {
# error: unsupported by backend contract: tensor with unknown rank # error: unsupported by backend contract: tensor with unknown rank
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
"ElementwisePreluModule_basic", "ElementwisePreluModule_basic",
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) "StdCorrectionKeepDimModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
#ERROR: value (-56) is not equal to golden value (200)
"AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor)
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
# ERROR: assert isinstance(e, FakeTensor)
"RsubInt0d_NumToTensor_Module_basic",
# Dtype function transition failures # Dtype function transition failures
"MobilenetV3Module_basic", "MobilenetV3Module_basic",
@ -100,12 +92,8 @@ TORCHDYNAMO_XFAIL_SET = {
"ResNet18StaticModule_basic", "ResNet18StaticModule_basic",
} }
STABLEHLO_PASS_SET = { MHLO_PASS_SET = {
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic", "ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic", "ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic", "ArangeFalsePinMemoryModule_basic",
@ -120,15 +108,10 @@ STABLEHLO_PASS_SET = {
"ArangeStartStepFloatModule_basic", "ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic", "ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic", "ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BmmModule_basic", "BmmModule_basic",
"BroadcastToModule_basic", "BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic", "BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic",
@ -143,29 +126,19 @@ STABLEHLO_PASS_SET = {
"ElementwiseClampModule_basic", "ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic", "ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic", "ElementwiseClampMaxModule_basic",
"ElementwisePowModule_basic",
"ElementwiseExpModule_basic", "ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"ElementwiseNegModule_basic", "ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic", "ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic", "ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic", "ElementwiseSqrtModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic", "ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic", "ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseDivScalarModule_basic", "ElementwiseDivScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic", "ElementwiseEqFloatScalarModule_basic",
@ -228,8 +201,6 @@ STABLEHLO_PASS_SET = {
"Gather2DInputModdule_basic", "Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic", "GatherRandomIndexModule_basic",
"GeluBackwardModule_basic", "GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic", "HardTanhIntModule_basic",
"HardTanhModule_basic", "HardTanhModule_basic",
"HardsigmoidModule_basic", "HardsigmoidModule_basic",
@ -252,8 +223,6 @@ STABLEHLO_PASS_SET = {
"MeanDynamicSizesModule_basic", "MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic", "MeanLargeInputModule_basic",
"MeanModule_basic", "MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic", "MmTanhModule_basic",
"Mv_basic", "Mv_basic",
"NativeLayerNormModule4D_basic", "NativeLayerNormModule4D_basic",
@ -270,7 +239,6 @@ STABLEHLO_PASS_SET = {
"ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic", "ReduceSumDtypeIntModule_basic",
"SelectIntModule_basic", "SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic", "SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic", "SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim", "SqueezeDimModule_negDim",
@ -282,15 +250,9 @@ STABLEHLO_PASS_SET = {
"FlattenStaticModule_basic", "FlattenStaticModule_basic",
"FlattenRank0Module_basic", "FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic", "Mlp2LayerModuleNoBias_basic",
"NumelModule_basic", "NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim", "SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim", "SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic", "ViewCollapseOnesMiddleModule_basic",
@ -310,7 +272,6 @@ STABLEHLO_PASS_SET = {
"Convolution2DStaticModule_basic", "Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic", "ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic", "ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic", "ReturnThreeTensorFloat32_basic",
@ -327,7 +288,6 @@ STABLEHLO_PASS_SET = {
"RsubFloatModule_noalpha_basic", "RsubFloatModule_noalpha_basic",
"RsubIntModule_basic", "RsubIntModule_basic",
"RsubIntModule_noalpha_basic", "RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
"SliceModule_basic", "SliceModule_basic",
"SliceNegIdxModule_basic", "SliceNegIdxModule_basic",
@ -398,7 +358,6 @@ STABLEHLO_PASS_SET = {
"ViewExpandCollapseModule_basic", "ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic", "ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic", "ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic", "ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic", "ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic", "ViewNoChange2dModule_basic",
@ -461,14 +420,12 @@ STABLEHLO_PASS_SET = {
"UnsafeViewDynamicExpandModule_basic", "UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic", "AtenRoundIntModule_basic",
"TestF16Return_basic", "TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
} }
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic", "ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic", "ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic", "ElementwiseBinaryModule_basic",
@ -492,7 +449,6 @@ TOSA_PASS_SET = {
"ViewExpandOnesMiddleOppModule_basic", "ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic", "ViewOffsetBackwardTestStaticModule_basic",
"TanhBackward_basic", "TanhBackward_basic",
"HardtanhBackward_basic",
"ElementwiseAddModule_basic", "ElementwiseAddModule_basic",
"ReturnThreeTensorFloat32_basic", "ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic", "AddCMulModule_basic",
@ -503,7 +459,6 @@ TOSA_PASS_SET = {
"BoolTensorReturnMixedModule_basic", "BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic", "BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic", "ElementwiseRsqrtModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SqueezeModule_static", "SqueezeModule_static",
"SqueezeModule_noUnitDim", "SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim", "SqueezeModule_allUnitDim",
@ -525,7 +480,6 @@ TOSA_PASS_SET = {
"Matmul_3d", "Matmul_3d",
"RsubFloatModule_basic", "RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic", "RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt32Module_basic",
@ -555,7 +509,6 @@ TOSA_PASS_SET = {
"ElementwiseDivScalarModule_basic", "ElementwiseDivScalarModule_basic",
"ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarFloatModule_basic",
"ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseMulScalarModule_float", "ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic", "ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic", "ElementwiseReciprocalModule_basic",
@ -619,7 +572,6 @@ TOSA_PASS_SET = {
"ViewExpandCollapseWithOnesModule_basic", "ViewExpandCollapseWithOnesModule_basic",
"ViewCollapseInferredDimModule_basic", "ViewCollapseInferredDimModule_basic",
"ViewExpandInferredDimModule_basic", "ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic", "ViewNoChangeStaticModule_basic",
"UnsafeViewExpandModule_basic", "UnsafeViewExpandModule_basic",
"ReshapeCollapseModule_basic", "ReshapeCollapseModule_basic",
@ -652,7 +604,6 @@ TOSA_PASS_SET = {
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"ElementwiseAtenWhereSelfModule_basic", "ElementwiseAtenWhereSelfModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic",
"MaskedFillScalarIntValueModule_basic",
"MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillTensorIntValueStaticModule_basic", "MaskedFillTensorIntValueStaticModule_basic",
"ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarInt64Module_basic",
@ -660,11 +611,8 @@ TOSA_PASS_SET = {
"TensorOpaqueLiteralModule_basic", "TensorOpaqueLiteralModule_basic",
"TypePromotionDifferentCategoryModule_basic", "TypePromotionDifferentCategoryModule_basic",
"TypePromotionSameCategoryDifferentWidthModule_basic", "TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionSameCategoryZeroRankWider_basic",
"TypePromotionZeroRankHigherCategoryModule_basic", "TypePromotionZeroRankHigherCategoryModule_basic",
"GatherStaticModule_basic", "GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListFloatModule_basic",
@ -703,10 +651,6 @@ TOSA_PASS_SET = {
"HardsigmoidRandomModule_basic", "HardsigmoidRandomModule_basic",
"HardswishModule_basic", "HardswishModule_basic",
"HardswishRandomModule_basic", "HardswishRandomModule_basic",
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
} }
LTC_XFAIL_SET = { LTC_XFAIL_SET = {
@ -722,7 +666,7 @@ LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddIntModule_basic", "AddIntModule_basic",
"AtenIntBoolOpModule_basic", "BernoulliFloatModule_basic",
"BernoulliTensorModule_basic", "BernoulliTensorModule_basic",
"BincountMinlengthModule_basic", "BincountMinlengthModule_basic",
"BincountModule_basic", "BincountModule_basic",
@ -742,7 +686,6 @@ LTC_XFAIL_SET = {
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"HBC_basic", "HBC_basic",
"HardtanhBackward_basic",
"IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic",
@ -777,8 +720,6 @@ LTC_XFAIL_SET = {
"IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic", "IndexTensorModule3dInput_basic",
"IndexTensorModule_basic", "IndexTensorModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputOneDim_basic",
@ -811,8 +752,6 @@ LTC_XFAIL_SET = {
"SubFloatModule_basic", "SubFloatModule_basic",
"SubIntModule_basic", "SubIntModule_basic",
"TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TensorToBoolZeroRank_basic", "TensorToBoolZeroRank_basic",
"TensorToBool_basic", "TensorToBool_basic",
"TensorToFloatZeroRank_basic", "TensorToFloatZeroRank_basic",
@ -849,34 +788,4 @@ LTC_XFAIL_SET = {
"ElementwisePreluModule_basic", "ElementwisePreluModule_basic",
"VarMeanBiasedModule_basic", "VarMeanBiasedModule_basic",
"VarMeanUnbiasedModule_basic", "VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliPModule_basic",
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarCorrectionKeepDimModule_basic",
"VarCorrectionLargeInputModule_basic",
"VarCorrectionModule_basic",
"VarCorrectionNoneModule_basic",
"VarCorrectionSingleDimReduceModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimEmptyDimModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimNoneDimModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"AtenFloatScalarModule_basic"
} }

View File

@ -91,4 +91,4 @@ resnet18 = models.resnet18(pretrained=True)
resnet18.train(False) resnet18.train(False)
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18) dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels) predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels)

View File

@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")

View File

@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module):
model = BertTinyWrapper() model = BertTinyWrapper()
model.eval() model.eval()
data = torch.randint(30522, (2, 128)) data = torch.randint(30522, (2, 128))
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module)) outf.write(str(module))
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}") print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")

View File

@ -1,14 +0,0 @@
import torch
import torchvision.models as models
import torch_mlir
model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))
print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")

View File

@ -10,7 +10,7 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#include "mlir/IR/IRMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"

View File

@ -457,7 +457,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands), "ValueRange":$operands),
[{ [{
IRMapping bvm; BlockAndValueMapping bvm;
OperationState state( OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes, loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs()); $_op->getAttrs());

View File

@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
} }
auto scfIf = b.create<scf::IfOp>( auto scfIf = b.create<scf::IfOp>(
loc, cond, loc, TypeRange{}, cond,
[&](OpBuilder &b, Location loc) { [&](OpBuilder &b, Location loc) {
if (isInclusive) { if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices); auto value = b.create<memref::LoadOp>(loc, input(), indices);
@ -232,7 +232,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
auto &srcBlock = getRegion().front(); auto &srcBlock = getRegion().front();
Region &thisRegion = scfIf.getElseRegion(); Region &thisRegion = scfIf.getElseRegion();
IRMapping bvm; BlockAndValueMapping bvm;
{ {
OpBuilder::InsertionGuard guard(b); OpBuilder::InsertionGuard guard(b);
auto &block = thisRegion.front(); auto &block = thisRegion.front();
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded); return success(folded);
} }
LogicalResult ScanOp::fold(FoldAdaptor adaptor, LogicalResult ScanOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) { SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this); return foldMemRefCast(*this);
} }
@ -461,7 +461,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Value init = b.create<memref::LoadOp>(loc, original(), starts); Value init = b.create<memref::LoadOp>(loc, original(), starts);
IRMapping bvm; BlockAndValueMapping bvm;
Block &block = getRegion().front(); Block &block = getRegion().front();
bvm.map(block.getArgument(0), update); bvm.map(block.getArgument(0), update);
bvm.map(block.getArgument(1), init); bvm.map(block.getArgument(1), init);

@ -1 +1 @@
Subproject commit 21f4b84c456b471cc52016cf360e14d45f7f2960 Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit b1ac0403ee2a40fc648ada6b9f11096f3d50fd19 Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
else() else()
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif() endif()

View File

@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops"; let summary = "Convert Torch ops to MHLO ops";
let description = [{ let description = [{
Convert Torch ops to Stablehlo ops. Convert Torch ops to mhlo ops.
}]; }];
let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; let constructor = "mlir::torch::createConvertTorchToMhloPass()";
// Specify any options. // Specify any options.
let options = [ let options = [

View File

@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H #ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H #define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@ -16,11 +16,10 @@
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(); createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H #endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H

View File

@ -2014,55 +2014,6 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [
}]; }];
} }
def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClampTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenClamp_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -3389,7 +3340,6 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1); printDefaultTorchOp(printer, *this, 3, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
@ -3688,31 +3638,6 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
}]; }];
} }
def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$p,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBernoulliPOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenBernoulliPOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -3846,34 +3771,6 @@ def Torch_AtenRandnGeneratorOp : Torch_Op<"aten.randn.generator", [
}]; }];
} }
def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory,
AnyTorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandnLikeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenRandnLikeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -5249,11 +5146,11 @@ def Torch_AtenStdCorrectionOp : Torch_Op<"aten.std.correction", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`"; let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim, AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalScalarType:$correction, AnyTorchOptionalIntType:$correction,
Torch_BoolType:$keepdim Torch_BoolType:$keepdim
); );
let results = (outs let results = (outs
@ -5325,11 +5222,11 @@ def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`"; let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim, AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalScalarType:$correction, AnyTorchOptionalIntType:$correction,
Torch_BoolType:$keepdim Torch_BoolType:$keepdim
); );
let results = (outs let results = (outs
@ -5351,11 +5248,11 @@ def Torch_AtenVarMeanCorrectionOp : Torch_Op<"aten.var_mean.correction", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)`"; let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchOptionalListOfTorchIntType:$dim, AnyTorchOptionalListOfTorchIntType:$dim,
AnyTorchOptionalScalarType:$correction, AnyTorchOptionalIntType:$correction,
Torch_BoolType:$keepdim Torch_BoolType:$keepdim
); );
let results = (outs let results = (outs
@ -6585,35 +6482,6 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
}]; }];
} }
def Torch_AtenNewEmptyStridedOp : Torch_Op<"aten.new_empty_strided", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNewEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenNewEmptyStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [ def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -7107,6 +6975,30 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenSumOp : Torch_Op<"aten.sum", [ def Torch_AtenSumOp : Torch_Op<"aten.sum", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -7664,86 +7556,6 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
}]; }];
} }
def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_add_ : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterAdd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatterReduceTwoOp : Torch_Op<"aten.scatter_reduce.two", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduceTwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduceTwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenScatterReduce_TwoOp : Torch_Op<"aten.scatter_reduce_.two", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_reduce_.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce,
Torch_BoolType:$include_self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduce_TwoOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenScatterReduce_TwoOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -8858,31 +8670,6 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_AtenStackOp : Torch_Op<"aten.stack", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenStackOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {
@ -9298,7 +9085,6 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
@ -9325,30 +9111,6 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`";
let arguments = (ins
Torch_BoolType:$a
);
let results = (outs
Torch_IntType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [ def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -9818,7 +9580,6 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [
@ -10089,31 +9850,6 @@ def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [
}]; }];
} }
def Torch_AtenPowIntFloatOp : Torch_Op<"aten.pow.int_float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::pow.int_float : (int, float) -> (float)`";
let arguments = (ins
Torch_IntType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPowIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPowIntFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -10571,7 +10307,6 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasCanonicalizer = 1;
} }
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
@ -10624,32 +10359,6 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [
}]; }];
} }
def Torch_AtenHardtanhBackwardOp : Torch_Op<"aten.hardtanh_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchScalarType:$min_val,
AnyTorchScalarType:$max_val
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHardtanhBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenHardtanhBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [ def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -11024,7 +10733,6 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
@ -11225,11 +10933,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`"; let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$inp, AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims, AnyTorchOptionalListOfTorchIntType:$dims,
Torch_FloatType:$correction, Torch_IntType:$correction,
AnyTorchOptionalIntType:$output_dtype AnyTorchOptionalIntType:$output_dtype
); );
let results = (outs let results = (outs

View File

@ -376,6 +376,9 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
Pure, Pure,
TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
"isValidSubtype">,
AllowedInModuleInitializer, AllowedInModuleInitializer,
]> { ]> {
let summary = "TorchScript prim::TupleConstruct op"; let summary = "TorchScript prim::TupleConstruct op";
@ -394,8 +397,6 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
let assemblyFormat = [{ let assemblyFormat = [{
$elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result)) $elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result))
}]; }];
let hasVerifier = 1;
} }
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [

View File

@ -98,8 +98,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps); createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass(); std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass(); std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
@ -123,7 +121,8 @@ createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps); ArrayRef<std::string> backendLegalOps);
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createVerifyBackendContractNoDecompositionsPass(); createVerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps);
StringRef getAbstractInterpLibrary(); StringRef getAbstractInterpLibrary();

View File

@ -343,17 +343,24 @@ def LowerToBackendContract
let dependentDialects = ["func::FuncDialect"]; let dependentDialects = ["func::FuncDialect"];
} }
def VerifyBackendContractNoDecompositions def VerifyBackendContract
: Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> { : Pass<"torch-verify-backend-contract", "ModuleOp"> {
let summary = "Check that program satisfies backend contract."; let summary = "Check that program satisfies backend contract.";
let constructor = [{ let constructor = [{
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() mlir::torch::Torch::createVerifyBackendContractPass(
/*decompose=*/true, /*backendLegalOps=*/{})
}]; }];
let description = [{ let description = [{
This pass performs a set of inspections to check that program satisfies backend This pass performs a set of inspections to check that program satisfies backend
contract assuming that no decompositions were applied. In case of check failure contract. In case of check failure it prints out the error message and returns
it prints out the error message and returns `signalPassFailure()` status. `signalPassFailure()` status.
}]; }];
let options = [
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">,
ListOption<"backendLegalOps", "backend-legal-ops", "std::string",
"List of ops to be considered legal for the backend.">
];
} }
#endif // TORCHMLIR_TORCH_PASSES #endif // TORCHMLIR_TORCH_PASSES

View File

@ -9,7 +9,6 @@
#define TORCHMLIR_DIALECT_TORCH_UPSTREAM_H #define TORCHMLIR_DIALECT_TORCH_UPSTREAM_H
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
// For layering reasons, the parts of the core MLIR compiler code written in C++ // For layering reasons, the parts of the core MLIR compiler code written in C++
// never take a C++ dependency on Torch itself (any code depending on Torch C++ // never take a C++ dependency on Torch itself (any code depending on Torch C++
@ -161,15 +160,6 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===-----------------------------------------------------------------------===// //===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX }; enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
//===----------------------------------------------------------------------===//
// Possible value for `reduce` argument for Scatter reduce ops.
// Source:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
//===-----------------------------------------------------------------------===//
enum ReductionType { MAX, MEAN, MIN, SUM, PROD };
ReductionType get_reduction_enum(const llvm::StringRef &reduce);
} // namespace torch_upstream } // namespace torch_upstream
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v, std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length); int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type); torch_upstream::ScalarType getScalarTypeForType(Type type);
FailureOr<Type> getTypeForScalarType( Type getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt, MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

View File

@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
else() else()
mlir_tablegen(Passes.h.inc -gen-pass-decls) mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif() endif()

View File

@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
/// TOSA backend contract. /// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
// Do not register the stablehlo options if the stablehlo target is disabled // Do not register the torch-to-mhlo pipeline if mhlo target is disabled
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
struct StablehloBackendPipelineOptions struct MhloBackendPipelineOptions
: public PassPipelineOptions<StablehloBackendPipelineOptions> { : public PassPipelineOptions<MhloBackendPipelineOptions> {
Option<bool> enableStaticShape{ Option<bool> enableStaticShape{
*this, "enable-static-shape", *this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
@ -46,10 +46,9 @@ struct StablehloBackendPipelineOptions
llvm::cl::init(false)}; llvm::cl::init(false)};
}; };
void createTorchBackendToStablehloBackendPipeline( void createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options); OpPassManager &pm, const MhloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
createVerifyStablehloBackendContractPass();
#endif #endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass(); std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

View File

@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract"; let summary = "Verifies conformity to the mhlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
} }
#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCH_MLIR_ENABLE_MHLO
#endif // TORCHMLIR_TORCHCONVERSION_PASSES #endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -61,7 +61,7 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
return wrap(Torch::TupleType::get( return wrap(Torch::TupleType::get(
unwrap(context), unwrap(context),
llvm::to_vector<6>( llvm::to_vector<6>(
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
@ -89,7 +89,7 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
return wrap(Torch::UnionType::get( return wrap(Torch::UnionType::get(
unwrap(context), unwrap(context),
llvm::to_vector<6>( llvm::to_vector<6>(
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
@ -230,7 +230,7 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt; std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked. // if numSizes == -1, then it is unranked.
if (numSizes > -1) if (numSizes > -1)
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get( return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
} }
@ -293,7 +293,7 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt; std::optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked. // if numSizes == -1, then it is unranked.
if (numSizes > -1) if (numSizes > -1)
optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get( return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
} }

View File

@ -3,7 +3,13 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect) add_subdirectory(Dialect)
add_subdirectory(RefBackend) add_subdirectory(RefBackend)
set(LinkedLibs add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRFuncDialect MLIRFuncDialect
MLIRIR MLIRIR
MLIRSupport MLIRSupport
@ -21,22 +27,4 @@ set(LinkedLibs
TorchMLIRRefBackend TorchMLIRRefBackend
) )
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
MhloPasses
MhloToLinalg
StablehloToMhlo
)
endif()
add_mlir_library(TorchMLIRInitAll
InitAll.cpp
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
${LinkedLibs}
)
torch_mlir_target_includes(TorchMLIRInitAll) torch_mlir_target_includes(TorchMLIRInitAll)

View File

@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith) add_subdirectory(TorchToArith)
add_subdirectory(TorchToTosa) add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToStablehlo) add_subdirectory(TorchToMhlo)
endif() endif()
add_subdirectory(TorchToTMTensor) add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram) add_subdirectory(TorchConversionToMLProgram)
@ -17,8 +17,10 @@ set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils) TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo) list(APPEND linked_libs
MhloPasses
TorchMLIRTorchToMhlo)
endif() endif()
add_mlir_library(TorchMLIRConversionPasses add_mlir_library(TorchMLIRConversionPasses

View File

@ -9,15 +9,15 @@
#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "mhlo/transforms/passes.h"
#include "transforms/passes.h" #include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
@ -32,4 +32,12 @@ namespace {
void mlir::torch::registerConversionPasses() { void mlir::torch::registerConversionPasses() {
::registerPasses(); ::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createSymbolicShapeOptimizationPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
} }

View File

@ -68,7 +68,7 @@ public:
// temp = multiplier * currentSeed + incrementStep // temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier); Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep); Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange()); globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar);
rewriter.create<ml_program::GlobalStoreOp>( rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar); globalVar);

View File

@ -232,67 +232,6 @@ public:
return success(); return success();
} }
}; };
class ConvertTorchConstantIntOp
: public OpConversionPattern<Torch::ConstantIntOp> {
public:
using OpConversionPattern<Torch::ConstantIntOp>::OpConversionPattern;
using OpAdaptor = Torch::ConstantIntOp::Adaptor;
LogicalResult
matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// note: arith.constant only accept singless integer, so convert singed to
// singless
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getIntegerAttr(rewriter.getI64Type(),
op.getValueAttr().getValue()));
return success();
}
};
} // namespace
namespace {
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenFloatScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result =
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenAddOp : public OpConversionPattern<AtenAddOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value operandA =
convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType);
Value operandB =
convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType);
if (resultType.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<arith::AddFOp>(op, operandA, operandB);
} else if (resultType.isa<mlir::IntegerType>()) {
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, operandA, operandB);
} else {
return rewriter.notifyMatchFailure(
op, "unimplemented: only support integer or float result type");
}
return success();
}
};
} // namespace } // namespace
namespace { namespace {
@ -442,14 +381,8 @@ public:
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter, patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
context); context);
target.addIllegalOp<Torch::ConstantIntOp>(); target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context); patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
target.addIllegalOp<AtenFloatScalarOp>();
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>(); target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>( patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context); typeConverter, context);

View File

@ -463,8 +463,8 @@ public:
} }
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input); SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt); ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize); ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize);
// Association indices for expand/collapse ops. These two vectors // Association indices for expand/collapse ops. These two vectors
// are populated such that two entries at the same index corresponds // are populated such that two entries at the same index corresponds
@ -1117,18 +1117,6 @@ public:
RankedTensorType newResultType = RankedTensorType newResultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = newResultType.getElementType();
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
builder.create<linalg::YieldOp>(loc, elem);
};
for (size_t i = 0; i < tensors.size(); ++i) {
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
}
int rank = newResultType.getRank(); int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides; SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank); sizes.reserve(rank);
@ -1148,7 +1136,7 @@ public:
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>( Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim)); loc, rewriter.getIndexAttr(dim));
for (auto tensor : ArrayRef(tensors).drop_front()) { for (auto tensor : makeArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex); auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize = resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size); rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
@ -1282,7 +1270,7 @@ public:
/*resultType=*/selfType, /*resultType=*/selfType,
/*inputs=*/broadcastedSrc, /*inputs=*/broadcastedSrc,
/*outputs=*/self, /*outputs=*/self,
/*indexingMaps=*/llvm::ArrayRef({id, id}), /*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) { [](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0]; Value result = args[0];

View File

@ -81,21 +81,9 @@ public:
Type inElementType = inputType.getElementType(); Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) { if (!inElementType.isa<mlir::FloatType>()) {
if (inElementType.isa<mlir::IntegerType>()) {
auto integerTy = maxDimOp.getSelf()
.getType()
.cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
if (integerTy.isUnsigned())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires input element type " maxDimOp,
"to be signed in case of integer"); "aten.max_dim to linalg.* requires Float input element type");
} else {
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
"input element type");
}
} }
// Constant op to account for the reduction along dim. // Constant op to account for the reduction along dim.
@ -116,23 +104,13 @@ public:
Value initTensorMax = rewriter.create<tensor::EmptyOp>( Value initTensorMax = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), inElementType); loc, getAsOpFoldResult(resultShape), inElementType);
Value fillValueMax; FloatAttr fillValueMaxAttr = rewriter.getFloatAttr(
if (inElementType.isa<mlir::FloatType>()) {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inElementType, inElementType,
APFloat::getLargest( APFloat::getLargest(
inElementType.cast<mlir::FloatType>().getFloatSemantics(), inElementType.cast<mlir::FloatType>().getFloatSemantics(), true));
true)));
} else {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
inElementType,
APSInt::getSignedMinValue(
inElementType.cast<mlir::IntegerType>().getWidth())));
}
Value fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
Value filledTensorMax = Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax) rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
.result(); .result();
@ -174,18 +152,10 @@ public:
nestedLoc, oldIndex.getType(), nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim)); rewriter.create<linalg::IndexOp>(loc, dim));
Value resultMax, predicate; auto resultMax = rewriter.create<arith::MaxFOp>(
if (inElementType.isa<mlir::FloatType>()) { nestedLoc, newValue, oldValue);
resultMax = Value predicate = rewriter.create<arith::CmpFOp>(
rewriter.create<arith::MaxFOp>(nestedLoc, newValue, oldValue);
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
} else {
resultMax =
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
}
auto resultIndex = rewriter.create<arith::SelectOp>( auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex); nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>( nestedBuilder.create<linalg::YieldOp>(

View File

@ -127,14 +127,9 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none"); op, "unimplemented: dtype must be a constant integer or none");
FailureOr<Type> maybeResultElementType = getTypeForScalarType( resultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt, op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless); IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
} }
// Create an uninitialized tensor of `resultSize` shape and fill it with // Create an uninitialized tensor of `resultSize` shape and fill it with
@ -232,14 +227,9 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none"); op, "unimplemented: dtype must be a constant integer or none");
FailureOr<Type> maybeResultElementType = getTypeForScalarType( resultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt, op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless); IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
} }
// Create an uninitialized tensor of `resultSize` shape. // Create an uninitialized tensor of `resultSize` shape.

View File

@ -59,15 +59,6 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs); b, loc, elementalType, lhs, rhs);
} }
static Value createGreaterThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs,
Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UGE,
arith::CmpIPredicate::uge,
arith::CmpIPredicate::sge>(
b, loc, elementalType, lhs, rhs);
}
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) { Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULT, return createComparisonTemplate<arith::CmpFPredicate::ULT,
@ -76,14 +67,6 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
b, loc, elementalType, lhs, rhs); b, loc, elementalType, lhs, rhs);
} }
static Value createLessThanOrEqual(OpBuilder &b, Location loc,
Type elementalType, Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::ULE,
arith::CmpIPredicate::ule,
arith::CmpIPredicate::sle>(
b, loc, elementalType, lhs, rhs);
}
static Value createEqual(OpBuilder &b, Location loc, Type elementalType, static Value createEqual(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) { Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::UEQ, return createComparisonTemplate<arith::CmpFPredicate::UEQ,
@ -134,46 +117,6 @@ static Value createCalculationForMathOpWithDtypeConversion(
return b.create<MathOpTy>(loc, arg); return b.create<MathOpTy>(loc, arg);
} }
template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
static Value createLinalgPayloadCalculationForElementwiseOp( static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, TypeConverter *converter, OpBuilder &b, Location loc, TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) { ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
@ -234,10 +177,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() && if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
(!matchPattern(clone.getMemoryFormat(), (!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) || m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { clone.emitError("unimplemented: only default memory format is supported");
clone.emitError("unimplemented: only contiguous and channels last memory "
"format is supported");
return nullptr; return nullptr;
} }
return payloadArgs[0]; return payloadArgs[0];
@ -352,7 +293,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
round.emitError("unimplemented: non-floating point dtype"); round.emitError("unimplemented: non-floating point dtype");
return nullptr; return nullptr;
} }
return b.create<math::RoundEvenOp>(loc, payloadArgs[0]); return b.create<math::RoundOp>(loc, payloadArgs[0]);
} }
if (auto prelu = dyn_cast<AtenPreluOp>(op)) { if (auto prelu = dyn_cast<AtenPreluOp>(op)) {
if (!prelu.getType() if (!prelu.getType()
@ -429,29 +370,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf); Value cdfExt = b.create<arith::AddFOp>(loc, dinputInputAlpha, cdf);
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt); return b.create<arith::MulFOp>(loc, payloadArgs[0], cdfExt);
} }
if (auto hardtanhBackward = dyn_cast<AtenHardtanhBackwardOp>(op)) {
AtenHardtanhBackwardOp::Adaptor adaptor(operands);
if (!hardtanhBackward.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
hardtanhBackward.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value gradOutput = payloadArgs[0];
Type elementType = gradOutput.getType();
Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType);
Value constantZero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType);
Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType);
Value lesser =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, self, min);
Value greater =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, self, max);
Value cmp = b.create<arith::OrIOp>(loc, lesser, greater);
return b.create<arith::SelectOp>(loc, cmp, constantZero, gradOutput);
}
if (auto add = dyn_cast<AtenAddTensorOp>(op)) { if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
AtenAddTensorOp::Adaptor adaptor(operands); AtenAddTensorOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(add.getType()) Type dtype = converter->convertType(add.getType())
@ -545,24 +463,63 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs); return b.create<math::Atan2Op>(loc, lhs, rhs);
} }
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) { if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], AtenGtTensorOp::Adaptor adaptor(operands);
payloadArgs[1]); Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
gtTensor.emitError("unimplemented: different lhs and rhs dtype");
return nullptr;
} }
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], Type elementalType =
gtTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]); payloadArgs[1]);
} }
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) { if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], AtenEqTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
eqTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
eqTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
if (elementalType.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
payloadArgs[0], payloadArgs[1]);
if (elementalType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
payloadArgs[0], payloadArgs[1]);
}
eqTensor.emitError("unimplemented: dtype isn't supported.");
return nullptr;
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
AtenLtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Type rhsDtype = payloadArgs[1].getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
ltTensor.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType =
ltTensor.getSelf().getType().cast<BaseTensorType>().getDtype();
return createLessThan(b, loc, elementalType, payloadArgs[0],
payloadArgs[1]); payloadArgs[1]);
} }
if (auto div = dyn_cast<AtenDivTensorOp>(op)) { if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
@ -1007,6 +964,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
.getElementType(); .getElementType();
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
} }
if (auto maskedFillScalar = dyn_cast<AtenMaskedFillScalarOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillScalar.getType())
.cast<RankedTensorType>()
.getElementType();
Value input = payloadArgs[0];
Value mask = payloadArgs[1];
Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
return b.create<arith::SelectOp>(loc, mask, fillValue, input);
}
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) { if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
AtenMaskedFillScalarOp::Adaptor adaptor(operands); AtenMaskedFillScalarOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(maskedFillTensor.getType()) Type dtype = converter->convertType(maskedFillTensor.getType())
@ -1065,7 +1034,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value allOnesVal = b.create<arith::ConstantOp>( Value allOnesVal = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr( loc, b.getIntegerAttr(
elementType, elementType,
APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth())));
return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal); return b.create<arith::XOrIOp>(loc, payloadArgs[0], allOnesVal);
} }
@ -1113,10 +1082,10 @@ public:
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op)) AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
@ -1592,12 +1561,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp, AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(); AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -7,16 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -30,7 +29,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other, mlir::Value &self, mlir::Value &other,
@ -44,7 +43,7 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
if (selfRank > otherRank) { if (selfRank > otherRank) {
auto unsqueezeDims = auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank)); llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other,
unsqueezeDims, dimSizeIndexBits); unsqueezeDims, dimSizeIndexBits);
if (failed(unsqueezeInfo)) if (failed(unsqueezeInfo))
return failure(); return failure();
@ -52,8 +51,8 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
} else if (otherRank > selfRank) { } else if (otherRank > selfRank) {
auto unsqueezeDims = auto unsqueezeDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank)); llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self,
dimSizeIndexBits); unsqueezeDims, dimSizeIndexBits);
if (failed(unsqueezeInfo)) if (failed(unsqueezeInfo))
return failure(); return failure();
self = *unsqueezeInfo; self = *unsqueezeInfo;
@ -79,8 +78,7 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)); /*negative=*/false));
return rewriter return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (elementType.isa<mlir::IntegerType>()) {
@ -93,8 +91,7 @@ static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMaxValue(integerType.getWidth())); constType, APInt::getSignedMaxValue(integerType.getWidth()));
} }
return rewriter return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
return failure(); return failure();
@ -108,8 +105,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constType, constType,
APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getInf(elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)); /*negative=*/true));
return rewriter return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
if (elementType.isa<mlir::IntegerType>()) { if (elementType.isa<mlir::IntegerType>()) {
@ -122,8 +118,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
constAttr = SplatElementsAttr::get( constAttr = SplatElementsAttr::get(
constType, APInt::getSignedMinValue(integerType.getWidth())); constType, APInt::getSignedMinValue(integerType.getWidth()));
} }
return rewriter return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType, constAttr)
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
.getResult(); .getResult();
} }
return failure(); return failure();
@ -131,7 +126,7 @@ static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
// These legalizations are for unary ops. // These legalizations are for unary ops.
namespace { namespace {
template <typename AtenOpT, typename StablehloOpT> template <typename AtenOpT, typename MhloOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> { class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -142,13 +137,13 @@ public:
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfType = self.getType().cast<TensorType>(); auto selfType = self.getType().cast<TensorType>();
if (!selfType) { if (!selfType) {
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
} }
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
self = hlo::promoteType(rewriter, self, outType); self = mhlo::promoteType(rewriter, self, outType);
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self); rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, self);
return success(); return success();
} }
}; };
@ -157,7 +152,7 @@ public:
// These legalizations are for unary ops with only for floating point datatypes. // These legalizations are for unary ops with only for floating point datatypes.
// There is no supported quantized integer mode for these. // There is no supported quantized integer mode for these.
namespace { namespace {
template <typename AtenOpT, typename StablehloOpT> template <typename AtenOpT, typename MhloOpT>
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> { class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -169,10 +164,10 @@ public:
auto selfTy = self.getType().cast<TensorType>(); auto selfTy = self.getType().cast<TensorType>();
if (!selfTy) if (!selfTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) { if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<StablehloOpT>( rewriter.replaceOpWithNewOp<MhloOpT>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -203,7 +198,7 @@ public:
.template dyn_cast<TensorType>(); .template dyn_cast<TensorType>();
if (!outType) if (!outType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
Type outElemTy = outType.getElementType(); Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) if (!outElemTy.isIntOrFloat())
@ -221,9 +216,9 @@ public:
SmallVector<int32_t> values(size, fillVal); SmallVector<int32_t> values(size, fillVal);
auto constOp = auto constOp =
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value(); mhlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp); rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, constOp);
return success(); return success();
} }
}; };
@ -252,8 +247,8 @@ public:
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
lhs = hlo::promoteType(rewriter, lhs, outTy); lhs = mhlo::promoteType(rewriter, lhs, outTy);
rhs = hlo::promoteType(rewriter, rhs, outTy); rhs = mhlo::promoteType(rewriter, rhs, outTy);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
/*broadcast_attr*/ nullptr); /*broadcast_attr*/ nullptr);
@ -279,7 +274,7 @@ public:
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -292,19 +287,18 @@ public:
} }
if (!rhsType) { if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
outElemTy);
if (isa<AtenRsubScalarOp>(op)) { if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs); std::swap(lhs, rhs);
} }
} }
lhs = hlo::promoteType(rewriter, lhs, outType); lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType);
if (!skipMultiplyAlpha(op.getAlpha())) { if (!skipMultiplyAlpha(op.getAlpha())) {
Value alpha = hlo::scalarToStablehloTensor(rewriter, op, Value alpha =
adaptor.getAlpha(), outElemTy); mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha, rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions); bcastDimensions);
@ -334,7 +328,7 @@ public:
TensorType rhsType = rhs.getType().dyn_cast<TensorType>(); TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType) if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -349,12 +343,11 @@ public:
if (std::is_same<AtenOpT, AtenSquareOp>()) { if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs; rhs = lhs;
} else if (!rhsType) { } else if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
outElemTy);
} }
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType); lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType); rhs = mhlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc(); auto loc = op.getLoc();
Value result = Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions); rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
@ -375,15 +368,15 @@ public:
if (roundingMode == "trunc") { if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent // "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division. // to C-style integer division.
auto sign = rewriter.create<stablehlo::SignOp>(loc, result); auto sign = rewriter.create<mhlo::SignOp>(loc, result);
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result); auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs); auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult(); result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
} }
if (roundingMode == "floor") { if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to // "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator) // floor division in Python (the // operator)
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult(); result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
} }
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
return success(); return success();
@ -408,7 +401,7 @@ public:
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>(); RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsTy) if (!lhsTy)
return op.emitError("only Tensor types supported in StableHLO"); return op.emitError("only Tensor types supported in MHLO");
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -421,12 +414,11 @@ public:
} }
if (!rhsTy) { if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy);
lhsElemTy);
} }
// TODO: what is the PyTorch default type promotion? // TODO: what is the PyTorch default type promotion?
rhs = hlo::promoteType(rewriter, rhs, lhsTy); rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonTypeAttr compareTypeAttr;
chlo::ComparisonDirectionAttr compareDirectionAttr; chlo::ComparisonDirectionAttr compareDirectionAttr;
@ -493,8 +485,8 @@ public:
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter() TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<TensorType>(); .template cast<TensorType>();
Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType); Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType); Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
@ -545,7 +537,7 @@ public:
RankedTensorType::get({static_cast<long int>(permValues.size())}, RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
permValues); permValues);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self, rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation); permutation);
return success(); return success();
} }
@ -560,7 +552,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto outType = auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self); rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, self);
return success(); return success();
} }
@ -581,8 +573,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
} else { } else {
Value inputRank = rewriter.create<arith::ConstantOp>( Value inputRank = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank);
inputRank);
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(), dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getIndexType(), dim); rewriter.getIndexType(), dim);
} }
@ -598,8 +589,9 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
template <> template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor, AtenWhereSelfOp op,
ConversionPatternRewriter &rewriter) const { OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
Value cond = adaptor.getCondition(); Value cond = adaptor.getCondition();
Value other = adaptor.getOther(); Value other = adaptor.getOther();
@ -613,7 +605,8 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
return op.emitError("failed broadcast other and condition ranks"); return op.emitError("failed broadcast other and condition ranks");
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>( rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
op, getTypeConverter()->convertType(op.getType()), op,
getTypeConverter()->convertType(op.getType()),
ArrayRef<Value>{cond, self, other}); ArrayRef<Value>{cond, self, other});
return success(); return success();
} }
@ -630,7 +623,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
.cast<RankedTensorType>(); .cast<RankedTensorType>();
if (options.enableStaticShape && selfTy.hasStaticShape()) { if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp); rewriter.replaceOp(op, bcastOp);
return success(); return success();
} }
@ -677,7 +670,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op->getLoc(), ValueRange{bcastShapeVec}); op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers = auto dimensionNumbers =
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank)); llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor, op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers)); rewriter.getI64TensorAttr(dimensionNumbers));
} }
@ -715,11 +708,28 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
RankedTensorType::get({static_cast<long int>(permValues.size())}, RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
permValues); permValues);
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self, rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation); permutation);
return success(); return success();
} }
// AtenTanhOp
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<TensorType>();
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
} else {
return op.emitError(
"only floating-point datatype legalization currently supported");
}
}
// ValueTensorLiteralOp // ValueTensorLiteralOp
template <> template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite( LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
@ -741,16 +751,16 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue()); return APInt(bitWidth, v.getSExtValue());
}); });
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType, rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType, valueAttr);
valueAttr);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType, rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, resultType,
adaptor.getValue()); adaptor.getValue());
return success(); return success();
} }
// AtenReciprocalOp // AtenReciprocalOp
// Reciprocal(x) = Div(1, x) // Reciprocal(x) = Div(1, x)
template <> template <>
@ -767,45 +777,7 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
} }
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input); rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
// AtenPowTensorScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.getExponent();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
rewriter.replaceOp(op, result);
return success(); return success();
} }
@ -818,9 +790,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
->convertType(op->getResult(0).getType()) ->convertType(op->getResult(0).getType())
.cast<RankedTensorType>(); .cast<RankedTensorType>();
auto outputElemType = outputType.getElementType(); auto outputElemType = outputType.getElementType();
Value stablehloTensor = hlo::scalarToStablehloTensor( Value mhloTensor =
rewriter, op, adaptor.getA(), outputElemType); mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType);
rewriter.replaceOp(op, stablehloTensor); rewriter.replaceOp(op, mhloTensor);
return success(); return success();
} }
@ -843,6 +815,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success(); return success();
} }
// AtenReluOp // AtenReluOp
// Relu(x) = Max(0, x) // Relu(x) = Max(0, x)
template <> template <>
@ -863,10 +836,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(), APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false), false),
lhs); lhs);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor); rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
return success(); return success();
} }
// Convert a Aten::GELU to HLO // Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
template <> template <>
@ -883,12 +857,12 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
auto rsqrtTwo = rewriter.create<mlir::stablehlo::RsqrtOp>(loc, two); auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo); auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement); auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one); auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
auto halfMul = rewriter.create<stablehlo::MulOp>(loc, erfAdd, half); auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, input, halfMul); rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
return success(); return success();
} }
@ -907,6 +881,7 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success(); return success();
} }
// AtenBatchNormOp // AtenBatchNormOp
template <> template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
@ -944,28 +919,28 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelShape = rewriter.create<tensor::FromElementsOp>( Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim}); op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) { if (failed(checkNotNone(rewriter, op, weight))) {
weight = hlo::getConstantOfShape( weight = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, bias))) { if (failed(checkNotNone(rewriter, op, bias))) {
bias = hlo::getConstantOfShape( bias = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, runningVar))) { if (failed(checkNotNone(rewriter, op, runningVar))) {
runningVar = hlo::getConstantOfShape( runningVar = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType())); inputTy.getElementType()));
} }
if (failed(checkNotNone(rewriter, op, runningMean))) { if (failed(checkNotNone(rewriter, op, runningMean))) {
runningMean = hlo::getConstantOfShape( runningMean = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape, channelShape,
RankedTensorType::get({inputTy.getShape()[1]}, RankedTensorType::get({inputTy.getShape()[1]},
@ -1008,8 +983,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Type outputTy = getTypeConverter()->convertType(op.getType()); Type outputTy = getTypeConverter()->convertType(op.getType());
Type batchMeanOrVarTy = Type batchMeanOrVarTy =
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
auto batchNormTrainingResult = auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
rewriter.create<stablehlo::BatchNormTrainingOp>(
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
weight, bias, rewriter.getF32FloatAttr(eps), weight, bias, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1)); rewriter.getI64IntegerAttr(1));
@ -1021,11 +995,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
inputTy.getShape().end()}; inputTy.getShape().end()};
castShape[1] = weightTy.getShape()[0]; castShape[1] = weightTy.getShape()[0];
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
// Feature counts must match among operands of // Feature counts must match among operands of mhlo::BatchNormInferenceOp.
// stablehlo::BatchNormInferenceOp.
Value inputCasted = Value inputCasted =
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input); rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
Value output = rewriter.create<stablehlo::BatchNormInferenceOp>( Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
runningMean, runningVar, runningMean, runningVar,
// 'epsilon' must satisfy constraint: 32-bit float attribute. // 'epsilon' must satisfy constraint: 32-bit float attribute.
@ -1035,6 +1008,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
} }
} }
// AtenNativeLayerNormOp // AtenNativeLayerNormOp
template <> template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
@ -1102,21 +1076,21 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
} }
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize, SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
numEmbeddingDimSize}; numEmbeddingDimSize};
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize}; SmallVector<int64_t> meanOrVarMhloOutShape{numFeatureDimSize};
auto stablehloBatchNormOutTy = auto mhloBatchNormOutTy =
RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get( auto mhloBathNormOutMeanOrVarTy =
meanOrVarStablehloOutShape, inputTy.getElementType()); RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
// Reshape input // Reshape input
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>( auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), stablehloBatchNormOutTy, input, op->getLoc(), mhloBatchNormOutTy, input,
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())}) {static_cast<int64_t>(inputFlattenShape.size())})
.value()); .value());
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec( SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType() numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>() .cast<mlir::FloatType>()
@ -1129,18 +1103,16 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto oneOrZeroConstType = auto oneOrZeroConstType =
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
Value scale = rewriter.create<stablehlo::ConstantOp>( Value scale = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType, op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
Value offset = rewriter.create<stablehlo::ConstantOp>( Value offset = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), oneOrZeroConstType, op->getLoc(), oneOrZeroConstType,
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
auto batchNormTrainingResult = auto batchNormTrainingResult = rewriter.create<mhlo::BatchNormTrainingOp>(
rewriter.create<stablehlo::BatchNormTrainingOp>( op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy,
op->getLoc(), stablehloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
rewriter.getI64IntegerAttr(1));
// Reshape back // Reshape back
auto outputTy = auto outputTy =
@ -1148,35 +1120,36 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto outputMeanOrVarTy = auto outputMeanOrVarTy =
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>(); getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
auto output = rewriter.create<stablehlo::DynamicReshapeOp>( auto output = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
hlo::getConstTensor(rewriter, op, outputTy.getShape(), mhlo::getConstTensor(rewriter, op, outputTy.getShape(),
{static_cast<int64_t>(outputTy.getShape().size())}) {static_cast<int64_t>(outputTy.getShape().size())})
.value()); .value());
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>( auto mean = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
hlo::getConstTensor( mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value()); .value());
auto var = rewriter.create<stablehlo::DynamicReshapeOp>( auto var = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
hlo::getConstTensor( mhlo::getConstTensor(
rewriter, op, outputMeanOrVarTy.getShape(), rewriter, op, outputMeanOrVarTy.getShape(),
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())}) {static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
.value()); .value());
// Apply affine transform: output x weight + bias [element-wise] // Apply affine transform: output x weight + bias [element-wise]
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy);
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy);
auto outputMulWeight = auto outputMulWeight =
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight); rewriter.create<mhlo::MulOp>(op->getLoc(), output, bcastedWeight);
auto finalOuput = rewriter.create<stablehlo::AddOp>( auto finalOuput =
op->getLoc(), outputMulWeight, bcastedBias); rewriter.create<mhlo::AddOp>(op->getLoc(), outputMulWeight, bcastedBias);
rewriter.replaceOp(op, {finalOuput, mean, var}); rewriter.replaceOp(op, {finalOuput, mean, var});
return success(); return success();
} }
// AtenCatOp // AtenCatOp
template <> template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
@ -1200,11 +1173,11 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
// Promote type // Promote type
for (auto &v : builtinTensors) { for (auto &v : builtinTensors) {
v = hlo::promoteType(rewriter, v, outType); v = mhlo::promoteType(rewriter, v, outType);
} }
size_t posDim = toPositiveDim(dim, outType.getRank()); size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>( rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
op, outType, ValueRange(builtinTensors), posDim); op, outType, ValueRange(builtinTensors), posDim);
return success(); return success();
} }
@ -1252,8 +1225,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "this op should be folded as its `min` and `max` both are none"); op, "this op should be folded as its `min` and `max` both are none");
} else if (failed(checkNotNone(rewriter, op, minValue))) { } else if (failed(checkNotNone(rewriter, op, minValue))) {
maxValue = maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
if (failed(minInfo)) { if (failed(minInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1261,8 +1233,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
} }
minValue = *minInfo; minValue = *minInfo;
} else if (failed(checkNotNone(rewriter, op, maxValue))) { } else if (failed(checkNotNone(rewriter, op, maxValue))) {
minValue = minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
if (failed(maxInfo)) { if (failed(maxInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -1270,13 +1241,10 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
} }
maxValue = *maxInfo; maxValue = *maxInfo;
} else { } else {
minValue = minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType);
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType);
maxValue =
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
} }
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input, rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, minValue, input, maxValue);
maxValue);
return success(); return success();
} }
@ -1298,27 +1266,24 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: only int or float dtype supported"); op, "unimplemented: only int or float dtype supported");
} }
Value start = Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype);
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype); Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value end = Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype);
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
Value step =
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
// Get length of the 1-d output tensor // Get length of the 1-d output tensor
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start); Value subOut = rewriter.create<mhlo::SubtractOp>(loc, end, start);
Value divOut = rewriter.create<stablehlo::DivOp>(loc, subOut, step); Value divOut = rewriter.create<mhlo::DivOp>(loc, subOut, step);
Value resultLength = rewriter.create<stablehlo::ReshapeOp>( Value resultLength = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({1}, dtype), divOut); loc, RankedTensorType::get({1}, dtype), divOut);
if (dtype.isa<mlir::FloatType>()) { if (dtype.isa<mlir::FloatType>()) {
resultLength = rewriter.create<stablehlo::CeilOp>(loc, resultLength); resultLength = rewriter.create<mhlo::CeilOp>(loc, resultLength);
resultLength = rewriter.create<stablehlo::ConvertOp>( resultLength = rewriter.create<mhlo::ConvertOp>(
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
} }
Value window = Value window =
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0); rewriter.create<mhlo::DynamicIotaOp>(loc, outType, resultLength, 0);
DenseIntElementsAttr broadcastDimensions; DenseIntElementsAttr broadcastDimensions;
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step, Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
broadcastDimensions); broadcastDimensions);
@ -1333,8 +1298,9 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto outType = auto outType = this->getTypeConverter()
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>(); ->convertType(op.getType())
.cast<TensorType>();
if (!outType) { if (!outType) {
return op.emitError("only tensor type is supported"); return op.emitError("only tensor type is supported");
} }
@ -1354,27 +1320,26 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input);
// Compute // Compute
Value kBeta0 = Value kBeta0 = rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0); Value kBeta = rewriter.create<mhlo::MulOp>(loc, outType, kBeta0, half);
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half); Value erfArg =
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, rewriter.create<mhlo::MulOp>(loc, outType, kAlpha, adaptor.getSelf());
adaptor.getSelf());
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg); Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one); Value erfAdd = rewriter.create<mhlo::AddOp>(loc, outType, erf, one);
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half); Value cdf = rewriter.create<mhlo::MulOp>(loc, outType, erfAdd, half);
Value inputSquared = rewriter.create<stablehlo::MulOp>( Value inputSquared = rewriter.create<mhlo::MulOp>(
loc, outType, adaptor.getSelf(), adaptor.getSelf()); loc, outType, adaptor.getSelf(), adaptor.getSelf());
Value negHalfInputSquared = Value negHalfInputSquared =
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf); rewriter.create<mhlo::MulOp>(loc, outType, inputSquared, negHalf);
Value expRes = Value expRes =
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared); rewriter.create<mhlo::ExpOp>(loc, outType, negHalfInputSquared);
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes); Value pdf = rewriter.create<mhlo::MulOp>(loc, outType, kBeta, expRes);
Value pdfTimesInput = Value pdfTimesInput =
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf()); rewriter.create<mhlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
Value pdfTimesInputAddCdf = Value pdfTimesInputAddCdf =
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf); rewriter.create<mhlo::AddOp>(loc, outType, pdfTimesInput, cdf);
rewriter.replaceOpWithNewOp<stablehlo::MulOp>( rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, adaptor.getGradOutput(),
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); pdfTimesInputAddCdf);
return success(); return success();
} }
@ -1401,9 +1366,9 @@ public:
}; };
} // namespace } // namespace
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>(); target.addIllegalOp<AtenTransposeIntOp>();
@ -1411,29 +1376,23 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
target.addIllegalOp<RuntimeAssertOp>(); target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context); patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ #define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context) patterns.add<ConvertAtenUnaryOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp);
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp);
#undef INSERT_UNARY_PATTERN #undef INSERT_UNARY_PATTERN
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \ patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
context) INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
#undef INSERT_UNARY_FPONLY_PATTERN #undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
@ -1500,9 +1459,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenContiguousOp);
@ -1523,10 +1482,10 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \ patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, MhloOp>>(typeConverter, \
typeConverter, context) context)
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);

View File

@ -0,0 +1,35 @@
add_mlir_conversion_library(TorchMLIRTorchToMhlo
TorchToMhlo.cpp
MhloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
DEPENDS
MhloDialect
MhloToLinalg
MLIRMhloPassIncGen
LMHLOTransformsPassIncGen
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MhloDialect
MhloToLinalg
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToMhlo)

View File

@ -7,15 +7,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -25,7 +24,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
namespace { namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
@ -70,7 +69,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t, 4> startIndexMap(1, axis); SmallVector<int64_t, 4> startIndexMap(1, axis);
// indexVecDim // indexVecDim
int64_t indexVecDim = indicesRank; int64_t indexVecDim = indicesRank;
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*offsetDims=*/offsetDims, /*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedSliceDims, /*collapsedSliceDims=*/collapsedSliceDims,
@ -92,18 +91,17 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
auto outputTy = auto outputTy =
RankedTensorType::get(outputShape, inputRankTy.getElementType()); RankedTensorType::get(outputShape, inputRankTy.getElementType());
return rewriter return rewriter
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices, .create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
sliceSizesTensor, dimsAttr) sliceSizesTensor, dimsAttr)
.getResult(); .getResult();
} }
} // namespace } // namespace
// Ref: // Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional) // padding_idx (int, optional)
// If specified, the entries at padding_idx do not contribute to the // If specified, the entries at padding_idx do not contribute to the gradient;
// gradient; therefore, the embedding vector at padding_idx is not updated // therefore, the embedding vector at padding_idx is not updated during training,
// during training, i.e. it remains as a fixed “pad”. // i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional) // scale_grad_by_freq (boolean, optional)
// If given, this will scale gradients by the inverse of frequency of the // If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False. // words in the mini-batch. Default False.
@ -141,7 +139,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis( Value output = gatherTensorAlongSingleAxis(
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>( rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
return success(); return success();
@ -163,7 +161,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
Value output = gatherTensorAlongSingleAxis( Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>( rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
return success(); return success();
@ -202,7 +200,7 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
auto options = getOptions(); auto options = getOptions();
auto indexShapeInfo = auto indexShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
if (failed(indexShapeInfo)) { if (failed(indexShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dim sizes of `index` param"); op, "failed to get dim sizes of `index` param");
@ -225,15 +223,15 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
SmallVector<Value> toConcat; SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) { for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == dim) { if (i == dim) {
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>( toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, toConcatIndexType, index, toConcatIndexShape)); loc, toConcatIndexType, index, toConcatIndexShape));
} else { } else {
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>( toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
loc, toConcatIndexType, toConcatIndexShape, loc, toConcatIndexType, toConcatIndexShape,
rewriter.getI64IntegerAttr(i))); rewriter.getI64IntegerAttr(i)));
} }
} }
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>( auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank())); loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1); SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
@ -245,22 +243,22 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
startIndexMap.push_back(i); startIndexMap.push_back(i);
} }
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*offsetDims=*/{}, /*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims, /*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap, /*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim); /*indexVecDim=*/indexVecDim);
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>( rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
op, input, gatherIndicies, dimsAttr, op, input, gatherIndicies, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes)); rewriter.getI64TensorAttr(sliceSizes));
return success(); return success();
} }
void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -7,16 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -26,7 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
namespace { namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -34,7 +33,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
ArrayRef<int64_t> broadcastDims) { ArrayRef<int64_t> broadcastDims) {
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>(); auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes); Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
RankedTensorType outTy = RankedTensorType outTy =
RankedTensorType::get(shape, tensorTy.getElementType()); RankedTensorType::get(shape, tensorTy.getElementType());
@ -44,8 +43,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
rewriter.getIntegerType(64)); rewriter.getIntegerType(64));
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
auto broadcast = rewriter.create<stablehlo::DynamicBroadcastInDimOp>( auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, outTy, tensor, stablehloShape, broadcastAttr); loc, outTy, tensor, mhloShape, broadcastAttr);
return broadcast; return broadcast;
} }
@ -53,7 +52,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
ArrayRef<int64_t> inpTransDims) { ArrayRef<int64_t> inpTransDims) {
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank(); auto rank = inputTy.getRank();
auto transDims = hlo::toPositiveDims(inpTransDims, rank); auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
auto inpShape = inputTy.getShape(); auto inpShape = inputTy.getShape();
std::vector<int64_t> newShape; std::vector<int64_t> newShape;
newShape.reserve(rank); newShape.reserve(rank);
@ -67,8 +66,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
auto result = rewriter.create<stablehlo::TransposeOp>(op->getLoc(), outTy, auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
input, permuteAttr); permuteAttr);
return result.getResult(); return result.getResult();
} }
@ -120,12 +119,10 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
} }
// set result dimensions // set result dimensions
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
lhsResultDim >= 0) {
outShape.push_back(lhsShape[lhsResultDim]); outShape.push_back(lhsShape[lhsResultDim]);
} }
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
rhsResultDim >= 0) {
outShape.push_back(rhsShape[rhsResultDim]); outShape.push_back(rhsShape[rhsResultDim]);
} }
return RankedTensorType::get(outShape, lhsTy.getElementType()); return RankedTensorType::get(outShape, lhsTy.getElementType());
@ -154,10 +151,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(), std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank); rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims, auto newDimSizes = *mhlo::getDimSizesOfTensor(
dimSizeIndexBits); rewriter, op, rhs, leadingDims, dimSizeIndexBits);
auto lhsDimSizes = auto lhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); *mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end()); lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -166,10 +163,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(), std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank); lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims, auto newDimSizes = *mhlo::getDimSizesOfTensor(
dimSizeIndexBits); rewriter, op, lhs, leadingDims, dimSizeIndexBits);
auto rhsDimSizes = auto rhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); *mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end()); rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -221,8 +218,8 @@ public:
if (lhsRank <= 2 && rhsRank <= 2) { if (lhsRank <= 2 && rhsRank <= 2) {
auto tensorType = auto tensorType =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()); ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
output = rewriter.create<stablehlo::DotOp>(op->getLoc(), tensorType, lhs, output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
rhs, nullptr); nullptr);
return success(); return success();
} }
@ -256,8 +253,8 @@ public:
lhsContractingDim = nBatchDims; lhsContractingDim = nBatchDims;
} }
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get( mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims, /*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims,
@ -267,7 +264,7 @@ public:
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
output = rewriter output = rewriter
.create<stablehlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs, .create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr) dotDimensionNumbers, nullptr)
.getResult(); .getResult();
return success(); return success();
@ -315,7 +312,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in StableHLO matmul"); "only ranked tensor types are supported in MHLO matmul");
return success(); return success();
} }
@ -338,7 +335,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in StableHLO matmul"); "only ranked tensor types are supported in MHLO matmul");
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -374,7 +371,7 @@ public:
if (!lhsTy || !rhsTy) if (!lhsTy || !rhsTy)
return op.emitError( return op.emitError(
"only ranked tensor types are supported in StableHLO matmul"); "only ranked tensor types are supported in MHLO matmul");
auto lhsRank = lhsTy.getRank(); auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank(); auto rhsRank = rhsTy.getRank();
@ -401,10 +398,10 @@ public:
auto bias = adaptor.getBias(); auto bias = adaptor.getBias();
auto biasTy = bias.getType(); auto biasTy = bias.getType();
// StableHLO does not mandate that elementwise op tensors need to be ranked. // MHLO does not mandate that elementwise op tensors need to be ranked.
if (!biasTy.template isa<Torch::NoneType>() && if (!biasTy.template isa<Torch::NoneType>() &&
!biasTy.template isa<RankedTensorType>()) !biasTy.template isa<RankedTensorType>())
return op.emitError("only ranked tensor types are supported in StableHLO " return op.emitError("only ranked tensor types are supported in MHLO "
"matmul for bias tensor"); "matmul for bias tensor");
// weight.T // weight.T
@ -430,14 +427,14 @@ public:
auto outTy = auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim); lhsContractingDim, rhsContractingDim);
stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
stablehlo::DotDimensionNumbersAttr::get( mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims, /*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim}, /*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<stablehlo::DotGeneralOp>( Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
Value matmulPlusBias = matmulOutput; Value matmulPlusBias = matmulOutput;
@ -467,7 +464,7 @@ public:
auto weightElemTy = weightTy.getElementType(); auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank(); auto rank = weightTy.getRank();
const auto &options = getOptions(); const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *hlo::getDimSizesOfTensor( SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits); rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank); SmallVector<int64_t> weightShapeInt(rank);
@ -491,7 +488,7 @@ public:
} }
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec); op->getLoc(), weightShapeVec);
weight = rewriter.create<stablehlo::DynamicReshapeOp>( weight = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor); weight, weightShapeTensor);
@ -500,7 +497,7 @@ public:
for (int64_t i = 0; i <= rank; i++) for (int64_t i = 0; i <= rank; i++)
transposeDims[i] = i; transposeDims[i] = i;
std::swap(transposeDims[1], transposeDims[0]); std::swap(transposeDims[1], transposeDims[0]);
weight = rewriter.create<stablehlo::TransposeOp>( weight = rewriter.create<mhlo::TransposeOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
@ -512,7 +509,7 @@ public:
weightShapeVec[1] = OCMulGValue; weightShapeVec[1] = OCMulGValue;
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec); op->getLoc(), weightShapeVec);
weight = rewriter.create<stablehlo::DynamicReshapeOp>( weight = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor); weight, weightShapeTensor);
return weight; return weight;
@ -547,27 +544,25 @@ public:
} }
// Prepare for transposed convolution // Prepare for transposed convolution
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1); SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr stablehloStride = DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
rewriter.getI64TensorAttr(stablehloStrideVec); SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
SmallVector<int64_t> stablehloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) { for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
stablehloPaddingVec[i * 2] = padInt; mhloPaddingVec[i * 2] = padInt;
stablehloPaddingVec[i * 2 + 1] = padInt; mhloPaddingVec[i * 2 + 1] = padInt;
} }
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
stablehloPaddingVec); mhloPaddingVec);
SmallVector<int64_t> stablehloLhsDilationVec(nSpatialDims); SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
DenseIntElementsAttr stablehloLhsDilation = DenseIntElementsAttr mhloLhsDilation =
rewriter.getI64TensorAttr(stablehloLhsDilationVec); rewriter.getI64TensorAttr(mhloLhsDilationVec);
SmallVector<int64_t> stablehloRhsDilationVec(nSpatialDims); SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(), std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
stablehloRhsDilationVec.begin()); DenseIntElementsAttr mhloRhsDilation =
DenseIntElementsAttr stablehloRhsDilation = rewriter.getI64TensorAttr(mhloRhsDilationVec);
rewriter.getI64TensorAttr(stablehloRhsDilationVec);
DenseElementsAttr windowReversal; DenseElementsAttr windowReversal;
ArrayAttr precisionConfig; ArrayAttr precisionConfig;
@ -576,8 +571,8 @@ public:
for (int i = 0; i < nSpatialDims; ++i) { for (int i = 0; i < nSpatialDims; ++i) {
spatialDims.push_back(i + 2); spatialDims.push_back(i + 2);
} }
stablehlo::ConvDimensionNumbersAttr dimensionNumbers = mhlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get( mhlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1, /*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDims, /*inputSpatialDimensions=*/spatialDims,
@ -588,18 +583,17 @@ public:
/*outputSpatialDimensions=*/spatialDims); /*outputSpatialDimensions=*/spatialDims);
// Reverse and transpose weight // Reverse and transpose weight
weight = rewriter.create<stablehlo::ReverseOp>( weight = rewriter.create<mhlo::ReverseOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
if (groups != 1) { if (groups != 1) {
weight = reshapeConvWeight(rewriter, op, weight, groups); weight = reshapeConvWeight(rewriter, op, weight, groups);
} }
// Create transposed convolution // Create transposed convolution
auto transposedConvOp = rewriter.create<stablehlo::ConvolutionOp>( auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, stablehloStride, op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1, static_cast<uint64_t>(groups), 1, precisionConfig);
precisionConfig);
// Handle output padding // Handle output padding
if (!needHandleOutputPadding) { if (!needHandleOutputPadding) {
@ -611,8 +605,8 @@ public:
std::copy(outputPadding.begin(), outputPadding.end(), std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2); edgePaddingHighVec.begin() + 2);
Value paddingValue = Value paddingValue =
hlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value(); mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).value();
paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy); paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow = mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec); rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh = mlir::DenseIntElementsAttr edgePaddingHigh =
@ -620,7 +614,7 @@ public:
mlir::DenseIntElementsAttr interiorPadding = mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec); rewriter.getI64VectorAttr(interiorPaddingVec);
auto paddedOutput = rewriter.create<stablehlo::PadOp>( auto paddedOutput = rewriter.create<mhlo::PadOp>(
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
edgePaddingHigh, interiorPadding); edgePaddingHigh, interiorPadding);
@ -634,22 +628,22 @@ public:
ArrayRef<int64_t> dilation, int64_t groups) const { ArrayRef<int64_t> dilation, int64_t groups) const {
int64_t nDims = outType.getRank(); int64_t nDims = outType.getRank();
// Get stablehlo::ConvolutionOp attributes // Get mhlo::ConvolutionOp attributes
DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(stride.size())}, RankedTensorType::get({static_cast<long int>(stride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stride); stride);
std::vector<int64_t> stablehloPaddingVec; std::vector<int64_t> mhloPaddingVec;
for (size_t i = 0; i < padding.size(); i++) { for (size_t i = 0; i < padding.size(); i++) {
stablehloPaddingVec.emplace_back(padding[i]); mhloPaddingVec.emplace_back(padding[i]);
stablehloPaddingVec.emplace_back(padding[i]); mhloPaddingVec.emplace_back(padding[i]);
} }
DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<long int>(padding.size()), static_cast<long int>(2)}, {static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloPaddingVec); mhloPaddingVec);
DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())}, RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
dilation); dilation);
@ -657,8 +651,8 @@ public:
for (int64_t i = 2; i < nDims; i++) { for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i); spatialDimensions.emplace_back(i);
} }
stablehlo::ConvDimensionNumbersAttr dimensionNumbers = mhlo::ConvDimensionNumbersAttr dimensionNumbers =
stablehlo::ConvDimensionNumbersAttr::get( mhlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1, /*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDimensions, /*inputSpatialDimensions=*/spatialDimensions,
@ -668,18 +662,17 @@ public:
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDimensions); /*outputSpatialDimensions=*/spatialDimensions);
// stablehlo::ConvolutionOp's optional attributes, leave them as default // mhlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr stablehloLhsDilation; DenseIntElementsAttr mhloLhsDilation;
DenseElementsAttr windowReversal; DenseElementsAttr windowReversal;
ArrayAttr precisionConfig; ArrayAttr precisionConfig;
auto stablehloConvOp = rewriter.create<stablehlo::ConvolutionOp>( auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), outType, input, weight, stablehloWindowStride, op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
windowReversal, dimensionNumbers, static_cast<uint64_t>(groups), 1, static_cast<uint64_t>(groups), 1, precisionConfig);
precisionConfig);
return stablehloConvOp.getResult(); return mhloConvOp.getResult();
} }
LogicalResult LogicalResult
@ -761,22 +754,21 @@ public:
} }
} }
Value stablehloConvResult; Value mhloConvResult;
if (transposed) { if (transposed) {
stablehloConvResult = convertTransposedConv( mhloConvResult = convertTransposedConv(
op, rewriter, outTy, input, weight, stride, padding, dilation, op, rewriter, outTy, input, weight, stride, padding, dilation,
outputPadding, groups, needHandleOutputPadding); outputPadding, groups, needHandleOutputPadding);
} else { } else {
stablehloConvResult = mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
convertNormalConv(op, rewriter, outTy, input, weight, stride, padding, stride, padding, dilation, groups);
dilation, groups);
} }
auto bias = adaptor.getBias(); auto bias = adaptor.getBias();
// No bias provided // No bias provided
if (failed(checkNotNone(rewriter, op, op.getBias()))) { if (failed(checkNotNone(rewriter, op, op.getBias()))) {
rewriter.replaceOp(op, stablehloConvResult); rewriter.replaceOp(op, mhloConvResult);
return success(); return success();
} }
@ -798,21 +790,21 @@ public:
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0)); llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
const auto &options = getOptions(); const auto &options = getOptions();
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
bias = hlo::promoteType(rewriter, bias, outTy); bias = mhlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>( rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
op, outTy, stablehloConvResult, bias, bcastDimensions); bias, bcastDimensions);
return success(); return success();
} }
}; };
} // namespace } // namespace
void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \

View File

@ -7,12 +7,11 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "StablehloLegalizeUtils.h" #include "./MhloLegalizeUtils.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric> #include <numeric>
@ -22,27 +21,27 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
namespace mlir { namespace mlir {
namespace hlo { namespace mhlo {
// Create a 32-bit float constant operator from a float // Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) { float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val); auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
// Create a 64-bit float constant operator from a double // Create a 64-bit float constant operator from a double
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val) { double val) {
auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, val); auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -66,8 +65,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -89,8 +88,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
shape, rewriter.getIntegerType(vec[0].getBitWidth())); shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -112,8 +111,8 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -134,8 +133,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); auto const_type = RankedTensorType::get(shape, rewriter.getF64Type());
auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
@ -170,18 +169,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) { T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype); auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val); auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op = rewriter.create<stablehlo::ConstantOp>( auto const_op =
op->getLoc(), const_type, const_attr); rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult(); return const_op.getResult();
} }
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Operation *op, Value scalarValue, Type dtype) { Value scalarValue, Type dtype) {
auto tensor = rewriter.create<tensor::FromElementsOp>( auto tensor = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ArrayRef<Value>{scalarValue}); op->getLoc(), ArrayRef<Value>{scalarValue});
auto dtype_tensor = auto dtype_tensor =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype); rewriter.create<mhlo::ConvertOp>(op->getLoc(), tensor, dtype);
return rewriter.create<stablehlo::ReshapeOp>( return rewriter.create<mhlo::ReshapeOp>(
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype), op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
dtype_tensor); dtype_tensor);
} }
@ -193,8 +192,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promotedType = TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType()); in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promotedType, return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
input);
} }
return input; return input;
} }
@ -212,8 +210,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
if (in_type.getElementType() != outType.getElementType()) { if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type = TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType()); in_type.cloneWith(in_type.getShape(), outType.getElementType());
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type, input =
input); rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
} }
ArrayRef<int64_t> inShape = in_type.getShape(); ArrayRef<int64_t> inShape = in_type.getShape();
@ -247,8 +245,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
bcastDims); bcastDims);
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>( auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
op->getLoc(), outType, input, bcast_attr); input, bcast_attr);
return bcast_op.getResult(); return bcast_op.getResult();
} }
@ -350,8 +348,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
} }
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes); auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape) return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
.getResult(); .getResult();
} }
@ -359,11 +357,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType) { TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr); auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
return rewriter return rewriter
.create<stablehlo::DynamicBroadcastInDimOp>( .create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) rewriter.getI64TensorAttr({}))
.getResult(); .getResult();
} }
} // namespace hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -7,8 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H #ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@ -18,22 +18,22 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace hlo { namespace mhlo {
using mlir::ConversionPatternRewriter; using mlir::ConversionPatternRewriter;
// Create a 32-bit float constant operator from a float // Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val); float val);
// Create a 64-bit float constant operator from a double // Create a 64-bit float constant operator from a double
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
double val); double val);
// Templated function to create a constant op for given type and shape. // Templated function to create a constant op for given type and shape.
// T: storage C type. // T: storage C type.
// Default template creates a constant tensor in T. // Default template creates a constant tensor in T.
// To create INT48 StableHLO constant, need to pass in llvm::APInt instead. // To create INT48 MHLO constant, need to pass in llvm::APInt instead.
template <typename T> template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op, std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape); ArrayRef<T> vec, ArrayRef<int64_t> shape);
@ -42,8 +42,8 @@ template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape); T val, Type dtype, llvm::ArrayRef<int64_t> dshape);
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op,
Operation *op, Value scalarValue, Type dtype); Value scalarValue, Type dtype);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
@ -71,7 +71,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType); TensorType outType);
} // namespace hlo } // namespace mhlo
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H #endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H

View File

@ -7,16 +7,15 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -29,25 +28,25 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy); auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling // Avg pooling
if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) { if (isa<AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) { if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -59,14 +58,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest( constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -117,43 +116,42 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> stablehloStride(inputRank, 1); SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1); SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1); SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0); SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(), std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 2); mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2); mhloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloKernelSize); mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())}, RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloStride); mhloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloDilation); mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloPadding); mhloPadding);
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>( auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad); baseDilations, windowDilations, pad);
@ -170,8 +168,8 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value result = Value result =
rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<mhlo::MaxOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result); rewriter.create<mhlo::ReturnOp>(op->getLoc(), result);
} }
rewriter.replaceOp(op, reduceWindowOp.getResults()); rewriter.replaceOp(op, reduceWindowOp.getResults());
@ -223,46 +221,45 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> stablehloStride(inputRank, 1); SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1); SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1); SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0); SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(dilation.begin(), dilation.end(), std::copy(dilation.begin(), dilation.end(),
stablehloDilation.begin() + inputRank - 2); mhloDilation.begin() + inputRank - 2);
std::copy(stride.begin(), stride.end(), std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2); mhloKernelSize.begin() + inputRank - 2);
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1];
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloKernelSize); mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())}, RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloStride); mhloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloDilation); mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloPadding); mhloPadding);
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -292,7 +289,7 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto initIndexTensor = auto initIndexTensor =
rewriter rewriter
.create<stablehlo::DynamicIotaOp>( .create<mhlo::DynamicIotaOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(initIndexShapeForType, RankedTensorType::get(initIndexShapeForType,
rewriter.getI64Type()), rewriter.getI64Type()),
@ -301,15 +298,15 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto indexTensor = auto indexTensor =
rewriter rewriter
.create<stablehlo::DynamicReshapeOp>( .create<mhlo::DynamicReshapeOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputShape, rewriter.getI64Type()), RankedTensorType::get(inputShape, rewriter.getI64Type()),
initIndexTensor, inputShapeTensor) initIndexTensor, inputShapeTensor)
.getResult(); .getResult();
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value(); Value initIdx = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>( auto reduceWindowOp = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
windowDimensions, windowStrides, baseDilations, windowDilations, pad); windowDimensions, windowStrides, baseDilations, windowDilations, pad);
@ -329,43 +326,43 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
auto *secondValArg = std::next(firstIdxArg); auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr; mhlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT); rewriter.getContext(), mhlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED); rewriter.getContext(), mhlo::ComparisonType::SIGNED);
} }
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get( mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
rewriter.getContext(), stablehlo::ComparisonDirection::GE); mhlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get( mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
rewriter.getContext(), stablehlo::ComparisonDirection::EQ); mhlo::ComparisonDirection::EQ);
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<stablehlo::CompareOp>( Value compareGeResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr); compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<stablehlo::SelectOp>( Value retValResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg); op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// Get smaller index if compared values are equal. // Get smaller index if compared values are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>( Value compareEqResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr); compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg, Value minIdx =
*secondIdxArg); rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>( Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>( Value retIdxResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal); op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<stablehlo::ReturnOp>( rewriter.create<mhlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
} }
@ -422,42 +419,41 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
// prepend 1 to kernelSize, stride, dilation until they are of same rank as // prepend 1 to kernelSize, stride, dilation until they are of same rank as
// input // input
SmallVector<int64_t> stablehloStride(inputRank, 1); SmallVector<int64_t> mhloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1); SmallVector<int64_t> mhloDilation(inputRank, 1);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1); SmallVector<int64_t> mhloKernelSize(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0); SmallVector<int64_t> mhloPadding(inputRank * 2, 0);
std::copy(stride.begin(), stride.end(), std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2);
stablehloStride.begin() + inputRank - 2);
std::copy(kernelSize.begin(), kernelSize.end(), std::copy(kernelSize.begin(), kernelSize.end(),
stablehloKernelSize.begin() + inputRank - 2); mhloKernelSize.begin() + inputRank - 2);
stablehloPadding[stablehloPadding.size() - 4] = padding[0]; mhloPadding[mhloPadding.size() - 4] = padding[0];
stablehloPadding[stablehloPadding.size() - 3] = padding[0]; mhloPadding[mhloPadding.size() - 3] = padding[0];
stablehloPadding[stablehloPadding.size() - 2] = padding[1]; mhloPadding[mhloPadding.size() - 2] = padding[1];
stablehloPadding[stablehloPadding.size() - 1] = padding[1]; mhloPadding[mhloPadding.size() - 1] = padding[1];
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())}, RankedTensorType::get({static_cast<int64_t>(mhloKernelSize.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloKernelSize); mhloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())}, RankedTensorType::get({static_cast<int64_t>(mhloStride.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloStride); mhloStride);
DenseIntElementsAttr baseDilations; DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())}, RankedTensorType::get({static_cast<int64_t>(mhloDilation.size())},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloDilation); mhloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get( DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)}, {static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()), rewriter.getI64Type()),
stablehloPadding); mhloPadding);
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>( auto reduceWindowSum = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad); baseDilations, windowDilations, pad);
@ -475,39 +471,39 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sumBlock); rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult = Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
} }
// Use kernel size as the divisor // Use kernel size as the divisor
if (countIncludePad) { if (countIncludePad) {
Value divisor = hlo::getConstTensor<int64_t>( Value divisor = mhlo::getConstTensor<int64_t>(
rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
.value(); .value();
divisor = hlo::promoteType(rewriter, divisor, outTy); divisor = mhlo::promoteType(rewriter, divisor, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>( rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
return success(); return success();
} }
// Use another stablehlo.ReduceWindowOp to get the divisor // Use another mhlo.ReduceWindowOp to get the divisor
Value windowSizeConst = Value windowSizeConst =
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value(); mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeVec = auto inputShapeVec =
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); *mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>( windowSizeConst = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({}));
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>( auto reduceWindowSize = rewriter.create<mhlo::ReduceWindowOp>(
op->getLoc(), RankedTensorType::get(outShape, inputElemTy), op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
windowDilations, pad); windowDilations, pad);
@ -526,99 +522,18 @@ LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
rewriter.setInsertionPointToStart(&sizeBlock); rewriter.setInsertionPointToStart(&sizeBlock);
Value sumResult = Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg); rewriter.create<mhlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult); rewriter.create<mhlo::ReturnOp>(op->getLoc(), sumResult);
} }
rewriter.replaceOpWithNewOp<stablehlo::DivOp>( rewriter.replaceOpWithNewOp<mhlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success(); return success();
} }
// AtenCumsumOp void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
template <>
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().cast<RankedTensorType>();
auto inputElemTy = inputTy.getElementType();
auto inputRank = inputTy.getRank();
auto inputShape = inputTy.getShape();
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: dim must be a constant int");
}
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "dim is out of range");
}
if (inputTy.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: cumsum dim must be static");
}
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
stablehloKernelSize[dim] = inputShape[dim];
SmallVector<int64_t> stablehloStride(inputRank, 1);
SmallVector<int64_t> stablehloDilation(inputRank, 1);
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
stablehloPadding[dim * 2] = inputShape[dim] - 1;
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
rewriter.getI64Type()),
stablehloStride);
DenseIntElementsAttr baseDilations;
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
rewriter.getI64Type()),
stablehloDilation);
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
rewriter.getI64Type()),
stablehloPadding);
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
baseDilations, windowDilations, pad);
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
// Add bb argument
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
sumBlock.addArgument(blockArgumentType, op->getLoc());
sumBlock.addArgument(blockArgumentType, op->getLoc());
auto *firstArg = sumBlock.args_begin();
auto *secondArg = std::next(firstArg);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&sumBlock);
Value sumResult =
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
}
rewriter.replaceOp(op, reduceWindowSum.getResults());
return success();
}
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>(); target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options); patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
@ -627,6 +542,4 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>(); target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter, patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context, options); context, options);
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
} }

View File

@ -0,0 +1,74 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H

View File

@ -7,15 +7,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -26,7 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy, static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
@ -37,13 +36,13 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getZero( constType, {APFloat::getZero(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false)}); /*negative=*/false)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -54,14 +53,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
constType, {APFloat::getLargest( constType, {APFloat::getLargest(
elementTy.cast<mlir::FloatType>().getFloatSemantics(), elementTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/true)}); /*negative=*/true)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} else if (elementTy.isa<mlir::IntegerType>() && } else if (elementTy.isa<mlir::IntegerType>() &&
elementTy.getIntOrFloatBitWidth() != 8) { elementTy.getIntOrFloatBitWidth() != 8) {
auto constAttr = DenseElementsAttr::get( auto constAttr = DenseElementsAttr::get(
constType, constType,
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType, return rewriter.create<mhlo::ConstantOp>(op->getLoc(), constType,
constAttr); constAttr);
} }
} }
@ -91,9 +90,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
return std::nullopt; return std::nullopt;
Value initIndex; Value initIndex;
if (dimSizeIndexBits == 32) { if (dimSizeIndexBits == 32) {
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value(); initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else { } else {
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value(); initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
} }
DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
@ -101,13 +100,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>( auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
op->getLoc(), op->getLoc(),
RankedTensorType::get(inputShape, RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)), rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim)); inputShapeTensor, static_cast<uint64_t>(dim));
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>( auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), ValueRange{input, indexTensor}, op->getLoc(), ValueRange{input, indexTensor},
ValueRange{ ValueRange{
initValue, initValue,
@ -115,7 +114,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
}, },
dimensions); dimensions);
Block &block = stablehloReduceOp.getBody().emplaceBlock(); Block &block = mhloReduceOp.getBody().emplaceBlock();
// Add block arguments // Add block arguments
auto blockValArgumentType = auto blockValArgumentType =
@ -134,46 +133,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto *secondValArg = std::next(firstIdxArg); auto *secondValArg = std::next(firstIdxArg);
auto *secondIdxArg = std::next(secondValArg); auto *secondIdxArg = std::next(secondValArg);
stablehlo::ComparisonTypeAttr compareTypeAttr; mhlo::ComparisonTypeAttr compareTypeAttr;
if (inputTy.getElementType().isa<mlir::FloatType>()) { if (inputTy.getElementType().isa<mlir::FloatType>()) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::FLOAT); rewriter.getContext(), mhlo::ComparisonType::FLOAT);
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) { } else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
compareTypeAttr = stablehlo::ComparisonTypeAttr::get( compareTypeAttr = mhlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED); rewriter.getContext(), mhlo::ComparisonType::SIGNED);
} }
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = mhlo::ComparisonDirectionAttr compareGeDirectionAttr =
stablehlo::ComparisonDirectionAttr::get( mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
rewriter.getContext(), stablehlo::ComparisonDirection::GE); mhlo::ComparisonDirection::GE);
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = mhlo::ComparisonDirectionAttr compareEqDirectionAttr =
stablehlo::ComparisonDirectionAttr::get( mhlo::ComparisonDirectionAttr::get(rewriter.getContext(),
rewriter.getContext(), stablehlo::ComparisonDirection::EQ); mhlo::ComparisonDirection::EQ);
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value compareGeResult = rewriter.create<stablehlo::CompareOp>( Value compareGeResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareGeDirectionAttr, compareTypeAttr); compareGeDirectionAttr, compareTypeAttr);
Value retValResult = rewriter.create<stablehlo::SelectOp>( Value retValResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstValArg, *secondValArg); op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
// get smaller index value if compared nums are equal. // get smaller index value if compared nums are equal.
Value compareEqResult = rewriter.create<stablehlo::CompareOp>( Value compareEqResult = rewriter.create<mhlo::CompareOp>(
op->getLoc(), compareResultType, *firstValArg, *secondValArg, op->getLoc(), compareResultType, *firstValArg, *secondValArg,
compareEqDirectionAttr, compareTypeAttr); compareEqDirectionAttr, compareTypeAttr);
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg, Value minIdx =
*secondIdxArg); rewriter.create<mhlo::MinOp>(op->getLoc(), *firstIdxArg, *secondIdxArg);
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>( Value idxWithGeVal = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
Value retIdxResult = rewriter.create<stablehlo::SelectOp>( Value retIdxResult = rewriter.create<mhlo::SelectOp>(
op->getLoc(), compareEqResult, minIdx, idxWithGeVal); op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
rewriter.create<stablehlo::ReturnOp>( rewriter.create<mhlo::ReturnOp>(
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
} }
return stablehloReduceOp.getResults(); return mhloReduceOp.getResults();
} }
namespace { namespace {
@ -197,8 +196,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -211,7 +209,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenArgmaxOp to StableHLO"); "AtenArgmaxOp to MHLO");
} }
int64_t dim; int64_t dim;
@ -230,14 +228,14 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
dim, options.dimSizeIndexBits) options.dimSizeIndexBits)
.value(); .value();
if (keepDim) { if (keepDim) {
@ -249,13 +247,13 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], op, typeConverter->convertType(op.getType()), mhloReduceResults[1],
outShapeTensor); outShapeTensor);
return success(); return success();
} }
rewriter.replaceOp(op, stablehloReduceResults[1]); rewriter.replaceOp(op, mhloReduceResults[1]);
return success(); return success();
} }
} // namespace } // namespace
@ -269,8 +267,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -282,7 +279,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxDimOp to StableHLO"); "AtenMaxDimOp to MHLO");
} }
RankedTensorType valResultType = getTypeConverter() RankedTensorType valResultType = getTypeConverter()
@ -311,14 +308,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
const auto &options = getOptions(); const auto &options = getOptions();
auto inputShapeInfo = auto inputShapeInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
dim, options.dimSizeIndexBits) options.dimSizeIndexBits)
.value(); .value();
if (keepDim) { if (keepDim) {
@ -330,21 +327,15 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
auto stablehloReduceValueResult = auto mhloReduceValueResult = rewriter.create<mhlo::DynamicReshapeOp>(
rewriter.create<stablehlo::DynamicReshapeOp>( op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor);
op->getLoc(), valResultType, stablehloReduceResults[0], auto mhloReduceIndexResult = rewriter.create<mhlo::DynamicReshapeOp>(
outShapeTensor); op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor);
auto stablehloReduceIndexResult = rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult});
rewriter.create<stablehlo::DynamicReshapeOp>(
op->getLoc(), idxResultType, stablehloReduceResults[1],
outShapeTensor);
rewriter.replaceOp(
op, {stablehloReduceValueResult, stablehloReduceIndexResult});
return success(); return success();
} }
rewriter.replaceOp(op, rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]});
{stablehloReduceResults[0], stablehloReduceResults[1]});
return success(); return success();
} }
} // namespace } // namespace
@ -361,14 +352,12 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
op, "only Tensor types supported in StableHLO");
} }
if (inputTy.getElementType() != outTy.getElementType()) { if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type. // Use output element type as computation type.
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = input.getType().dyn_cast<RankedTensorType>();
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -381,7 +370,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumOp to StableHLO"); "AtenSumOp to MHLO");
} }
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
@ -390,14 +379,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
} }
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) if (!initValue) return failure();
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>( auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock(); Block &block = mhloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
@ -409,13 +397,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<stablehlo::AddOp>( Value addResult = rewriter.create<mhlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult); rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
stablehloReduceOp.getResults()); mhloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
@ -429,8 +417,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
op, "only Tensor types supported in StableHLO");
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
if (!inputElemTy.isIntOrFloat()) { if (!inputElemTy.isIntOrFloat()) {
@ -442,7 +429,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenMaxOp to StableHLO"); "AtenMaxOp to MHLO");
} }
SmallVector<int64_t> dims; SmallVector<int64_t> dims;
@ -452,13 +439,12 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) if (!initValue) return failure();
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>( auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock(); Block &block = mhloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
@ -470,14 +456,14 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value maxResult = rewriter.create<stablehlo::MaxOp>( Value maxResult = rewriter.create<mhlo::MaxOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), maxResult); rewriter.create<mhlo::ReturnOp>(op->getLoc(), maxResult);
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>( rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults()); mhloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
@ -494,14 +480,12 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
->convertType(op.getType()) ->convertType(op.getType())
.template dyn_cast<RankedTensorType>(); .template dyn_cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO");
op, "only Tensor types supported in StableHLO");
} }
if (inputTy.getElementType() != outTy.getElementType()) { if (inputTy.getElementType() != outTy.getElementType()) {
// Use output element type as computation type. // Use output element type as computation type.
auto dstElemTy = outTy.getElementType(); auto dstElemTy = outTy.getElementType();
input = input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), input, dstElemTy);
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>(); inputTy = input.getType().dyn_cast<RankedTensorType>();
} }
auto inputElemTy = inputTy.getElementType(); auto inputElemTy = inputTy.getElementType();
@ -515,7 +499,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
inputElemTy.getIntOrFloatBitWidth() == 8) { inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from " op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenSumDimIntListOp to StableHLO"); "AtenSumDimIntListOp to MHLO");
} }
SmallVector<int64_t> inputDims; SmallVector<int64_t> inputDims;
@ -541,14 +525,13 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
Value initValue = Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue) if (!initValue) return failure();
return failure();
llvm::sort(dims.begin(), dims.end()); llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>( auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
Region &region = stablehloReduceOp.getBody(); Region &region = mhloReduceOp.getBody();
Block &block = region.emplaceBlock(); Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -561,15 +544,15 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
Value addResult = rewriter.create<stablehlo::AddOp>( Value addResult = rewriter.create<mhlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult); rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult);
} }
if (keepDim) { if (keepDim) {
const auto &options = getOptions(); const auto &options = getOptions();
auto outShapeInfo = auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); options.dimSizeIndexBits);
if (failed(outShapeInfo)) { if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -584,27 +567,26 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResult(0), outShapeTensor); mhloReduceOp.getResult(0), outShapeTensor);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy,
stablehloReduceOp.getResults()); mhloReduceOp.getResults());
return success(); return success();
} }
} // namespace } // namespace
// AtenFrobeniusNormDimOp // AtenFrobeniusNormDimOp
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given // aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims)
// dims) // + mhlo.sqrt
// + stablehlo.sqrt
namespace { namespace {
template <> template <>
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite( LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
AtenFrobeniusNormDimOp op, OpAdaptor adaptor, AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
const TorchToStablehloOptions &options = getOptions(); const TorchToMhloOptions &options = getOptions();
Value input = adaptor.getSelf(); Value input = adaptor.getSelf();
auto inputType = input.getType().dyn_cast<RankedTensorType>(); auto inputType = input.getType().dyn_cast<RankedTensorType>();
@ -642,57 +624,58 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
op, "non-const bool `keepdim` is not supported"); op, "non-const bool `keepdim` is not supported");
} }
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter);
if (!initValue) { if (!initValue) {
return failure(); return failure();
} }
auto reduceOp = rewriter.create<stablehlo::ReduceOp>( auto squareSumReduceOp = rewriter.create<mhlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue, op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims));
rewriter.getI64TensorAttr(dims));
Region &region = reduceOp.getBody(); Region &region = squareSumReduceOp.getBody();
Block &block = region.emplaceBlock(); Block &block = region.emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputElemType); auto blockArgumentTy = RankedTensorType::get({}, inputElemType);
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc());
auto firstArgument = *block.args_begin(); auto *firstArgument = block.args_begin();
auto secondArgument = *block.args_rbegin(); auto secondArgument = block.args_rbegin();
{ {
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
auto addResult = rewriter.create<stablehlo::AddOp>( auto constantOrd2 = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), firstArgument, secondArgument); op->getLoc(), blockArgumentTy,
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), addResult.getResult()); DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef<float>{2.0}));
auto abs = rewriter.create<mhlo::AbsOp>(op->getLoc(), *secondArgument);
auto squareResult = rewriter.create<mhlo::PowOp>(
op->getLoc(), abs, constantOrd2);
auto addResult = rewriter.create<mhlo::AddOp>(op->getLoc(), squareResult,
*firstArgument);
rewriter.create<mhlo::ReturnOp>(op->getLoc(), addResult.getResult());
} }
auto output = auto output = rewriter.create<mhlo::SqrtOp>(op->getLoc(),
rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceOp.getResult(0)); squareSumReduceOp.getResult(0));
if (keepDim) { if (keepDim) {
auto outShapeInfo = auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(outShapeInfo)) { if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto outShapeVec = *outShapeInfo; auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>( auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1)); rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) { for (int64_t i : dims) {
outShapeVec[i] = one; outShapeVec[i] = one;
} }
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), output, op, getTypeConverter()->convertType(op.getType()), output,
outShapeTensor); outShapeTensor);
return success(); return success();
@ -702,9 +685,9 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
} }
} // namespace } // namespace
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \

View File

@ -7,18 +7,17 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -31,18 +30,17 @@ using namespace mlir::torch::Torch;
namespace { namespace {
class ConvertTorchToStablehlo class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
: public ConvertTorchToStablehloBase<ConvertTorchToStablehlo> {
public: public:
ConvertTorchToStablehlo() = default; ConvertTorchToMhlo() = default;
ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) { ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape; this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index; this->enableI32Index = enableI32Index;
} }
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>(); registry.insert<chlo::ChloDialect>();
registry.insert<stablehlo::StablehloDialect>(); registry.insert<mhlo::MhloDialect>();
registry.insert<tensor::TensorDialect>(); registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>(); registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry); TorchConversion::getBackendTypeConversionDependentDialects(registry);
@ -50,7 +48,7 @@ public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, stablehlo::StablehloDialect, target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
tensor::TensorDialect, arith::ArithDialect>(); tensor::TensorDialect, arith::ArithDialect>();
TypeConverter typeConverter; TypeConverter typeConverter;
@ -59,20 +57,20 @@ public:
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
torch_to_stablehlo::TorchToStablehloOptions options{ torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
enableStaticShape, enableI32Index ? 32u : 64u}; enableI32Index ? 32u : 64u};
torch_to_stablehlo::populateBasicOpPatternsAndLegality( torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
typeConverter, patterns, target, options); target, options);
torch_to_stablehlo::populateGatherOpPatternsAndLegality( torch_to_mhlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateLinearOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
typeConverter, patterns, target, options); typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target, options);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
@ -84,13 +82,13 @@ public:
} // namespace } // namespace
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStablehloPass() { mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToStablehlo>(false, false); return std::make_unique<ConvertTorchToMhlo>(false, false);
} }
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape, mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
bool enableI32Index) { bool enableI32Index) {
return std::make_unique<ConvertTorchToStablehlo>(enableStaticShape, return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
enableI32Index); enableI32Index);
} }

View File

@ -7,15 +7,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "PopulatePatterns.h" #include "./MhloLegalizeUtils.h"
#include "StablehloLegalizeUtils.h" #include "./PopulatePatterns.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -29,7 +28,7 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_stablehlo; using namespace mlir::torch::torch_to_mhlo;
namespace { namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize]. // A dimension index from torch.dialect might outside the range [0, dimSize].
@ -101,7 +100,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
auto stridesTensor = auto stridesTensor =
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult(); rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
return rewriter.create<stablehlo::RealDynamicSliceOp>( return rewriter.create<mhlo::RealDynamicSliceOp>(
loc, outTy, input, startTensor, endTensor, stridesTensor); loc, outTy, input, startTensor, endTensor, stridesTensor);
} }
@ -145,7 +144,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
step = rewriter.create<arith::TruncIOp>(loc, intType, step); step = rewriter.create<arith::TruncIOp>(loc, intType, step);
} }
FailureOr<SmallVector<Value, 4>> dimSizesInfo = FailureOr<SmallVector<Value, 4>> dimSizesInfo =
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo)) if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -180,7 +179,7 @@ public:
auto loc = op.getLoc(); auto loc = op.getLoc();
auto newRank = dimSizes.size(); auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) { if (newRank == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -215,18 +214,17 @@ public:
numel); numel);
if (dimSizes.size() == 0) { if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
adaptor.getSelf()); adaptor.getSelf());
return success(); return success();
} }
Value stablehloShape = Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
rewriter.create<tensor::FromElementsOp>(loc, dimSizes); Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>( loc, mhloShape.getType(), numel, mhloShape);
loc, stablehloShape.getType(), numel, stablehloShape); rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()), op.getType()),
@ -317,21 +315,21 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
dims.push_back(r); dims.push_back(r);
} }
if (dims.size() == 0) { if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} }
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo; auto newDimSizes = *newDimSizesInfo;
auto stablehloShape = auto mhloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
return success(); return success();
} }
@ -367,20 +365,20 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim); dims.erase(dims.begin() + dim);
if (dims.size() == 0) { if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} }
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits); options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
auto newDimSizes = *newDimSizesInfo; auto newDimSizes = *newDimSizesInfo;
auto stablehloShape = auto mhloShape =
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
return success(); return success();
} }
@ -397,7 +395,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant"); return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
{dim}, options.dimSizeIndexBits); {dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo)) if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
@ -407,9 +405,9 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
return success(); return success();
} }
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \

View File

@ -1,29 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Pooling.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRBufferTransforms
StablehloOps
TorchMLIRTorchDialect
TorchMLIRConversionUtils
)
torch_mlir_target_includes(TorchMLIRTorchToStablehlo)

View File

@ -1,69 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace torch {
namespace torch_to_stablehlo {
struct TorchToStablehloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToStablehloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToStablehloOptions &getOptions() const { return options; }
private:
TorchToStablehloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToStablehloOptions &options);
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
} // namespace torch_to_stablehlo
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H

View File

@ -17,7 +17,6 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ValueRange.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
@ -27,9 +26,6 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/ErrorHandling.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -56,147 +52,6 @@ using namespace mlir::torch::TMTensor;
// that these patterns become mostly mechanical associations of // that these patterns become mostly mechanical associations of
// "aten.foo -> linalg.foo". // "aten.foo -> linalg.foo".
static Attribute getNumericLimit(PatternRewriter &rewriter, Type elementType,
bool getMin = true) {
auto bitWidth = elementType.getIntOrFloatBitWidth();
if (llvm::isa<mlir::IntegerType>(elementType)) {
if (getMin) {
return rewriter.getIntegerAttr(elementType,
APInt::getSignedMinValue(bitWidth));
} else {
return rewriter.getIntegerAttr(elementType,
APInt::getSignedMaxValue(bitWidth));
}
} else if (mlir::FloatType floatType =
llvm::dyn_cast<mlir::FloatType>(elementType)) {
return rewriter.getFloatAttr(
elementType,
APFloat::getLargest(floatType.getFloatSemantics(), getMin));
} else {
llvm_unreachable("Only float/integer types are supported!");
}
}
// This function will reformat the `index` and `src` from torch operations
// like `torch.scatter` or `torch.scatter_reduce` to match the expected
// input for the TMScatterOp. It will return the reformated `index` and `src`
// as a pair of mlir::Value that can be used as inputs for the TMScatterOp.
static std::pair<Value, Value>
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
Value indices, Value src,
int64_t dim) {
// Get information on types for inputs
RankedTensorType indexType = indices.getType().cast<RankedTensorType>();
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>();
// Store location for insertions
Location loc = src.getLoc();
Value indexSize = getTensorSize(rewriter, loc, indices);
indexSize = castIntToIndex(rewriter, loc, indexSize);
SmallVector<Value> indexShape = getTensorSizes(rewriter, loc, indices);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
// We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...)
SmallVector<Value> indSliceShape({indexSize, cstOne});
Value indSlice =
createZeroInitTensor(rewriter, loc, indSliceShape, rewriter.getI32Type());
// New output shape will be equal to the product of the dimensions of the
// updates
SmallVector<Value> outputs(indexType.getRank(), indSlice);
outputs.push_back(createZeroInitTensor(rewriter, loc, {indexSize},
srcSelf.getElementType()));
SmallVector<Type> outputsType(indexType.getRank(), indSlice.getType());
outputsType.push_back(outputs[indexType.getRank()].getType());
// Create mapping over flattened iteration space
SmallVector<AffineExpr> indSliceExpr = {rewriter.getAffineDimExpr(0),
rewriter.getAffineConstantExpr(0)};
SmallVector<AffineMap> mapping(
indexType.getRank(), AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
indSliceExpr, src.getContext()));
// Mapping for updates
mapping.push_back(rewriter.getDimIdentityMap());
SmallVector<utils::IteratorType> iteratorTypes(
{utils::IteratorType::parallel});
// This function goes over the flattened iteration space of the `indices`
// and `src`. It will reconstruct the original induction variables based
// on the current flattened index. The flattened iteration space is required
// because TMTensorScatterOp expects a list of single element updates.
auto flattenedUpdates =
rewriter
.create<linalg::GenericOp>(
loc, outputsType, ValueRange(), outputs, mapping, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indexValues(indexType.getRank());
Value ind = b.create<linalg::IndexOp>(loc, 0);
for (int i = indexType.getRank() - 1; i >= 0; i--) {
indexValues[i] =
b.create<arith::RemSIOp>(loc, ind, indexShape[i]);
ind = b.create<arith::DivSIOp>(loc, ind, indexShape[i]);
}
// Extract the scatter index and update value
Value extractIndexValue =
b.create<tensor::ExtractOp>(loc, indices, indexValues);
Value extractSrcValue =
b.create<tensor::ExtractOp>(loc, src, indexValues);
SmallVector<Value> yieldVals;
for (Value v : indexValues) {
Value scalar = castIndexToInt64(b, loc, v);
yieldVals.push_back(b.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), scalar));
}
// Replace the original index with the index specified
// by the scatter.
yieldVals[dim] = b.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), extractIndexValue);
yieldVals.push_back(extractSrcValue);
b.create<linalg::YieldOp>(loc, yieldVals);
})
.getResultTensors();
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
// The result of the linalg::Generic operation gives us (rank(`src`) + 1)
// 1D-tensors where each contains a number of elements equal to the total
// number of elements in the `src` tensor. The indices must now be
// constructed by concatanating the first rank(`src`) tensors together. The
// new `src` tensor is the last tensor returned from the linalg::Generic
// operation.
SmallVector<Value> offsets = {
rewriter.create<arith::ConstantIndexOp>(loc, 0),
rewriter.create<arith::ConstantIndexOp>(loc, 0)};
SmallVector<Value> strides = {
rewriter.create<arith::ConstantIndexOp>(loc, 1),
rewriter.create<arith::ConstantIndexOp>(loc, 1)};
Value indicesRank =
rewriter.create<arith::ConstantIndexOp>(loc, indexType.getRank());
Value flattenedIndices = createZeroInitTensor(
rewriter, loc, SmallVector<Value>({indexSize, indicesRank}),
rewriter.getI32Type());
SmallVector<Value> scatterInputsVector(flattenedUpdates);
for (auto const slice : ArrayRef(scatterInputsVector).drop_back()) {
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, slice);
flattenedIndices = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, slice, flattenedIndices,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
// Increment offset to insert into next column
offsets[1] = rewriter.createOrFold<arith::AddIOp>(loc, offsets[1], cstOne);
}
return std::make_pair(flattenedIndices,
scatterInputsVector[indexType.getRank()]);
}
static Value createTMTensorScatterOp( static Value createTMTensorScatterOp(
OpBuilder &b, Location loc, Value updates, Value indices, Value original, OpBuilder &b, Location loc, Value updates, Value indices, Value original,
bool uniqueIndices, bool uniqueIndices,
@ -287,7 +142,7 @@ public:
// Finding the maximum value in the input tensor. // Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes; SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get( ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::ArrayRef(maxTensorSizes), context, llvm::makeArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value maxTensor = Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput); rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
@ -310,7 +165,7 @@ public:
SmallVector<int64_t> expandedInputSizes{ SmallVector<int64_t> expandedInputSizes{
makeShapeTorchCompatible(inputType.getShape())[0], 1}; makeShapeTorchCompatible(inputType.getShape())[0], 1};
ValueTensorType expandInputType = ValueTensorType::get( ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::ArrayRef(expandedInputSizes), context, llvm::makeArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
@ -431,8 +286,8 @@ public:
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>(); auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
int64_t indexTensorSize = indexTensorType.getSizes()[0]; int64_t indexTensorSize = indexTensorType.getSizes()[0];
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1}; SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
ValueTensorType expandedIndexTensorType = ValueTensorType expandedIndexTensorType = ValueTensorType::get(
ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes), context, llvm::makeArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype()); indexTensorType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
@ -697,229 +552,6 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenScatterReduceTwoOp
: public OpConversionPattern<AtenScatterReduceTwoOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScatterReduceTwoOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
RankedTensorType selfType =
adaptor.getSelf().getType().cast<RankedTensorType>();
RankedTensorType indexType =
adaptor.getIndex().getType().cast<RankedTensorType>();
RankedTensorType srcType =
adaptor.getSrc().getType().cast<RankedTensorType>();
Value self = adaptor.getSelf();
if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(op,
"'self', 'index' and 'src' should all "
"have the same number of dimensions.");
std::string reduceType;
if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceType)))
return rewriter.notifyMatchFailure(op,
"'reduce' must be a costant string");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "'dim' is not constant");
bool includeSelf;
if (!matchPattern(op.getIncludeSelf(), m_TorchConstantBool(&includeSelf)))
return rewriter.notifyMatchFailure(op, "'include_self' is not constant");
// Get reduce string as the equivalent enum
auto reduceEnum = torch_upstream::get_reduction_enum(reduceType);
// Get the inputs reformatted for the TMScatterOp
auto [indices, updates] =
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(
rewriter, adaptor.getIndex(), adaptor.getSrc(), dim);
// Value 'counts' will be used to tally the number of reductions into
// each unique index. The tally is used to calculate the average of the
// values scattered per index.
Value counts = nullptr;
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
SmallVector<Value> selfShape =
getTensorSizes(rewriter, loc, adaptor.getSelf());
Attribute initAttr;
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
initAttr = rewriter.getFloatAttr(srcType.getElementType(), 1);
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
initAttr = rewriter.getIntegerAttr(srcType.getElementType(), 1);
} else {
llvm_unreachable("Only integer/float types supported!");
}
Value initElement = rewriter.create<arith::ConstantOp>(loc, initAttr);
counts = createInitTensor(rewriter, loc, selfShape,
selfType.getElementType(), initElement);
}
// If the original values shouldn't be included, normalize the
// input tensor where the scatters take place.
if (!includeSelf) {
Value normalizationValue;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) {
// Set the values in the input tensor to '0' so they are not included
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(srcType.getElementType()));
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
// Set the values in the input tensor to '1' (multiplication identity)
if (llvm::isa<mlir::FloatType>(srcType.getElementType())) {
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(srcType.getElementType(), 1.0));
} else if (llvm::isa<mlir::IntegerType>(srcType.getElementType())) {
normalizationValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(srcType.getElementType(), 1));
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
// Set the values in the input tensor to the smallest element of that
// type
auto minAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/true);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, minAttr);
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
// Set the values in the input tensor to the largest element of that
// type
auto maxAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/false);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, maxAttr);
}
// Scatter the normalizations into the input tensor
Value indexSize = getTensorSize(rewriter, loc, adaptor.getIndex());
indexSize = castIntToIndex(rewriter, loc, indexSize);
Value normalizations = createInitTensor(
rewriter, loc, SmallVector<Value>({indexSize}),
srcType.getElementType(), /*init_element=*/normalizationValue);
self = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, self,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
});
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, normalizations, indices, counts,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
b.create<TMTensor::YieldOp>(loc, update);
});
}
}
// Create final operation
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updates, indices, self,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;
if (reduceEnum == torch_upstream::ReductionType::SUM ||
reduceEnum == torch_upstream::ReductionType::MEAN) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::AddIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::AddFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::PROD) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MulIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MulFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MAX) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MaxSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MaxFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
} else if (reduceEnum == torch_upstream::ReductionType::MIN) {
if (update.getType().isa<mlir::IntegerType>()) {
result = b.create<arith::MinSIOp>(loc, update, current);
} else if (update.getType().isa<mlir::FloatType>()) {
result = b.create<arith::MinFOp>(loc, update, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
}
b.create<TMTensor::YieldOp>(loc, result);
});
// Special case for the mean
if (reduceEnum == torch_upstream::ReductionType::MEAN) {
counts = createTMTensorScatterOp(
rewriter, loc, updates, indices, counts,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value update, Value current) {
Value result;
if (mlir::IntegerType intType =
llvm::dyn_cast<mlir::IntegerType>(current.getType())) {
Value constantUpdate = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(intType, 1));
result = b.create<arith::AddIOp>(loc, constantUpdate, current);
} else if (mlir::FloatType floatType =
llvm::dyn_cast<mlir::FloatType>(current.getType())) {
Value constantUpdate = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(floatType, 1.0));
result = b.create<arith::AddFOp>(loc, constantUpdate, current);
} else {
llvm_unreachable("Only integer/float types supported!");
}
b.create<TMTensor::YieldOp>(loc, result);
});
Value output = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, self),
selfType.getElementType());
// Finally divide the result
scatterOp =
rewriter
.create<linalg::MapOp>(
loc, ValueRange{scatterOp, counts}, output,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result;
if (llvm::isa<mlir::IntegerType>(args[0].getType())) {
result = b.create<arith::DivSIOp>(loc, args[0], args[1]);
} else if (llvm::isa<mlir::FloatType>(args[0].getType())) {
result = b.create<arith::DivFOp>(loc, args[0], args[1]);
} else {
llvm_unreachable("Only integer/float types supported!");
}
b.create<linalg::YieldOp>(loc, result);
})
.getResult()[0];
}
auto resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> { class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public: public:
@ -1012,8 +644,6 @@ public:
target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>(); target.addIllegalOp<AtenMaxPool2dWithIndicesBackwardOp>();
patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter, patterns.add<ConvertAtenMaxPool2dWithIndicesBackwardOp>(typeConverter,
context); context);
target.addIllegalOp<AtenScatterReduceTwoOp>();
patterns.add<ConvertAtenScatterReduceTwoOp>(typeConverter, context);
target.addIllegalOp<AtenCumsumOp>(); target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertAtenCumsumOp>(typeConverter, context); patterns.add<ConvertAtenCumsumOp>(typeConverter, context);

View File

@ -10,7 +10,6 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
@ -718,8 +717,8 @@ class ConvertAtenMultipleDimsReductionOp
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
int64_t N = reduceDims.size(); int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); llvm::makeArrayRef(reduceDims));
keepDims = false; keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -748,8 +747,8 @@ class ConvertAtenOneDimReductionOp
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim})); llvm::makeArrayRef({reduceDim}));
keepDims = false; keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -782,8 +781,8 @@ public:
reduceDims.push_back(i); reduceDims.push_back(i);
int64_t N = selfTy.getRank(); int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); llvm::makeArrayRef(reduceDims));
keepDims = false; keepDims = false;
return success(); return success();
@ -2646,36 +2645,6 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants"); "size must consist of Scalar constants");
// the shape -1 is inferred from other dimensions
size_t countNegativeShape{0};
// Check at most one -1 shape
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] < 0) {
countNegativeShape++;
if (countNegativeShape > 1)
return rewriter.notifyMatchFailure(op, "At most one -1 shape");
}
}
auto inputShape = selfType.getShape();
size_t totalSize = 1;
for (size_t i = 0; i < inputShape.size(); i++) {
totalSize *= inputShape[i];
}
size_t otherSize = 1;
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] > 0) {
otherSize *= outShape[i];
}
}
for (size_t i = 0; i < outShape.size(); i++) {
if (outShape[i] < 0) {
outShape[i] = totalSize / otherSize;
break;
}
}
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(outShape)); rewriter.getDenseI64ArrayAttr(outShape));
@ -2847,79 +2816,6 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
AtenHardtanhBackwardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
if (!selfType) {
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");
}
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
}
// Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
if (selfIntType && selfIntType.getWidth() > 32) {
return rewriter.notifyMatchFailure(
op, "Integer types with width greater than 32 are not supported");
}
Value gradOutput = adaptor.getGradOutput();
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>();
Type gradOutputElemType = gradOutputType.getElementType();
if (selfElemTy != gradOutputElemType) {
return rewriter.notifyMatchFailure(
op,
"Input element type should be same as the grad_output element type.");
}
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
Value maxVal, minVal;
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMinVal(), minVal,
selfElemTy, constTypeShape))) {
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
if (failed(torchScalarToTosaTensor(rewriter, op, op.getMaxVal(), maxVal,
selfElemTy, constTypeShape))) {
return rewriter.notifyMatchFailure(op, "Only scalar constant is supported");
}
Value replace = tosa::getConstTensor<float>(rewriter, op, 0, {}).value();
Type outType = getTypeConverter()->convertType(op.getType());
Value lesser = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
minVal, adaptor.getSelf());
Value greater = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
adaptor.getSelf(), maxVal);
Value cmp = rewriter.create<tosa::LogicalOrOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
lesser, greater);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmp, replace,
gradOutput);
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
AtenEmbeddingOp op, OpAdaptor adaptor, AtenEmbeddingOp op, OpAdaptor adaptor,
@ -3217,71 +3113,32 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported"); op, "Only floating-point or integer datatype legalization supported");
} }
SmallVector<int64_t> resultShape; SmallVector<int64_t> outShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants"); "size must consist of Scalar constants");
// Get the result type
auto resultType = getTypeConverter()->convertType(op.getType());
SmallVector<int64_t> inputShape( SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape())); makeShapeTorchCompatible(selfType.getShape()));
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly. // true then we can replace the op result with the input operand
if (llvm::equal(inputShape, resultShape)) { // irrespective of the users of the op result.
// If we reach here, then it means that the broadcasting is not required if (!llvm::equal(inputShape, outShape)) {
// since the input and result are of same shape. for (auto user : op->getResult(0).getUsers()) {
// This case is only supported if the result of the `broadcast_to` op is
// not used by an op which is a view like.
if (isViewLikeOp(user)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: broadcast not supported for this case");
}
}
}
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
op.replaceAllUsesWith(op.getSelf()); op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
// can check for broadcasting compatibility for the latter case. For
// broadcasting compatibility, either the shape of input and result should
// be equal at each dimenion or one of them should be 1.
if (selfType.getRank() != 0) {
for (unsigned i = 0; i < inputShape.size(); i++) {
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
}
}
}
// If the above condition hold true then we can directly create a const
// zero tensor of shape same as the result shape.
SmallVector<int64_t> zeroTensorShape{resultShape};
// create the 0 constant tensor
int64_t totalNumElements = 1;
for (auto dimSize : zeroTensorShape) {
totalNumElements = dimSize * totalNumElements;
}
// There is some danger here. For edge cases in floating point, x + 0 != x.
// The cases are denormalized values, which may get flushed, and -0 + 0 =
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
SmallVector<int64_t> intValues(totalNumElements, 0);
SmallVector<float> floatValues(totalNumElements, 0.0);
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
? tosa::getConstTensor<float>(
rewriter, op, floatValues, zeroTensorShape)
.value()
: tosa::getConstTensor<int64_t>(
rewriter, op, intValues, zeroTensorShape)
.value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
zeroTensor);
return success();
} }
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op,
@ -3375,171 +3232,6 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10],
// [11,12,13,14,15],[16,17,18,19,20]]) # 4*5
// i = tf.constant([[1,2,3], [3,2,1]]) # 2*3
// i_expand = tf.expand_dims(i,axis=2) # 2*3*1
// IndexTensorOutput = tf.gather_nd(t,tf.i_expand)
// = torch.ops.aten.index(t, (i, )) = t[i] # 2*3*5
// [[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
auto input = adaptor.getSelf();
auto inputTensorType =
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
// Deal with torch.prim.ListConstruct of non const value to get the index
auto tensorList = op.getIndices();
SmallVector<Value> tensorsTorchType;
if (!getListConstructElements(tensorList, tensorsTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType);
auto outType = getTypeConverter()->convertType(op.getType());
// Support for multiple indexes
if (indexTensors.size() > 1) {
// t[i, i]
// = torch.ops.aten.index(t,(i,i))
// = tensor([[ t[1,1], t[2,2], t[3,3]],
// [ t[3,3], t[2,2], t[1,1]]])
// = tensor([[ 7, 13, 19], [19, 13, 7]])
// = tf.gather_nd(t,tf.ii_expand)
// ii_expand
// = tf.concat((i_expand,i_expand), dim=2)
// = tf.constant([[[1,1],[2,2],[3,3]],
// [[3,3],[2,2],[1,1]]]) # 2*3*2
SmallVector<Value> indicesTfConcatTensors;
SmallVector<int64_t> indexesRank;
SmallVector<SmallVector<int64_t>> indexesShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto index = indexTensors[i];
auto indexTorch = tensorsTorchType[i];
// TODO add support for none index input like torch.ops.aten.index(x,
// (None, index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
indexesRank.push_back(indexType.getRank());
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
indiceShapeOneDim.push_back(shape);
}
indiceShapeOneDim.push_back(1);
auto indicesTfOneDim = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)),
index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim));
// create concat tensor for indicesTf
indicesTfConcatTensors.push_back(indicesTfOneDim.getResult());
}
// Right now only support multiple indexes with same shape
// TODO for different shape multiple indexes, add broadcast_to for small
// shape
for (auto indexShapeOneDim : indexesShape) {
if (!llvm::equal(indexesShape[0], indexShapeOneDim)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: Only support multi indexes with same shape");
}
}
// concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2]
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);
if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
// Support for multiple index
auto index = indexTensors[0];
auto indexTorch = tensorsTorchType[0];
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexType = index.getType().dyn_cast<RankedTensorType>();
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
}
rewriter.replaceOp(op, {result.value()});
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor, AtenWhereSelfOp op, OpAdaptor adaptor,
@ -4276,11 +3968,9 @@ public:
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() && if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.getMemoryFormat(), (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) || m_TorchConstantInt(&memoryFormat)) ||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous && memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
return op.emitError( return op.emitError(
"unimplemented: only contiguous and channels last memory " "unimplemented: only default memory format is supported");
"format is supported");
} }
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter() auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
@ -4479,7 +4169,6 @@ public:
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp); INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
@ -4507,7 +4196,6 @@ public:
INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenClampOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);

View File

@ -230,10 +230,6 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
(src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) || (src.isInteger(32) && dest.isF32()) ||
(src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) ||
(src.isF32() && dest.isF64()) ||
(src.isF64() && dest.isF32()) ||
(src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(1))) { (src.isF32() && dest.isInteger(1))) {
return success(); return success();

View File

@ -11,7 +11,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
@ -32,11 +31,11 @@ namespace {
struct TorchInlinerInterface : public DialectInlinerInterface { struct TorchInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface; using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final { BlockAndValueMapping &valueMapping) const final {
return true; return true;
} }
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
IRMapping &) const final { BlockAndValueMapping &) const final {
return true; return true;
} }
}; };

View File

@ -128,36 +128,32 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value); return FloatAttr::get(Float64Type::get(context), value);
} }
static Value getScalarIntValue(Value input, Location loc, static Value getScalarValue(Value input, Location loc,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
auto inputType = input.getType(); auto inputType = input.getType();
if (inputType.isa<Torch::IntType>()) { if (inputType.isa<Torch::IntType>()) {
return input; return input;
} }
Value scalar = nullptr;
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
if (!inputTensorType)
return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype || !inputDtype.isInteger(64))
return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
std::optional<unsigned> tensorRank =
getTensorRank(valueTensorLiteralOp.getResult());
if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) {
auto tensorType =
valueTensorLiteralOp.getValue().getType().cast<RankedTensorType>();
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
auto val = valueTensorLiteralOp.getValue() auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>() .cast<DenseElementsAttr>()
.getSplatValue<int64_t>(); .getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>( scalar = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
}
}
} else if (auto primNumToTensorScalarOp = } else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) { input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA(); scalar = primNumToTensorScalarOp.getA();
} }
return nullptr; return scalar;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -390,7 +386,7 @@ void PrimIfOp::getSuccessorRegions(std::optional<unsigned> index,
// If the condition is constant, we can give a more precise answer. // If the condition is constant, we can give a more precise answer.
if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) { if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
Region *executedRegion = Region *executedRegion =
condAttr.getValue().isOne() ? &getThenRegion() : &getElseRegion(); condAttr.getValue().isOneValue() ? &getThenRegion() : &getElseRegion();
regions.push_back(RegionSuccessor(executedRegion)); regions.push_back(RegionSuccessor(executedRegion));
return; return;
} }
@ -511,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]); return isValidSubtype(inputs[0], outputs[0]);
} }
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) { OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>(); auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
if (!uncheckedCast) if (!uncheckedCast)
return nullptr; return nullptr;
@ -574,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// Aten__RangeLengthOp // Aten__RangeLengthOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
auto lo = adaptor.getLo(); auto lo = operands[0];
auto hi = adaptor.getHi(); auto hi = operands[1];
auto step = adaptor.getStep(); auto step = operands[2];
if (!lo || !hi || !step) if (!lo || !hi || !step)
return nullptr; return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue(); auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
@ -599,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
// Aten__DeriveIndexOp // Aten__DeriveIndexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
auto index = adaptor.getIndex(); auto index = operands[0];
auto start = adaptor.getStart(); auto start = operands[1];
auto step = adaptor.getStep(); auto step = operands[2];
if (!index || !start || !step) if (!index || !start || !step)
return nullptr; return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue(); auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
@ -616,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
// Aten__Is__Op // Aten__Is__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
} }
@ -624,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
// Aten__Isnot__Op // Aten__Isnot__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
} }
@ -632,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
// Aten__Not__Op // Aten__Not__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
bool value; bool value;
if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
return nullptr; return nullptr;
@ -643,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
// AtenNeBoolOp // AtenNeBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
if (getOperand(0) == getOperand(1)) if (getOperand(0) == getOperand(1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), false); return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
@ -659,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
// AtenSqueezeOp // AtenSqueezeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(); return getOperand();
@ -671,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
// AtenSqueezeDimOp // AtenSqueezeDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0); return getOperand(0);
@ -683,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
// AtenRoundOp // AtenRoundOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) { if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>()) if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf(); return getSelf();
@ -695,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
// AtenTypeAsOp // AtenTypeAsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
Type inType = getSelf().getType(); Type inType = getSelf().getType();
Type newType = getOther().getType(); Type newType = getOther().getType();
@ -709,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
// AtenToDtypeOp // AtenToDtypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
bool nonBlocking, copyArg; bool nonBlocking, copyArg;
// The non_blocking arg must be `False`. // The non_blocking arg must be `False`.
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
@ -740,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
// AtenToDtypeLayoutOp // AtenToDtypeLayoutOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
// The pin_memory arg should be either constant `False` or `none`. // The pin_memory arg should be either constant `False` or `none`.
if (!getPinMemory().getType().isa<Torch::NoneType>()) { if (!getPinMemory().getType().isa<Torch::NoneType>()) {
bool pinMemory; bool pinMemory;
@ -801,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
// AtenViewOp // AtenViewOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>(); auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr; return nullptr;
@ -816,7 +812,7 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
// AtenDimOp // AtenDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes()) if (tensorType.hasSizes())
return IntegerAttr::get(IntegerType::get(getContext(), 64), return IntegerAttr::get(IntegerType::get(getContext(), 64),
@ -829,7 +825,7 @@ OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
// AtenLenTOp // AtenLenTOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
// `len([1,1,1])` -> `3`, if it is not mutated. // `len([1,1,1])` -> `3`, if it is not mutated.
if (auto listConstruct = if (auto listConstruct =
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) { getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
@ -857,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenLenStrOp // AtenLenStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>()) if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
return getI64IntegerAttr(getContext(), return getI64IntegerAttr(getContext(),
stringConstruct.getValueAttr().getValue().size()); stringConstruct.getValueAttr().getValue().size());
@ -873,24 +869,21 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
if (op->getNumOperands() < 2) { if (op->getNumOperands() < 2) {
return failure(); return failure();
} }
auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter); auto lhs = getScalarValue(op->getOperand(0), loc, rewriter);
auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter); auto rhs = getScalarValue(op->getOperand(1), loc, rewriter);
auto outType = op->getResult(0).getType(); auto outType = op->getResult(0).getType();
if (!lhs || !rhs) { if (!lhs || !rhs) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only int scalar lhs or rhs is supported"); op, "only int scalar lhs or rhs is supported");
} }
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp, if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
AtenAddScalarOp>(op)) { op)) {
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter); Value alpha = getScalarValue(op->getOperand(2), loc, rewriter);
if (!alpha) { if (!alpha) {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only int scalar alpha is supported"); "only int scalar alpha is supported");
} }
if (isa<AtenRsubScalarOp>(op))
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
else
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha); rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
} }
@ -944,8 +937,6 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs); result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) { } else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs); result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
} else if (isa<AtenRsubScalarOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) { } else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs); result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
} }
@ -993,16 +984,6 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// AtenRSubScalarOp
//===----------------------------------------------------------------------===//
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenMulTensorOp // AtenMulTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1033,23 +1014,6 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// AtenScalarImplicitOp
//===----------------------------------------------------------------------===//
void AtenScalarImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSizeOp // AtenSizeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1128,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenSizeIntOp // AtenSizeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
int64_t dim; int64_t dim;
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
return nullptr; return nullptr;
@ -1168,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
// AtenLtFloatOp // AtenLtFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a < b; }); [](double a, double b) { return a < b; });
} }
@ -1177,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
// AtenGtFloatOp // AtenGtFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a > b; }); [](double a, double b) { return a > b; });
} }
@ -1186,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
// AtenGeFloatOp // AtenGeFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a >= b; }); [](double a, double b) { return a >= b; });
} }
@ -1195,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
// AtenEqFloatOp // AtenEqFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a == b; }); [](double a, double b) { return a == b; });
} }
@ -1261,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
// AtenNeIntOp // AtenNeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a != b; }); [](int64_t a, int64_t b) { return a != b; });
} }
@ -1270,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
// AtenEqIntOp // AtenEqIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a == b; }); [](int64_t a, int64_t b) { return a == b; });
} }
@ -1279,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
// AtenEqStrOp // AtenEqStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
if (getOperand(0) == getOperand(1)) if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), true); return getI1IntegerAttr(getContext(), true);
@ -1295,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
// AtenLtIntOp // AtenLtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a < b; }); [](int64_t a, int64_t b) { return a < b; });
} }
@ -1304,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
// AtenLeIntOp // AtenLeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a <= b; }); [](int64_t a, int64_t b) { return a <= b; });
} }
@ -1313,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
// AtenGtIntOp // AtenGtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a > b; }); [](int64_t a, int64_t b) { return a > b; });
} }
@ -1322,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
// AtenGeIntOp // AtenGeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a >= b; }); [](int64_t a, int64_t b) { return a >= b; });
} }
@ -1331,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
// AtenBoolFloatOp // AtenBoolFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
double c; double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI1IntegerAttr(getContext(), c != 0.0); return getI1IntegerAttr(getContext(), c != 0.0);
@ -1342,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
// AtenBoolIntOp // AtenBoolIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI1IntegerAttr(getContext(), c != 0); return getI1IntegerAttr(getContext(), c != 0);
@ -1353,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
// AtenFloatScalarOp // AtenFloatScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
// Constant fold int -> float conversion. // Constant fold int -> float conversion.
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) { if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
return FloatAttr::get( return FloatAttr::get(
mlir::Float64Type::get(getContext()), mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue())); static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -1366,27 +1330,13 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenIntScalarOp // AtenIntScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
// Constant fold float -> int conversion. // Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) { if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get( return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<long>(floatAttr.getValue().convertToDouble())); static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -1397,18 +1347,6 @@ OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenIntBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b));
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSortIntOp // AtenSortIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1502,7 +1440,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
return success(); return success();
} }
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr(); return getValueAttr();
} }
@ -1607,7 +1545,7 @@ void CopyToValueTensorOp::getEffects(
// ConstantNoneOp // ConstantNoneOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) { OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
return TypeAttr::get(Torch::NoneType::get(getContext())); return TypeAttr::get(Torch::NoneType::get(getContext()));
} }
@ -1620,7 +1558,9 @@ void ConstantNoneOp::getAsmResultNames(
// ConstantStrOp // ConstantStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
void ConstantStrOp::getAsmResultNames( void ConstantStrOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { function_ref<void(Value, StringRef)> setNameFn) {
@ -1658,7 +1598,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
} }
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) { OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr(); return getValueAttr();
} }
@ -1674,7 +1614,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
// ConstantFloatOp // ConstantFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr(); return getValueAttr();
} }
@ -1704,7 +1644,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
// ConstantNumberOp // ConstantNumberOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) { OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr(); return getValueAttr();
} }
@ -1732,7 +1672,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
// ConstantBoolOp // ConstantBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) { OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr(); return getValueAttr();
} }
@ -1750,7 +1690,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(outputs[0], inputs[0]); return isValidSubtype(outputs[0], inputs[0]);
} }
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) { if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
if (derefineOp.getOperand().getType() == getType()) if (derefineOp.getOperand().getType() == getType())
return derefineOp.getOperand(); return derefineOp.getOperand();
@ -1884,7 +1824,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenEqIntListOp // AtenEqIntListOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>(); auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsLiteral) if (!lhsLiteral)
return nullptr; return nullptr;
@ -1909,20 +1849,6 @@ OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// PrimTupleConstructOp
//===----------------------------------------------------------------------===//
LogicalResult PrimTupleConstructOp::verify() {
if (!(isValidSubtype(
Torch::TupleType::get(getContext(),
llvm::to_vector<6>(getElements().getType())),
getResult().getType())))
return emitOpError(
"failed to verify that contained types correspond to operand types");
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PrimTupleIndexOp // PrimTupleIndexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2024,7 +1950,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
// Aten__Getitem__DictStrOp // Aten__Getitem__DictStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
auto dictConstruct = getDictConstructIfNotModified(getSelf()); auto dictConstruct = getDictConstructIfNotModified(getSelf());
if (!dictConstruct) if (!dictConstruct)
return nullptr; return nullptr;
@ -2042,7 +1968,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
// Aten__Contains__StrOp // Aten__Contains__StrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
auto dictConstruct = getDictConstructIfNotModified(getDict()); auto dictConstruct = getDictConstructIfNotModified(getDict());
if (!dictConstruct) if (!dictConstruct)
return nullptr; return nullptr;
@ -2065,7 +1991,7 @@ static bool isListConstructNotModified(Value torchList) {
}); });
} }
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) { OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
auto itemConstruct = getItem(); auto itemConstruct = getItem();
if (!isListConstructNotModified(getL())) if (!isListConstructNotModified(getL()))
return nullptr; return nullptr;
@ -2126,55 +2052,43 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenFloordivIntOp // AtenFloordivIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenRemainderIntOp // AtenRemainderIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); operands, [](int64_t a, int64_t b) { return a % b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenAddIntOp // AtenAddIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); operands, [](int64_t a, int64_t b) { return a + b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubIntOp // AtenSubIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); operands, [](int64_t a, int64_t b) { return a - b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenCatOp // AtenCatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
return list.getElements()[0];
}
//===----------------------------------------------------------------------===//
// AtenStackOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>(); auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1) if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr; return nullptr;
@ -2185,7 +2099,7 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) {
// AtenSliceTensorOp // AtenSliceTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>(); auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>(); auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
@ -2204,7 +2118,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
// AtenMulIntOp // AtenMulIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
int64_t lhs, rhs; int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2215,70 +2129,46 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenSubFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) {
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return a - b; });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubOp // AtenSubOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
if (!adaptor.getA() || !adaptor.getB()) { if (!operands[0] || !operands[1]) {
return nullptr; return nullptr;
} }
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) { if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
adaptor.getOperands(), operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
[](int64_t a, int64_t b) -> int64_t { return a - b; });
} }
return atenBinaryFloatOperatorFoldHelper( return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), operands, [](double a, double b) -> double { return a - b; });
[](double a, double b) -> double { return a - b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenDivOp // AtenDivOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) {
if (!adaptor.getA() || !adaptor.getB()) { if (!operands[0] || !operands[1]) {
return nullptr; return nullptr;
} }
// Since AtenDivOp always returns float value, we don't need to deal with the // Since AtenDivOp always returns float value, we don't need to deal with the
// case where the operands are both integers separately. // case where the operands are both integers separately.
return atenBinaryFloatOperatorFoldHelper( return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), operands, [](double a, double b) -> double { return a / b; });
[](double a, double b) -> double { return a / b; });
}
//===----------------------------------------------------------------------===//
// AtenPowIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenCeilScalarOp // AtenCeilScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
if (!adaptor.getA()) { if (!operands[0]) {
return nullptr; return nullptr;
} }
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>(); auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
if (!floatValue) { if (!floatValue) {
return nullptr; return nullptr;
} }
@ -2291,7 +2181,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
// AtenNegIntOp // AtenNegIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI64IntegerAttr(getContext(), -c); return getI64IntegerAttr(getContext(), -c);
@ -2302,7 +2192,7 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
// AtenSqrtIntOp // AtenSqrtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getF64FloatAttr(getContext(), std::sqrt(c)); return getF64FloatAttr(getContext(), std::sqrt(c));
@ -2313,7 +2203,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
// PrimDtypeOp // PrimDtypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>(); BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) { if (tensorType.hasDtype()) {
torch_upstream::ScalarType scalarType = torch_upstream::ScalarType scalarType =
@ -2327,7 +2217,7 @@ OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
// AtenIntTensorOp // AtenIntTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// If a scalar number is converted to a 0-d tensor and passed on to // If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number. // aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>()) if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2339,7 +2229,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// AtenFloatTensorOp // AtenFloatTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
// If a scalar number is converted to a 0-d tensor and passed on to // If a scalar number is converted to a 0-d tensor and passed on to
// aten.Float.Tensor, fold to the scalar number. // aten.Float.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>()) if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2351,7 +2241,7 @@ OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
// AtenDivFloatOp // AtenDivFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
double lhs, rhs; double lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
@ -2368,7 +2258,7 @@ OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
// AtenDivIntOp // AtenDivIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
int64_t lhs, rhs; int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2381,7 +2271,7 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
// AtenCeilFloatOp // AtenCeilFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
double c; double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI64IntegerAttr(getContext(), std::ceil(c)); return getI64IntegerAttr(getContext(), std::ceil(c));
@ -2392,13 +2282,13 @@ OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
// PrimMaxIntOp // PrimMaxIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
// If both operands are the same, then the operation is an identity. // If both operands are the same, then the operation is an identity.
if (getA() == getB()) if (getA() == getB())
return getA(); return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>(); auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>(); auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs) if (!lhs || !rhs)
return nullptr; return nullptr;
// Torch semantics are that !torch.int is 64-bit signed. // Torch semantics are that !torch.int is 64-bit signed.
@ -2411,7 +2301,7 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
// PrimMinSelfIntOp // PrimMinSelfIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
auto list = getOperand().getDefiningOp<PrimListConstructOp>(); auto list = getOperand().getDefiningOp<PrimListConstructOp>();
if (!list) if (!list)
return nullptr; return nullptr;
@ -2430,25 +2320,6 @@ OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
*std::min_element(values.begin(), values.end())); *std::min_element(values.begin(), values.end()));
} }
//===----------------------------------------------------------------------===//
// PrimMinIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity.
if (getA() == getB())
return getA();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
return IntegerAttr::get(
lhs.getType(),
std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue()));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ShapeCalculateOp // ShapeCalculateOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -68,32 +68,16 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return true; return true;
} }
auto subtypeTensorType = subtype.dyn_cast<BaseTensorType>();
auto typeTensorType = type.dyn_cast<BaseTensorType>();
if (subtypeTensorType && typeTensorType) {
// Check that both tensors have the same `BaseTensorType` subtype.
// TODO: This is not subtyping according to PEP 483. See description // TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType. // of NonValueTensorType.
if (subtypeTensorType.isa<ValueTensorType>() != if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
typeTensorType.isa<ValueTensorType>()) type ==
return false; NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
// `type` must not have more static information than `subtype`, and `type`
// must not disagree with `subtype`. if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
if (typeTensorType.hasDtype() && type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
(!subtypeTensorType.hasDtype() ||
typeTensorType.getDtype() != subtypeTensorType.getDtype())) {
return false;
}
if (typeTensorType.hasSizes() &&
(!subtypeTensorType.hasSizes() ||
typeTensorType.getSizes() != subtypeTensorType.getSizes())) {
return false;
}
return true; return true;
}
return false; return false;
} }
@ -479,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
} }
} }
return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype); return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
} }
////===----------------------------------------------------------------------===// ////===----------------------------------------------------------------------===//

View File

@ -4088,259 +4088,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n" " } : (!torch.int, !torch.bool) -> ()\n"
" return %none : !torch.none\n" " return %none : !torch.none\n"
" }\n" " }\n"
" func.func @__torch__.torch.jit._shape_functions.stack(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n"
" %str_0 = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %1, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.add.int %17, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %19 = torch.aten.le.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %20 = torch.prim.If %19 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %18 : !torch.int\n"
" }\n"
" %21 = torch.aten.neg.int %20 : !torch.int -> !torch.int\n"
" %22 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %23 = torch.aten.lt.int %arg1, %21 : !torch.int, !torch.int -> !torch.bool\n"
" %24 = torch.prim.If %23 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %31 = torch.aten.gt.int %arg1, %22 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %31 : !torch.bool\n"
" }\n"
" %25 = torch.aten.__not__ %24 : !torch.bool -> !torch.bool\n"
" torch.prim.If %25 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %26 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
" %31 = torch.aten.add.int %arg1, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %31 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %28 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %29, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %31 = torch.aten.__getitem__.t %16, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %32 = torch.aten.append.t %28, %31 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.aten.insert.t %28, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
" %30 = torch.aten.append.t %0, %28 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.gt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %18 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %4 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %5 = torch.derefine %none : !torch.none to !torch.optional<int>\n"
" %6 = torch.prim.Loop %4, %true, init(%5) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<int>):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %19 = torch.prim.If %18 -> (!torch.bool) {\n"
" %22 = torch.aten.__getitem__.t %16, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %23 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n"
" %21 = torch.prim.If %20 -> (!torch.optional<int>) {\n"
" %22 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %23 = torch.prim.If %22 -> (!torch.int) {\n"
" %25 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %26 = torch.aten.le.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %27 = torch.prim.If %26 -> (!torch.int) {\n"
" torch.prim.If.yield %int1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %25 : !torch.int\n"
" }\n"
" %28 = torch.aten.neg.int %27 : !torch.int -> !torch.int\n"
" %29 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %30 = torch.aten.lt.int %arg1, %28 : !torch.int, !torch.int -> !torch.bool\n"
" %31 = torch.prim.If %30 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %35 = torch.aten.gt.int %arg1, %29 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %35 : !torch.bool\n"
" }\n"
" %32 = torch.aten.__not__ %31 : !torch.bool -> !torch.bool\n"
" torch.prim.If %32 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %33 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %34 = torch.prim.If %33 -> (!torch.int) {\n"
" %35 = torch.aten.add.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %35 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" torch.prim.If.yield %34 : !torch.int\n"
" } else {\n"
" %25 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %25 : !torch.int\n"
" }\n"
" %24 = torch.derefine %23 : !torch.int to !torch.optional<int>\n"
" torch.prim.If.yield %24 : !torch.optional<int>\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.optional<int>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%21 : !torch.optional<int>)\n"
" } : (!torch.int, !torch.bool, !torch.optional<int>) -> !torch.optional<int>\n"
" %7 = torch.aten.__is__ %6, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" %16 = torch.prim.unchecked_cast %6 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %16 : !torch.int\n"
" }\n"
" %9 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %10 = torch.aten.gt.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %12 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %13 = torch.prim.Loop %11, %true, init(%12) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.optional<list<int>>):\n"
" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %17 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %18 = torch.prim.Loop %17, %true, init(%int1) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %16, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten.mul.int %arg5, %23 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop.condition %true, iter(%24 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %19 = torch.aten.eq.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %20 = torch.prim.If %19 -> (!torch.bool) {\n"
" %23 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %24 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n"
" %22 = torch.prim.If %21 -> (!torch.optional<list<int>>) {\n"
" %23 = torch.derefine %16 : !torch.list<int> to !torch.optional<list<int>>\n"
" torch.prim.If.yield %23 : !torch.optional<list<int>>\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.optional<list<int>>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%22 : !torch.optional<list<int>>)\n"
" } : (!torch.int, !torch.bool, !torch.optional<list<int>>) -> !torch.optional<list<int>>\n"
" %14 = torch.aten.__is__ %13, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %15 = torch.prim.If %14 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %2 : !torch.list<int>\n"
" } else {\n"
" %16 = torch.prim.unchecked_cast %13 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %17 = torch.aten.len.t %0 : !torch.list<list<int>> -> !torch.int\n"
" %18 = torch.prim.Loop %17, %true, init(%int0) {\n"
" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n"
" %22 = torch.aten.__getitem__.t %0, %arg2 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %23 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %24 = torch.prim.Loop %23, %true, init(%int1) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %29 = torch.aten.__getitem__.t %22, %arg4 : !torch.list<int>, !torch.int -> !torch.int\n"
" %30 = torch.aten.mul.int %arg5, %29 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop.condition %true, iter(%30 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %25 = torch.aten.eq.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %26 = torch.prim.If %25 -> (!torch.bool) {\n"
" %29 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %30 = torch.aten.eq.int %29, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %30 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool\n"
" %28 = torch.prim.If %27 -> (!torch.int) {\n"
" %29 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" %30 = torch.aten.len.t %22 : !torch.list<int> -> !torch.int\n"
" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %31 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %32 = torch.aten.__range_length %int0, %29, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" torch.prim.Loop %32, %true, init() {\n"
" ^bb0(%arg4: !torch.int):\n"
" %35 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %36 = torch.aten.ne.int %35, %8 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %36 -> () {\n"
" %37 = torch.aten.__getitem__.t %16, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
" %38 = torch.aten.__getitem__.t %22, %35 : !torch.list<int>, !torch.int -> !torch.int\n"
" %39 = torch.aten.eq.int %37, %38 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %39 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %33 = torch.aten.__getitem__.t %22, %8 : !torch.list<int>, !torch.int -> !torch.int\n"
" %34 = torch.aten.add.int %arg3, %33 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %34 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%28 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %19 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %20 = torch.aten.len.t %16 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %20, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %22 = torch.aten.__getitem__.t %16, %arg2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %23 = torch.aten.append.t %19, %22 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %21 = torch.aten._set_item.t %19, %8, %18 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %19 : !torch.list<int>\n"
" }\n"
" return %15 : !torch.list<int>\n"
" }\n"
" func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n" " %int0 = torch.constant.int 0\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
@ -6043,10 +5790,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -6134,10 +5877,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -6254,7 +5993,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n" " %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.float, %arg3: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %false = torch.constant.bool false\n" " %false = torch.constant.bool false\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n"
@ -6267,13 +6006,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n" " return %1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n" " return %1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -6296,7 +6035,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n" " return %1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<float>, %arg3: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.any\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -6810,9 +6549,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n" " return %arg1 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -6847,9 +6583,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.p\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._index_put_impl\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten._index_put_impl\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -6863,9 +6596,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n" " return %arg2 : !torch.list<int>\n"
" }\n" " }\n"
@ -7151,9 +6881,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -7583,10 +7310,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list<list<int>>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list<list<int>>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.int, %arg3: !torch.optional<str>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
@ -7617,13 +7340,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n" " return %2 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %int0 : !torch.int to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n" " %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n" " %int1 = torch.constant.int 1\n"
@ -8139,30 +7855,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
@ -8233,30 +7925,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
@ -8976,6 +8644,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n" " %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n" " %false = torch.constant.bool false\n"
" %int15 = torch.constant.int 15\n"
" %int5 = torch.constant.int 5\n" " %int5 = torch.constant.int 5\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %int4 = torch.constant.int 4\n" " %int4 = torch.constant.int 4\n"
@ -8990,7 +8659,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield %12 : !torch.bool\n" " torch.prim.If.yield %12 : !torch.bool\n"
" }\n" " }\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n" " %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n" " %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" " %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n" " torch.prim.If.yield %13 : !torch.bool\n"
@ -9012,7 +8681,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield %12 : !torch.bool\n" " torch.prim.If.yield %12 : !torch.bool\n"
" }\n" " }\n"
" %7 = torch.prim.If %6 -> (!torch.bool) {\n" " %7 = torch.prim.If %6 -> (!torch.bool) {\n"
" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list<int>\n" " %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n" " %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" " %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n" " torch.prim.If.yield %13 : !torch.bool\n"

View File

@ -10,6 +10,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -9,7 +9,6 @@ add_mlir_library(TorchMLIRTorchPasses
LowerToBackendContract.cpp LowerToBackendContract.cpp
MaximizeValueSemantics.cpp MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
ReduceOpVariants.cpp ReduceOpVariants.cpp
RefinePublicReturn.cpp RefinePublicReturn.cpp
RefineTypes.cpp RefineTypes.cpp

View File

@ -33,11 +33,9 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
int64_t dtypeInt; int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false; return false;
FailureOr<Type> resDtype = Type resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
if (failed(resDtype)) return resDtype.isa<mlir::FloatType>();
return false;
return resDtype->isa<mlir::FloatType>();
} }
// Helper function to compute the return type of the reduction function. // Helper function to compute the return type of the reduction function.
@ -72,7 +70,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype( Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>() sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes), : llvm::makeArrayRef(sizes),
tensorType.getOptionalDtype()); tensorType.getOptionalDtype());
return resultType; return resultType;
} }
@ -108,7 +106,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
valueType valueType
.getWithSizesAndDtype( .getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>() !valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(valueType.getSizes()), : llvm::makeArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed)) IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>(); .cast<BaseTensorType>();
return rewriter return rewriter
@ -142,7 +140,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) { BaseTensorType inputType, Value scalar) {
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
Type rank0TensorTy = inputType.getWithSizesAndDtype( Type rank0TensorTy = inputType.getWithSizesAndDtype(
ArrayRef(sizes), inputType.getOptionalDtype()); makeArrayRef(sizes), inputType.getOptionalDtype());
Value dimList = rewriter.create<PrimListConstructOp>( Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{}); ValueRange{});
@ -171,37 +169,6 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
return sub; return sub;
} }
// Helper function to unsqueeze the input tensor at given dim.
// Return the unsqueezed tensor or failure.
static FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter,
Operation *op, Value input, Value dim) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size");
}
SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes();
// `input` has a reduced rank. Hence add 1.
int64_t unsqueezedRank = inputShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, unsqueezedRank);
if (!isValidDim(dimInt, unsqueezedRank)) {
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
}
unsqueezedShape.append(inputShape.begin(), inputShape.end());
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType = inputType.getWithSizesAndDtype(
unsqueezedShape, inputType.getOptionalDtype());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
}
namespace { namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed. /// number of dimensions across which the max needs to be computed.
@ -291,15 +258,6 @@ public:
Value dim = op.getDim(); Value dim = op.getDim();
Value self = op.getSelf(); Value self = op.getSelf();
// convert `start` to non-negative: start += int(start < 0) * dimSize
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value isNegative = rewriter.create<AtenLtIntOp>(loc, start, zero);
isNegative = rewriter.create<AtenIntBoolOp>(loc, isNegative);
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, self, dim);
Value indexOffset = rewriter.create<AtenMulIntOp>(loc, isNegative, dimSize);
start = rewriter.create<AtenAddIntOp>(loc, start, indexOffset);
Value one = Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne = Value startPlusOne =
@ -637,128 +595,6 @@ public:
}; };
} // namespace } // namespace
// Decompose `aten.bucketize` into the following op sequence:
//
// def aten_bucketize(input, boundaries, out_int32, right):
// unsqz_input = input.unsqueeze(-1)
// if not right:
// comparison = unsqz_input <= boundaries
// else:
// comparison = unsqz_input < boundaries
// indices = torch.argmax(comparison.float(), dim=-1)
// within_bound = comparison[..., -1]
// result = torch.where(within_bound, indices, boundaries.shape[0])
// if out_int32:
// result = result.int()
// return result
//
namespace {
class DecomposeAtenBucketizeTensorOp
: public OpRewritePattern<AtenBucketizeTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBucketizeTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: input must have known sizes");
}
ArrayRef<int64_t> inputShape = inputType.getSizes();
Value boundaries = op.getBoundaries();
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
return rewriter.notifyMatchFailure(op,
"unimplemented: boundaries must have "
"known sizes and must be a 1D array");
}
int64_t boundariesSize = boundariesType.getSizes()[0];
bool outInt32;
if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: out_int32 must be a constant bool");
}
bool right;
if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: right must be a constant bool");
}
// unsqueeze input at the last dim to make it broadcastable with boundaries
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
Value unsqzInput = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
SmallVector<int64_t> compareShape(inputShape);
compareShape.push_back(boundariesSize);
Type compareType =
inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type());
Value compare;
if (!right) {
compare = rewriter.create<AtenLeTensorOp>(loc, compareType, unsqzInput,
boundaries);
} else {
compare = rewriter.create<AtenLtTensorOp>(loc, compareType, unsqzInput,
boundaries);
}
// convert the comparison results to float32 as the argmax op input,
// which does not support integer dtype in LINALG backend
Value compareF32 =
convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type());
// get the first boundary index where the input element is less than (or
// equal to) the boundary value
Type indicesType = inputType.getWithSizesAndDtype(
inputShape, rewriter.getIntegerType(64, IntegerType::Signed));
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value indices = rewriter.create<AtenArgmaxOp>(loc, indicesType, compareF32,
/*dim=*/constMinusOne,
/*keepdim=*/constFalse);
// get the comparison results between each input element and the rightmost
// boundary value
Type withinUpperBoundType =
inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type());
Value withinUpperBound = rewriter.create<AtenSelectIntOp>(
loc, withinUpperBoundType, compare, /*dim=*/constMinusOne,
/*index=*/constMinusOne);
// If the input element is less than (or equal to) the rightmost boundary,
// take the max index as result. Otherwise, the element is beyond the
// rightmost boundary, so take the boundary size.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value upperBound =
rewriter.create<AtenSizeIntOp>(loc, boundaries, /*dim=*/constZero);
Value result = rewriter.create<AtenWhereScalarOtherOp>(
loc, indicesType, withinUpperBound, indices, upperBound);
if (outInt32) {
result = convertTensorToDtype(
rewriter, loc, result,
rewriter.getIntegerType(32, IntegerType::Signed));
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// To avoid overflow we use the following decomposition rule: // To avoid overflow we use the following decomposition rule:
// x_max = aten.max(x, dim, keepdim=True)[0] // x_max = aten.max(x, dim, keepdim=True)[0]
// shifted = x - x_max // shifted = x - x_max
@ -1055,50 +891,6 @@ public:
}; };
} // namespace } // namespace
// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`.
namespace {
class DecomposeAtenStackOp : public OpRewritePattern<AtenStackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenStackOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the tensor list is not from list construct");
}
// Ensure all tensors have known sizes
for (Value tensor : tensors) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (!tensorType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: one tensor does not have known sizes");
}
}
SmallVector<Value> unsqueezedTensors;
for (Value tensor : tensors) {
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim());
if (failed(unsqueezedInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate unsqueeze tensor op");
}
unsqueezedTensors.push_back(*unsqueezedInfo);
}
Type listElemType =
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
op.getLoc(), listType, unsqueezedTensors);
rewriter.replaceOpWithNewOp<AtenCatOp>(op, op.getType(),
unsqueezedTensorList, op.getDim());
return success();
}
};
} // namespace
// Decompose aten.roll into aten.slice and aten.cat ops. // Decompose aten.roll into aten.slice and aten.cat ops.
// https://pytorch.org/docs/stable/generated/torch.roll.html // https://pytorch.org/docs/stable/generated/torch.roll.html
namespace { namespace {
@ -1137,7 +929,7 @@ public:
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end()); sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize; sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
selfTy.getOptionalDtype()); selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>( Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne); loc, sliceTy, input, dim, negShift, constNone, constOne);
@ -1274,9 +1066,9 @@ public:
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get( Type unsqueezedType = ValueTensorType::get(
context, llvm::ArrayRef(unsqueezedIntSizes), dtype); context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = Type expandedType = ValueTensorType::get(
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); context, llvm::makeArrayRef(expandedIntSizes), dtype);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims = Value unsqueezedDims =
@ -1434,25 +1226,6 @@ public:
}; };
} // namespace } // namespace
// Decompose aten.masked_fill.Scalar into aten.where.self op.
namespace {
class DecomposeAtenMaskedFillScalarOp
: public OpRewritePattern<AtenMaskedFillScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
Value mask = op.getMask();
Value value = createRank0Tensor(rewriter, loc, resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resType, mask,
value, op.getSelf());
return success();
}
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution op. // Decompose aten.convolution_overrideable to aten.convolution op.
namespace { namespace {
class DecomposeAtenConvolutionOverrideableOp class DecomposeAtenConvolutionOverrideableOp
@ -2204,23 +1977,23 @@ public:
// aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)). // aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)).
// Since the input x can be an integer tensor, it's important to cast it to // Since the input x can be an integer tensor, it's important to cast it to
// float type before passing it to the `aten.randLike` op. // float type before passing it to the `aten.randLike` op.
template <typename BernoulliLikeOp> class DecomposeValsemVariantAtenBernoulliFloatOp
class DecomposeAtenBernoulliLikeOp : public OpRewritePattern<BernoulliLikeOp> { : public OpRewritePattern<ValsemVariantAtenBernoulliFloatOp> {
public: public:
using OpRewritePattern<BernoulliLikeOp>::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(BernoulliLikeOp op, LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = op.getSelf(); Value input = op.getSelf();
Value p = op.getP(); Value p = op.getP();
if (!op.getGenerator().getType().template isa<Torch::NoneType>()) if (!op.getGenerator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default " op, "The generator has to ben None because only global default "
"generator is supported"); "generator is supported");
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
rewriter.getF64Type()); rewriter.getF64Type());
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p); Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
Value output; Value output;
@ -2298,8 +2071,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
std::vector<int64_t> meanVarSizes(inputRank, 1); std::vector<int64_t> meanVarSizes(inputRank, 1);
for (int i = 0; i < axis; i++) for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i]; meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), auto meanVarType = input.getWithSizesAndDtype(
input.getOptionalDtype()); llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>( auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(), loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
@ -2536,7 +2309,7 @@ class DecomposeAtenNativeBatchNormOp
runningStatsShapeInt[1] = kUnknownSize; runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get( Type reshapeType = ValueTensorType::get(
context, llvm::ArrayRef(runningStatsShapeInt), dtype); context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean, runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList); runningStatsSizeList);
@ -2682,7 +2455,8 @@ public:
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
auto dtype = auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType()); getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
op.getFillValue()); op.getFillValue());
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
@ -2718,7 +2492,7 @@ public:
SmallVector<int64_t> transposeShape = SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes())); llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype( Type transposeType = weightType.getWithSizesAndDtype(
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight = Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight); rewriter.create<AtenTOp>(loc, transposeType, weight);
@ -2788,7 +2562,8 @@ public:
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
auto dtype = auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType()); getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>( Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(), tensorType, op.getFillValue()); op.getLoc(), tensorType, op.getFillValue());
fillVal = fillVal =
@ -3228,7 +3003,7 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
template <typename OpTy> template <typename OpTy>
static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
bool unbiased, double correction) { bool unbiased, int64_t correction) {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = op.getSelf(); Value self = op.getSelf();
Value dimList = op.getDim(); Value dimList = op.getDim();
@ -3314,22 +3089,19 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
productDimSize = productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize); rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
} }
productDimSize = rewriter.create<AtenFloatScalarOp>(loc, productDimSize); Value cstCorrection = rewriter.create<Torch::ConstantIntOp>(
constantOne = rewriter.create<Torch::ConstantFloatOp>( loc, rewriter.getI64IntegerAttr(correction));
loc, rewriter.getF64FloatAttr(1.0));
Value cstCorrection = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(correction));
// The `correction` value should be less than or equal to `productDimSize + // The `correction` value should be less than or equal to `productDimSize +
// 1`. // 1`.
Value productDimSizePlusOne = rewriter.create<AtenAddOp>( Value productDimSizePlusOne =
loc, productDimSize.getType(), productDimSize, constantOne); rewriter.create<AtenAddIntOp>(loc, productDimSize, constantOne);
Value cond = Value cond =
rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne, cstCorrection); rewriter.create<AtenGeIntOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<RuntimeAssertOp>( rewriter.create<RuntimeAssertOp>(
loc, cond, loc, cond,
"correction value should be less than or equal to productDimSize + 1"); "correction value should be less than or equal to productDimSize + 1");
Value productDimSizeSubCorrection = Value productDimSizeSubCorrection =
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection); rewriter.create<AtenSubIntOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum, Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
productDimSizeSubCorrection); productDimSizeSubCorrection);
result = result =
@ -3356,7 +3128,7 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only support constant unbiased for aten.var"); op, "Only support constant unbiased for aten.var");
} }
double correction = unbiased ? 1.0 : 0.0; int64_t correction = unbiased ? 1 : 0;
if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased, if (failed(calculateVariance<AtenVarDimOp>(op, rewriter, unbiased,
correction))) correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return rewriter.notifyMatchFailure(op, "invalid variance parameters");
@ -3376,32 +3148,18 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenVarCorrectionOp op, LogicalResult matchAndRewrite(AtenVarCorrectionOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
int64_t correctionValInt; int64_t correction;
double correctionValFloat = 1.0;
if (!op.getCorrection().getType().isa<Torch::NoneType>()) { if (!op.getCorrection().getType().isa<Torch::NoneType>()) {
if (op.getCorrection().getType().isa<Torch::FloatType>()) { if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correction)))
if (!matchPattern(op.getCorrection(),
m_TorchConstantFloat(&correctionValFloat)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for " op, "Only support constant int correction for aten.var");
"aten.var");
} else if (op.getCorrection().getType().isa<Torch::IntType>()) {
if (!matchPattern(op.getCorrection(),
m_TorchConstantInt(&correctionValInt)))
return rewriter.notifyMatchFailure(
op, "Only support constant int or float correction value for "
"aten.var");
correctionValFloat = (double)correctionValInt;
} else { } else {
return rewriter.notifyMatchFailure( // The default value in case of `correction` being None is 1.
op, "unimplemented: correction value should be only constant int " correction = 1;
"or float for aten.var");
} }
} bool unbiased = correction == 0 ? false : true;
bool unbiased = correctionValFloat == 0.0 ? false : true;
if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased, if (failed(calculateVariance<AtenVarCorrectionOp>(op, rewriter, unbiased,
correctionValFloat))) correction)))
return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return rewriter.notifyMatchFailure(op, "invalid variance parameters");
return success(); return success();
} }
@ -3426,13 +3184,29 @@ public:
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne = Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one); rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
SmallVector<int64_t> sizes;
if (!srcTensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "src tensor must have size");
auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim); ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
if (failed(unsqueezedInfo)) { // `src` has a reduced rank. Hence add 1.
return rewriter.notifyMatchFailure(op, int64_t srcRank = srcShape.size() + 1;
"cannot generate unsqueeze tensor op"); int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, srcRank);
if (!isValidDim(dimInt, srcRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(srcShape.begin(), srcShape.end());
sizes.insert(sizes.begin() + dimInt, 1);
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
} }
src = *unsqueezedInfo; Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>( rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne, op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
/*step=*/one); /*step=*/one);
@ -3529,7 +3303,7 @@ public:
op, "Expected the input tensor to have sizes"); op, "Expected the input tensor to have sizes");
BaseTensorType subType = BaseTensorType subType =
inputType inputType
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), .getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
resultType.getOptionalDtype()) resultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();
@ -3556,29 +3330,6 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op
class DecomposeAtenNormScalarOptDimOp
: public OpRewritePattern<AtenNormScalarOptDimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value ord = op.getP();
if (ord.getType().isa<Torch::NoneType>()) {
ord = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));
}
rewriter.replaceOpWithNewOp<AtenLinalgVectorNormOp>(
op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(),
/*dtype=*/none);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> { class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
public: public:
@ -3775,40 +3526,6 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten.randn_like` op into `aten.randn.generator` op.
class DecomposeAtenRandnLikeOp : public OpRewritePattern<AtenRandnLikeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRandnLikeOp op,
PatternRewriter &rewriter) const override {
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
rewriter.replaceOpWithNewOp<AtenRandnGeneratorOp>(
op, op.getType(), sizeList, /*generator=*/none, op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> { class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public: public:
@ -3829,49 +3546,6 @@ public:
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenNewEmptyStridedOp
: public OpRewritePattern<AtenNewEmptyStridedOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
// stride[2] == 1.
bool isDefaultStride = true;
for (unsigned i = 0; i < strideListInts.size(); i++) {
int64_t defaultStride = 1;
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
defaultStride *= sizeListInts[j];
if (defaultStride != strideListInts[i]) {
isDefaultStride = false;
break;
}
}
if (!isDefaultStride)
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
rewriter.replaceOpWithNewOp<AtenNewEmptyOp>(
op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(),
op.getLayout(), op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -3917,7 +3591,6 @@ public:
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns); DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns); DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStackOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
@ -3925,7 +3598,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns); DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
@ -3968,11 +3640,8 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
addPatternIfTargetOpIsIllegal< addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
DecomposeAtenBernoulliLikeOp<ValsemVariantAtenBernoulliFloatOp>>(
patterns); patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
@ -4019,7 +3688,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>( addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
patterns); patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormScalarOptDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
@ -4027,12 +3695,9 @@ public:
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;

View File

@ -10,9 +10,9 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -10,9 +10,9 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -244,7 +244,7 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable,
continue; continue;
opsToMove.push_back(&op); opsToMove.push_back(&op);
} }
IRMapping mapping; BlockAndValueMapping mapping;
for (Operation *op : opsToMove) { for (Operation *op : opsToMove) {
// The ops are used by `torch.slot` ops in the enclosing module. // The ops are used by `torch.slot` ops in the enclosing module.
// Cloning avoids needing to handle those uses specially. // Cloning avoids needing to handle those uses specially.
@ -329,7 +329,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
// currently only analyzes a subset of ops. // currently only analyzes a subset of ops.
static LogicalResult analyzeInstances(func::FuncOp func, static LogicalResult analyzeInstances(func::FuncOp func,
ArrayRef<ArgInstance> argInstances, ArrayRef<ArgInstance> argInstances,
IRMapping &mapping) { BlockAndValueMapping &mapping) {
for (auto &argInstance : argInstances) for (auto &argInstance : argInstances)
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance); mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
auto walkResult = func.walk([&](PrimGetAttrOp op) { auto walkResult = func.walk([&](PrimGetAttrOp op) {
@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
} }
static FailureOr<Monomorphization> static FailureOr<Monomorphization>
createMonomorphizationForCall(func::CallOp op, IRMapping &mapping, createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
SymbolTable &symbolTable) { SymbolTable &symbolTable) {
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee()); auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
Monomorphization monomorphization; Monomorphization monomorphization;
@ -410,7 +410,7 @@ public:
private: private:
LogicalResult generateNewMonomorphizations(const Monomorphization &m) { LogicalResult generateNewMonomorphizations(const Monomorphization &m) {
auto func = m.func; auto func = m.func;
IRMapping mapping; BlockAndValueMapping mapping;
if (failed(analyzeInstances(func, m.argInstances, mapping))) if (failed(analyzeInstances(func, m.argInstances, mapping)))
return failure(); return failure();
auto walkResult = func.walk([&](func::CallOp op) { auto walkResult = func.walk([&](func::CallOp op) {
@ -495,7 +495,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in // Rewrite `func`, given that all values of `NnModuleType` have been mapped in
// `mapping` to corresponding global instances. // `mapping` to corresponding global instances.
static LogicalResult rewriteMonomorphizedFuncClone( static LogicalResult rewriteMonomorphizedFuncClone(
func::FuncOp func, IRMapping mapping, SymbolTable &symbolTable, func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
DenseMap<Monomorphization, func::FuncOp> &newFuncs, DenseMap<Monomorphization, func::FuncOp> &newFuncs,
ObjectGraphInfo &objectGraphInfo) { ObjectGraphInfo &objectGraphInfo) {
@ -662,7 +662,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
} }
for (auto &kv : newFuncs) { for (auto &kv : newFuncs) {
IRMapping mapping; BlockAndValueMapping mapping;
if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping))) if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping)))
return failure(); return failure();
if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable, if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable,

View File

@ -27,8 +27,8 @@
#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -373,7 +373,7 @@ class InlineGlobalSlotsPass
// big deal. // big deal.
SmallVector<Operation *> slice = SmallVector<Operation *> slice =
getBackwardSliceIncludingRoot(initialValue); getBackwardSliceIncludingRoot(initialValue);
IRMapping mapping; BlockAndValueMapping mapping;
OpBuilder builder(op); OpBuilder builder(op);
for (Operation *opInSlice : slice) for (Operation *opInSlice : slice)
builder.clone(*opInSlice, mapping); builder.clone(*opInSlice, mapping);

View File

@ -285,16 +285,19 @@ public:
} }
}; };
class VerifyBackendContractNoDecompositionsPass class VerifyBackendContractPass
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> { : public VerifyBackendContractBase<VerifyBackendContractPass> {
public: public:
VerifyBackendContractNoDecompositionsPass() = default; VerifyBackendContractPass() = default;
VerifyBackendContractPass(bool decompose,
ArrayRef<std::string> backendLegalOps) {
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
}
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
ConversionTarget target = ConversionTarget target =
getBackendContractTarget(context, /*decompose*/false, getBackendContractTarget(context, decompose, backendLegalOps);
/*backendLegalOps*/{});
if (!satisfiesBackendContract(getOperation(), target, if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) { /*actuallyEmitDiagnostics=*/true)) {
@ -312,8 +315,10 @@ mlir::torch::Torch::createLowerToBackendContractPass(
} }
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { mlir::torch::Torch::createVerifyBackendContractPass(
return std::make_unique<VerifyBackendContractNoDecompositionsPass>(); bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<VerifyBackendContractPass>(decompose,
backendLegalOps);
} }
// The backend contract guarantees that ops with decompositions available will // The backend contract guarantees that ops with decompositions available will
@ -342,7 +347,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenEmptyLikeOp>(); target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>(); target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>(); target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenStackOp>();
target.addIllegalOp<AtenRollOp>(); target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>(); target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>(); target.addIllegalOp<AtenExpandOp>();
@ -350,7 +354,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOp>(); target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>(); target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>(); target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>(); target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>(); target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>(); target.addIllegalOp<AtenReshapeOp>();
@ -359,7 +362,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAddmmOp>(); target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>(); target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>(); target.addIllegalOp<AtenMeanDimOp>();
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>(); target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>(); target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenTOp>(); target.addIllegalOp<AtenTOp>();
@ -392,7 +394,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_ReshapeAliasOp>(); target.addIllegalOp<Aten_ReshapeAliasOp>();
target.addIllegalOp<AtenBernoulliOp>(); target.addIllegalOp<AtenBernoulliOp>();
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>(); target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>(); target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>(); target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenRandLikeOp>(); target.addIllegalOp<AtenRandLikeOp>();
@ -441,10 +442,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<PrimsSqrtOp>(); target.addIllegalOp<PrimsSqrtOp>();
target.addIllegalOp<AtenRandnOp>(); target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>(); target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>(); target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) { for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context)); target.addLegalOp(OperationName(opName, context));
} }

View File

@ -106,7 +106,6 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point // Clean up again to avoid needing to to back around the fixed-point
// iteration. // iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives. // Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass()); pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

View File

@ -10,6 +10,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -1,103 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopy_Op op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
return failure();
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
// Get indices
int64_t dim;
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
int64_t end;
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();
Value newEnd = sliceOp.getEnd();
if (end < 0) {
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
// Create IndexPut_Op
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);
SmallVector<Value> indicesVector;
for (auto i = 0; i < dim - 1; i++)
indicesVector.push_back(noneVal);
indicesVector.push_back(range);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
return success();
}
};
} // namespace
namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
}

View File

@ -59,6 +59,7 @@
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -80,9 +81,7 @@ using namespace mlir::torch::Torch;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
FailureOr<Type> result = return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return failed(result) ? Type() : *result;
} }
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
@ -112,6 +111,24 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
return torch_upstream::TypeKind::AnyType; return torch_upstream::TypeKind::AnyType;
} }
/// Returns the dtype that assumes information from both `lhs` and `rhs`.
/// Returns `std::nullopt` if the types are contradictory. Note this can only
/// be used on the `dtype` from tensors and can't be used on other types like
/// scalar types.
static std::optional<Type> meetElementTypes(Type lhs, Type rhs) {
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
(void)isNullOrBuiltIn;
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
if (!lhs)
return rhs;
if (!rhs)
return lhs;
if (lhs == rhs)
return lhs;
return std::nullopt;
}
enum class OptionalKnowledge { enum class OptionalKnowledge {
unKnown, unKnown,
@ -458,8 +475,7 @@ private:
void visitAtenToDtypeLikeOp(OpTy op, ArrayRef<const ValueState *> operands); void visitAtenToDtypeLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy> template <typename OpTy>
void visitTypeConversionOp(OpTy op, ArrayRef<const ValueState *> operands); void visitTypeConversionOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy> void visitAtenCatOp(AtenCatOp op, ArrayRef<const ValueState *> operands);
void visitAtenCatLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
template <typename OpTy> template <typename OpTy>
void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands); void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef<const ValueState *> operands);
@ -547,9 +563,7 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
/*skipRankCheck=*/true); /*skipRankCheck=*/true);
state = state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
FailureOr<Type> result = return getTypeForScalarType(scalarType.getContext(), result_type(state));
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
} }
static SmallVector<std::optional<bool>> static SmallVector<std::optional<bool>>
@ -586,8 +600,7 @@ static Type getPromotedResultType(MLIRContext *context,
return Type(); return Type();
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck); state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
} }
FailureOr<Type> result = getTypeForScalarType(context, result_type(state)); return getTypeForScalarType(context, result_type(state));
return failed(result) ? Type() : *result;
} }
static Type getPromotedResultTypeAssumingNonZeroRank( static Type getPromotedResultTypeAssumingNonZeroRank(
@ -636,26 +649,23 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp, AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp, AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, AtenHardtanhBackwardOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, AtenUniformOp,
AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp, AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp,
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp, AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, AtenSelectIntOp, AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp, AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp,
Aten_IndexPutImplOp, AtenIndexPutOp, AtenCopyOp, AtenZeroOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp,
AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) { AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
@ -960,16 +970,9 @@ void TypeAnalysis::visitOperation(Operation *op,
} else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) { } else if (auto newEmpty = dyn_cast<AtenNewEmptyOp>(op)) {
visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands); visitConstantTensorNewLikeOp<AtenNewEmptyOp>(newEmpty, operands);
return; return;
} else if (auto newEmptyStrided = dyn_cast<AtenNewEmptyStridedOp>(op)) {
visitConstantTensorNewLikeOp<AtenNewEmptyStridedOp>(newEmptyStrided,
operands);
return;
} else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) { } else if (auto randLike = dyn_cast<AtenRandLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands); visitConstantTensorAllocLikeOp<AtenRandLikeOp>(randLike, operands);
return; return;
} else if (auto randLike = dyn_cast<AtenRandnLikeOp>(op)) {
visitConstantTensorAllocLikeOp<AtenRandnLikeOp>(randLike, operands);
return;
} else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) { } else if (auto toCopy = dyn_cast<Aten_ToCopyOp>(op)) {
visitConstantTensorAllocLikeOp<Aten_ToCopyOp>(toCopy, operands); visitConstantTensorAllocLikeOp<Aten_ToCopyOp>(toCopy, operands);
return; return;
@ -1005,10 +1008,7 @@ void TypeAnalysis::visitOperation(Operation *op,
} }
if (auto cat = dyn_cast<AtenCatOp>(op)) { if (auto cat = dyn_cast<AtenCatOp>(op)) {
visitAtenCatLikeOp<AtenCatOp>(cat, operands); visitAtenCatOp(cat, operands);
return;
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
return; return;
} }
@ -1114,22 +1114,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return; return;
} }
if (auto bucketize = dyn_cast<AtenBucketizeTensorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
bool outInt32;
if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) &&
outInt32) {
knowledge.dtype =
IntegerType::get(op->getContext(), 32, IntegerType::Signed);
} else {
knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
}
incorporateKnowledge(bucketize.getResult(), knowledge);
return;
}
// Otherwise, this is an unknown operation, so reset the state. // Otherwise, this is an unknown operation, so reset the state.
setAllToEntryStates(results); setAllToEntryStates(results);
return; return;
@ -1354,26 +1338,30 @@ void TypeAnalysis::visitTypeConversionOp(
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given // `torch.aten.cat` concatenates the given sequence of seq tensors in the given
// dimension. The output has the same sizes as the input for all dimensions // dimension. The output has the same sizes as the input for all dimensions
// except the given dimension. // except the given dimension.
template <typename OpTy> void TypeAnalysis::visitAtenCatOp(AtenCatOp op,
void TypeAnalysis::visitAtenCatLikeOp(OpTy op,
ArrayRef<const ValueState *> operands) { ArrayRef<const ValueState *> operands) {
auto tensorList = op.getTensors(); auto tensorList = op.getTensors();
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());
auto listConstruct = tensorList.template getDefiningOp<PrimListConstructOp>(); auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) { if (!listConstruct) {
incorporateKnowledge(op.getResult(), knowledge); incorporateKnowledge(op.getResult(), knowledge);
return; return;
} }
SmallVector<ValueKnowledge*> tensors = llvm::to_vector( auto tensors = llvm::to_vector<4>(
llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* { llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge {
return &getLatticeElement(v)->getValue(); return getLatticeElement(v)->getValue();
})); }));
for (auto tensor : tensors) {
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype);
op->getContext(), tensors); if (!newDtype.has_value()) {
incorporateKnowledge(op->getResult(0), knowledge); incorporateKnowledge(op.getResult(), knowledge);
return;
}
knowledge.dtype = newDtype.value();
}
incorporateKnowledge(op.getResult(), knowledge);
} }
void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) { void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
@ -1448,16 +1436,12 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement) if (!latticeElement)
return nullptr; return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue(); const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
return getRefinedTensorType(tensorType, knowledge); return getRefinedTensorType(tensorType, knowledge);
} else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) { } else if (auto optionalType = v.getType().dyn_cast<OptionalType>()) {
const ValueState *latticeElement = solver.lookupState<ValueState>(v); const ValueState *latticeElement = solver.lookupState<ValueState>(v);
if (!latticeElement) if (!latticeElement)
return nullptr; return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue(); const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.optional == OptionalKnowledge::isNone) if (knowledge.optional == OptionalKnowledge::isNone)
return Torch::NoneType::get(v.getContext()); return Torch::NoneType::get(v.getContext());
else if (knowledge.optional == OptionalKnowledge::notNone) { else if (knowledge.optional == OptionalKnowledge::notNone) {
@ -1472,8 +1456,6 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
if (!latticeElement) if (!latticeElement)
return nullptr; return nullptr;
const ValueKnowledge &knowledge = latticeElement->getValue(); const ValueKnowledge &knowledge = latticeElement->getValue();
if (!knowledge.isInitialized)
return nullptr;
if (knowledge.kind == torch_upstream::TypeKind::IntType) if (knowledge.kind == torch_upstream::TypeKind::IntType)
return Torch::IntType::get(v.getContext()); return Torch::IntType::get(v.getContext());
if (knowledge.kind == torch_upstream::TypeKind::FloatType) if (knowledge.kind == torch_upstream::TypeKind::FloatType)

View File

@ -46,15 +46,10 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
impliedTypeFromDtype = *torchType; impliedTypeFromDtype = *torchType;
} else if (auto originalResultType = } else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) { result.getType().dyn_cast<BaseTensorType>()) {
FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype = impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype( originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(), *builtinType); originalResultType.getOptionalSizes(),
getTypeForScalarType(op->getContext(), dtypeScalarType));
} else { } else {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to " "Unimplemented: Expected result type to "

View File

@ -10,7 +10,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "SimplifyAbstractInterpCalculationsUtils.h" #include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
@ -47,7 +47,7 @@ public:
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());
SmallVector<Block *> blocksToMerge; SmallVector<Block *> blocksToMerge;
IRMapping bvm; BlockAndValueMapping bvm;
// TODO: Helper for region().front() // TODO: Helper for region().front()
auto condition = auto condition =
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator()); cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
@ -129,7 +129,8 @@ public:
// Truncate the list of users to the number of users we're going to // Truncate the list of users to the number of users we're going to
// interpret. // interpret.
allUsers.resize(numUsersToInterpret); allUsers.resize(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); auto usersToInterpret =
makeArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the // For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then // current state of the list as a vector of Value's. These will then
@ -335,7 +336,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>(); auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape = auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>() originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(ArrayRef(sizes), .getWithSizesAndDtype(makeArrayRef(sizes),
originalResultType.getOptionalDtype()) originalResultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();

View File

@ -8,8 +8,6 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "llvm/Support/ErrorHandling.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {
namespace torch_upstream { namespace torch_upstream {
@ -128,23 +126,6 @@ ScalarType result_type(const ResultTypeState &in_state) {
combine_categories(in_state.zeroResult, in_state.wrappedResult)); combine_categories(in_state.zeroResult, in_state.wrappedResult));
} }
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
if (reduce == "max" || reduce == "amax") {
return torch_upstream::ReductionType::MAX;
} else if (reduce == "mean") {
return torch_upstream::ReductionType::MEAN;
} else if (reduce == "min" || reduce == "amin") {
return torch_upstream::ReductionType::MIN;
} else if (reduce == "sum") {
return torch_upstream::ReductionType::SUM;
} else if (reduce == "prod") {
return torch_upstream::ReductionType::PROD;
} else {
llvm_unreachable(
"'reduce' argument must be either sum, prod, mean, amax or amin");
}
}
} // namespace torch_upstream } // namespace torch_upstream
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -83,9 +83,8 @@ Type Torch::getTypeForTorchType(
llvm::report_fatal_error("unhandled type for getTypeForTorchType"); llvm::report_fatal_error("unhandled type for getTypeForTorchType");
} }
FailureOr<Type> Type Torch::getTypeForScalarType(
Torch::getTypeForScalarType(MLIRContext *context, MLIRContext *context, torch_upstream::ScalarType dtypeInt,
torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) { mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) { switch (dtypeInt) {
case torch_upstream::ScalarType::Float: case torch_upstream::ScalarType::Float:
@ -111,8 +110,6 @@ Torch::getTypeForScalarType(MLIRContext *context,
return mlir::ComplexType::get(Float64Type::get(context)); return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::ComplexDouble: case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float128Type::get(context)); return mlir::ComplexType::get(Float128Type::get(context));
case torch_upstream::ScalarType::Undefined:
return failure();
default: default:
llvm::report_fatal_error("unhandled type for getTypeForScalarType"); llvm::report_fatal_error("unhandled type for getTypeForScalarType");
} }
@ -126,7 +123,6 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
return Torch::FloatType::get(context); return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long: case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context); return Torch::IntType::get(context);
case torch_upstream::ScalarType::Undefined:
default: default:
return failure(); return failure();
} }

View File

@ -32,11 +32,11 @@ namespace {
struct TorchConversionInlinerInterface : public DialectInlinerInterface { struct TorchConversionInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface; using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final { BlockAndValueMapping &valueMapping) const final {
return true; return true;
} }
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
IRMapping &) const final { BlockAndValueMapping &) const final {
return true; return true;
} }
}; };

View File

@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
// FromI64Op // FromI64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
// ToI64Op // ToI64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
// ToF64Op // ToF64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>(); auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
// FromF64Op // FromF64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>(); auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"

View File

@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR
TorchMLIRTorchConversionToMLProgram TorchMLIRTorchConversionToMLProgram
MLIRMemRefTransforms) MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND LinkedLibs ChloPasses) list(APPEND LinkedLibs ChloPasses)
endif() endif()
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp Passes.cpp
VerifyLinalgOnTensorsBackendContract.cpp VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp VerifyTosaBackendContract.cpp
VerifyStablehloBackendContract.cpp VerifyMhloBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -21,8 +21,9 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif #endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -52,13 +53,12 @@ void mlir::torch::registerTorchConversionPasses() {
"Pipeline lowering torch backend contract to TOSA backend " "Pipeline lowering torch backend contract to TOSA backend "
"contract.", "contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline); TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
mlir::PassPipelineRegistration< mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
TorchConversion::StablehloBackendPipelineOptions>( "torch-backend-to-mhlo-backend-pipeline",
"torch-backend-to-stablehlo-backend-pipeline", "Pipeline lowering torch backend contract to MHLO backend "
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.", "contract.",
TorchConversion::createTorchBackendToStablehloBackendPipeline); TorchConversion::createTorchBackendToMhloBackendPipeline);
#endif #endif
} }
@ -121,12 +121,11 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
void TorchConversion::createTorchBackendToStablehloBackendPipeline( void TorchConversion::createTorchBackendToMhloBackendPipeline(
OpPassManager &pm, OpPassManager &pm,
const TorchConversion::StablehloBackendPipelineOptions &options) { const TorchConversion::MhloBackendPipelineOptions &options) {
// Generate Stablehlo ops. pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index)); options.enableStaticShape, options.enableI32Index));
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..
@ -134,13 +133,21 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them. // The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass()); pm.addNestedPass<func::FuncOp>(createCSEPass());
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// StableHLO backend contract. // MHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that MHLO backends
// Verify that we have lowered to Stablehlo and Chlo ops. // expect. This fails compilation (signalPassFailure) if the IR is not in the
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); // correct form.
pm.addPass(TorchConversion::createVerifyMhloBackendContractPass());
} }
#endif #endif

View File

@ -6,9 +6,10 @@
// Also available under a BSD-style license. See LICENSE. // Also available under a BSD-style license. See LICENSE.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_MHLO
#include "PassDetail.h" #include "PassDetail.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
@ -17,7 +18,6 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
@ -25,15 +25,17 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
namespace { namespace {
class VerifyStablehloBackendContractPass class VerifyMhloBackendContractPass
: public VerifyStablehloBackendContractBase< : public VerifyMhloBackendContractBase<VerifyMhloBackendContractPass> {
VerifyStablehloBackendContractPass> {
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext();
auto module = getOperation();
TypeConverter converter; TypeConverter converter;
converter.addConversion([](Type type) -> Type { converter.addConversion([](Type type) -> Type {
auto elemTy = type; auto elemTy = type;
if (isa<TensorType>(type)) if (isa<TensorType>(type)) {
elemTy = type.cast<TensorType>().getElementType(); elemTy = type.cast<TensorType>().getElementType();
}
if (BaseMemRefType::isValidElementType(elemTy)) if (BaseMemRefType::isValidElementType(elemTy))
return type; return type;
return nullptr; return nullptr;
@ -41,7 +43,6 @@ class VerifyStablehloBackendContractPass
auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); };
MLIRContext *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
// Structural operations. // Structural operations.
@ -49,16 +50,26 @@ class VerifyStablehloBackendContractPass
// Shape operations. // Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes); target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>(); target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithDialect>(); target.addLegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the MHLO backend contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() {
return std::make_unique<VerifyStablehloBackendContractPass>(); return std::make_unique<VerifyMhloBackendContractPass>();
} }
#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCH_MLIR_ENABLE_MHLO

View File

@ -20,10 +20,6 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h" #include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) { void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>(); registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>(); registry.insert<mlir::torch::Torch::TorchDialect>();
@ -38,11 +34,4 @@ void mlir::torch::registerAllPasses() {
mlir::torch::registerConversionPasses(); mlir::torch::registerConversionPasses();
mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses(); mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass();
#endif // TORCH_MLIR_ENABLE_STABLEHLO
} }

View File

@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
loc, loc,
/*inputs=*/from, /*inputs=*/from,
/*outputs=*/to, /*outputs=*/to,
/*indexingMaps=*/llvm::ArrayRef({id, id}), /*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) { [](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front()); b.create<linalg::YieldOp>(loc, args.front());

View File

@ -45,16 +45,14 @@ endif()
declare_mlir_python_sources(TorchMLIRPythonSources) declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions) declare_mlir_python_sources(TorchMLIRPythonExtensions)
if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources ADD_TO_PARENT TorchMLIRPythonSources
SOURCES SOURCES
__init__.py __init__.py
compiler_utils.py compiler_utils.py
dynamo.py dynamo.py
) )
endif()
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
@ -93,9 +91,7 @@ if(TORCH_MLIR_ENABLE_LTC)
endif() endif()
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it # Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it
# generates a dummy Python library when disabled. # generates a dummy Python library when disabled.
if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)
endif()
################################################################################ ################################################################################
# Optionally handle JIT IR importer. # Optionally handle JIT IR importer.

View File

@ -44,9 +44,9 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to TOSA. # as taking the `TORCH` output type and lowering it to TOSA.
TOSA = "tosa" TOSA = "tosa"
# This output type consists of `stablehlo` dialect ops. It can be thought of # This output type consists of `mhlo` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to StableHLO. # as taking the `TORCH` output type and lowering it to MHLO.
STABLEHLO = "stablehlo" MHLO = "mhlo"
# Raw output of the JIT IR importer. This is not expected to be useful # Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs. # for end-users, but can be convenient for development or reporting bugs.
@ -242,7 +242,7 @@ class ExampleArgs:
BACKEND_LEGAL_OPS = { BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.STABLEHLO: [], OutputType.MHLO: [],
} }
@ -290,7 +290,7 @@ def compile(model: torch.nn.Module,
# We only allow `backend_legal_ops` to be specified for the `"torch"` # We only allow `backend_legal_ops` to be specified for the `"torch"`
# output type because the other output types actually invoke their # output type because the other output types actually invoke their
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have # respective backends (Linalg, TOSA, or MHLO), and those backends have
# very specific requirements about the ops which are legal. # very specific requirements about the ops which are legal.
# See `BACKEND_LEGAL_OPS` for more details. # See `BACKEND_LEGAL_OPS` for more details.
if backend_legal_ops is not None: if backend_legal_ops is not None:
@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
print(mb.module) print(mb.module)
return mb.module return mb.module
elif output_type == OutputType.STABLEHLO: elif output_type == OutputType.MHLO:
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
mb.module, mb.module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)", "builtin.module(torch-backend-to-mhlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR") "Lowering Torch Backend IR -> MHLO Backend IR")
if verbose: if verbose:
print("\n====================") print("\n====================")
print("StableHLO Backend IR") print("MHLO Backend IR")
print(mb.module) print(mb.module)
return mb.module return mb.module
raise Exception(f"Unknown OutputType: {output_type}") raise Exception(f"Unknown OutputType: {output_type}")

View File

@ -44,7 +44,7 @@ def run_pipeline_with_repro_report(module,
# Lower module in place to make it ready for compiler backends. # Lower module in place to make it ready for compiler backends.
with module.context: with module.context:
pm = PassManager.parse(pipeline) pm = PassManager.parse(pipeline)
pm.run(module.operation) pm.run(module)
except Exception as e: except Exception as e:
# TODO: More robust. # TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many # - don't arbitrarily clutter up /tmp. When a test suite has many

View File

@ -71,7 +71,6 @@ add_library(torch_mlir_ltc_backend SHARED
mlir_node.cpp mlir_node.cpp
ops/device_data.cpp ops/device_data.cpp
ops/generic.cpp ops/generic.cpp
utils/jit_utils.cpp
utils/tensor_utils.cpp utils/tensor_utils.cpp
) )
target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17)

Some files were not shown because too many files have changed in this diff Show More