From ca224bcf17b152e7fb69700f96af5e55406c3be9 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 15 Mar 2023 11:25:26 -0700 Subject: [PATCH] Revert "Merge main into dtype-functions-staging (#1935)" This reverts commit 042d58b699a4e9f5eddf55f249811e0ef03b4f13. --- .github/actions/setup-build/action.yml | 5 +- .github/workflows/RollPyTorch.yml | 15 +- .github/workflows/bazelBuildAndTest.yml | 6 - .github/workflows/buildAndTest.yml | 10 +- .github/workflows/buildRelease.yml | 45 +- .github/workflows/gh-pages-releases.yml | 7 +- .github/workflows/oneshotSnapshotPackage.yml | 8 +- .github/workflows/releaseSnapshotPackage.yml | 9 +- .gitignore | 3 - CMakeLists.txt | 18 +- README.md | 28 +- build-requirements.txt | 2 + .../python_deploy/build_linux_packages.sh | 32 +- .../python_deploy/build_macos_packages.sh | 32 +- build_tools/python_deploy/build_windows.ps1 | 8 +- .../python_deploy/install_macos_deps.sh | 4 +- docs/architecture.md | 27 +- docs/code_owners.md | 2 +- docs/development.md | 2 +- docs/long_term_roadmap.md | 14 +- e2e_testing/main.py | 16 +- e2e_testing/xfail_sets.py | 111 +---- examples/torchdynamo_resnet18.py | 2 +- examples/torchscript_mhlo_backend_resnet.py | 14 + ...y => torchscript_mhlo_backend_tinybert.py} | 8 +- .../torchscript_stablehlo_backend_resnet.py | 14 - .../Dialect/TMTensor/IR/TMTensorInterfaces.h | 2 +- .../Dialect/TMTensor/IR/TMTensorInterfaces.td | 2 +- .../lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 8 +- externals/llvm-project | 2 +- externals/mlir-hlo | 2 +- include/torch-mlir/Conversion/CMakeLists.txt | 4 +- include/torch-mlir/Conversion/Passes.td | 10 +- .../TorchToMhlo.h} | 11 +- .../Dialect/Torch/IR/GeneratedTorchOps.td | 356 ++------------ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 5 +- .../Dialect/Torch/Transforms/Passes.h | 5 +- .../Dialect/Torch/Transforms/Passes.td | 17 +- .../Dialect/Torch/Utils/TorchUpstream.h | 10 - .../torch-mlir/Dialect/Torch/Utils/Utils.h | 2 +- .../TorchConversion/Transforms/CMakeLists.txt | 4 +- .../TorchConversion/Transforms/Passes.h | 15 +- .../TorchConversion/Transforms/Passes.td | 10 +- lib/CAPI/TorchTypes.cpp | 8 +- lib/CMakeLists.txt | 26 +- lib/Conversion/CMakeLists.txt | 10 +- lib/Conversion/Passes.cpp | 16 +- .../TorchConversionToMLProgram.cpp | 2 +- lib/Conversion/TorchToArith/TorchToArith.cpp | 71 +-- lib/Conversion/TorchToLinalg/DataMovement.cpp | 20 +- lib/Conversion/TorchToLinalg/Reduction.cpp | 56 +-- .../TorchToLinalg/TensorConstructors.cpp | 14 +- .../TorchToLinalg/Uncategorized.cpp | 193 +++----- .../Basic.cpp | 407 +++++++--------- lib/Conversion/TorchToMhlo/CMakeLists.txt | 35 ++ .../Gather.cpp | 46 +- .../Linear.cpp | 190 ++++---- .../MhloLegalizeUtils.cpp} | 76 ++- .../MhloLegalizeUtils.h} | 24 +- .../Pooling.cpp | 307 +++++------- lib/Conversion/TorchToMhlo/PopulatePatterns.h | 74 +++ .../Reduction.cpp | 247 +++++----- .../TorchToMhlo.cpp} | 54 +- .../ViewLike.cpp | 68 ++- .../TorchToStablehlo/CMakeLists.txt | 29 -- .../TorchToStablehlo/PopulatePatterns.h | 69 --- .../TorchToTMTensor/TorchToTMTensor.cpp | 380 +-------------- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 360 +------------- .../TorchToTosa/TosaLegalizeUtils.cpp | 4 - lib/Dialect/Torch/IR/TorchDialect.cpp | 5 +- lib/Dialect/Torch/IR/TorchOps.cpp | 357 +++++--------- lib/Dialect/Torch/IR/TorchTypes.cpp | 38 +- .../Transforms/AbstractInterpLibrary.cpp | 345 +------------ .../Transforms/AdjustCallingConventions.cpp | 1 + lib/Dialect/Torch/Transforms/CMakeLists.txt | 1 - .../Torch/Transforms/DecomposeComplexOps.cpp | 461 +++--------------- .../Transforms/EraseModuleInitializer.cpp | 2 +- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 14 +- .../Torch/Transforms/InlineGlobalSlots.cpp | 4 +- .../Transforms/LowerToBackendContract.cpp | 28 +- lib/Dialect/Torch/Transforms/Passes.cpp | 1 - .../PrepareForGlobalizeObjectGraph.cpp | 1 + .../Torch/Transforms/RecomposeComplexOps.cpp | 103 ---- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 128 +++-- .../Transforms/SimplifyDtypeCalculations.cpp | 9 +- .../Transforms/SimplifyShapeCalculations.cpp | 9 +- lib/Dialect/Torch/Utils/TorchUpstream.cpp | 19 - lib/Dialect/Torch/Utils/Utils.cpp | 10 +- .../IR/TorchConversionDialect.cpp | 4 +- .../TorchConversion/IR/TorchConversionOps.cpp | 16 +- .../BackendTypeConversionPasses.cpp | 1 + .../TorchConversion/Transforms/CMakeLists.txt | 4 +- .../TorchConversion/Transforms/Passes.cpp | 41 +- ...ract.cpp => VerifyMhloBackendContract.cpp} | 33 +- lib/InitAll.cpp | 11 - lib/RefBackend/RefBackend.cpp | 2 +- python/CMakeLists.txt | 22 +- python/torch_mlir/__init__.py | 18 +- python/torch_mlir/compiler_utils.py | 2 +- .../csrc/base_lazy_backend/CMakeLists.txt | 1 - .../mlir_lowering_context.cpp | 53 +- .../base_lazy_backend/mlir_lowering_context.h | 8 +- .../base_lazy_backend/mlir_node_lowering.cpp | 11 - .../base_lazy_backend/shape_inference.cpp | 7 - .../base_lazy_backend/utils/jit_utils.cpp | 45 -- .../csrc/base_lazy_backend/utils/jit_utils.h | 10 - .../build_tools/abstract_interp_lib_gen.py | 56 +-- .../jit_ir/build_tools/library_generator.py | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 30 +- python/torch_mlir/dynamo.py | 16 +- .../torch_mlir_e2e_test/configs/__init__.py | 2 +- .../{stablehlo_backend.py => mhlo_backend.py} | 21 +- .../configs/torchdynamo.py | 40 +- python/torch_mlir_e2e_test/framework.py | 4 +- .../linalg_on_tensors_backends/refbackend.py | 2 +- .../__init__.py | 0 .../abc.py | 17 +- .../linalg_on_tensors.py | 23 +- .../test_suite/__init__.py | 4 +- .../test_suite/backprop.py | 22 - .../torch_mlir_e2e_test/test_suite/basic.py | 370 +------------- python/torch_mlir_e2e_test/test_suite/cast.py | 4 +- .../test_suite/constant_alloc.py | 47 +- python/torch_mlir_e2e_test/test_suite/conv.py | 34 +- .../test_suite/elementwise.py | 122 +---- .../test_suite/elementwise_comparison.py | 80 --- .../histogram_binning_calibration.py | 2 +- .../test_suite/{scatter.py => index_put.py} | 148 ------ .../test_suite/index_select.py | 14 +- .../test_suite/nll_loss.py | 12 +- .../test_suite/reduction.py | 92 +--- .../test_suite/reshape_like.py | 17 - python/torch_mlir_e2e_test/test_suite/rng.py | 75 --- .../torch_mlir_e2e_test/test_suite/scalar.py | 59 +-- .../test_suite/scalar_comparison.py | 8 +- .../test_suite/slice_like.py | 70 +-- .../test_suite/threshold.py | 18 +- .../test_suite/type_promotion.py | 2 +- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 3 +- requirements.txt | 7 +- setup.py | 15 +- test-requirements.txt | 5 - test/CMakeLists.txt | 2 +- test/Conversion/TorchToMhlo/basic.mlir | 84 ++-- test/Conversion/TorchToMhlo/elementwise.mlir | 114 ++--- test/Conversion/TorchToMhlo/gather.mlir | 14 +- test/Conversion/TorchToMhlo/linear.mlir | 78 +-- test/Conversion/TorchToMhlo/lit.local.cfg | 2 +- test/Conversion/TorchToMhlo/pooling.mlir | 76 +-- test/Conversion/TorchToMhlo/view_like.mlir | 38 +- test/Conversion/TorchToTosa/basic.mlir | 19 - test/Dialect/Torch/canonicalize.mlir | 138 ------ test/Dialect/Torch/invalid.mlir | 22 +- test/Dialect/Torch/refine-types-ops.mlir | 16 - test/Dialect/Torch/refine-types.mlir | 24 - .../Torch/verify-backend-contract-error.mlir | 12 +- test/lit.site.cfg.py.in | 2 +- test/python/smoketest.py | 5 - tools/torch-mlir-opt/torch-mlir-opt.cpp | 9 +- torchvision-requirements.txt | 3 - utils/bazel/docker/Dockerfile | 6 +- utils/bazel/torch-mlir-overlay/BUILD.bazel | 18 +- .../bazel/torch-mlir-overlay/test/BUILD.bazel | 2 +- whl-requirements.txt | 1 - 165 files changed, 1846 insertions(+), 5802 deletions(-) create mode 100644 examples/torchscript_mhlo_backend_resnet.py rename examples/{torchscript_stablehlo_backend_tinybert.py => torchscript_mhlo_backend_tinybert.py} (69%) delete mode 100644 examples/torchscript_stablehlo_backend_resnet.py rename include/torch-mlir/Conversion/{TorchToStablehlo/TorchToStablehlo.h => TorchToMhlo/TorchToMhlo.h} (64%) rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/Basic.cpp (80%) create mode 100644 lib/Conversion/TorchToMhlo/CMakeLists.txt rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/Gather.cpp (87%) rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/Linear.cpp (83%) rename lib/Conversion/{TorchToStablehlo/StablehloLegalizeUtils.cpp => TorchToMhlo/MhloLegalizeUtils.cpp} (84%) rename lib/Conversion/{TorchToStablehlo/StablehloLegalizeUtils.h => TorchToMhlo/MhloLegalizeUtils.h} (79%) rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/Pooling.cpp (63%) create mode 100644 lib/Conversion/TorchToMhlo/PopulatePatterns.h rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/Reduction.cpp (73%) rename lib/Conversion/{TorchToStablehlo/TorchToStablehlo.cpp => TorchToMhlo/TorchToMhlo.cpp} (58%) rename lib/Conversion/{TorchToStablehlo => TorchToMhlo}/ViewLike.cpp (88%) delete mode 100644 lib/Conversion/TorchToStablehlo/CMakeLists.txt delete mode 100644 lib/Conversion/TorchToStablehlo/PopulatePatterns.h delete mode 100644 lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp rename lib/Dialect/TorchConversion/Transforms/{VerifyStablehloBackendContract.cpp => VerifyMhloBackendContract.cpp} (66%) delete mode 100644 python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp delete mode 100644 python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h rename python/torch_mlir_e2e_test/configs/{stablehlo_backend.py => mhlo_backend.py} (74%) rename python/torch_mlir_e2e_test/{stablehlo_backends => mhlo_backends}/__init__.py (100%) rename python/torch_mlir_e2e_test/{stablehlo_backends => mhlo_backends}/abc.py (76%) rename python/torch_mlir_e2e_test/{stablehlo_backends => mhlo_backends}/linalg_on_tensors.py (66%) rename python/torch_mlir_e2e_test/test_suite/{scatter.py => index_put.py} (78%) delete mode 100644 test-requirements.txt delete mode 100644 torchvision-requirements.txt diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index 85c3f7516..4cb234aa6 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -17,7 +17,7 @@ runs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.10' - name: Install MLIR Python depends run: | @@ -26,8 +26,7 @@ runs: - name: Install PyTorch nightly depends run: | - python -m pip install -r pytorch-requirements.txt - python -m pip install -r build-requirements.txt + python -m pip install -r requirements.txt shell: bash - name: Install prerequisites (Linux) diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 373e9618e..24e64922f 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -8,19 +8,12 @@ on: jobs: build_linux: name: Manylinux Build - runs-on: a100 + runs-on: ubuntu-latest # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: - - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Get torch-mlir uses: actions/checkout@v3 with: @@ -38,7 +31,6 @@ jobs: cd ${GITHUB_WORKSPACE} python -m pip install wheel - sudo apt-get install unzip # Fetch the most recent nightly torchvision release VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') @@ -52,8 +44,7 @@ jobs: # Read the version from the downloaded whl file without extracting it PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') echo "Found torch release ${PT_RELEASE}" - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\ntorchvision==%s\n" "${PT_RELEASE}" "${VISION_RELEASE}" > pytorch-requirements.txt # Read the commit hash from the downloaded whl file without extracting it PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'") @@ -105,7 +96,7 @@ jobs: git fetch --recurse-submodules=no git checkout main git pull origin main - git add pytorch-hash.txt pytorch-requirements.txt torchvision-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td + git add pytorch-hash.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main) - name: Update PyTorch Build Cache (if running on main branch) diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 43630adcb..dd053ae7f 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -20,12 +20,6 @@ jobs: runs-on: ubuntu-latest steps: - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checkout torch-mlir uses: actions/checkout@v3 with: diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 5d05f0e51..7dd42ada7 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -51,14 +51,6 @@ jobs: runs-on: ${{ matrix.os }} steps: - - - name: Prepare workspace - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checkout torch-mlir uses: actions/checkout@v3 with: @@ -121,7 +113,7 @@ jobs: -DLLVM_USE_HOST_TOOLS=ON \ -DLLVM_ENABLE_ZSTD=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ + -DTORCH_MLIR_ENABLE_MHLO=OFF \ -DTORCH_MLIR_ENABLE_LTC=OFF \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ -DMACOSX_DEPLOYMENT_TARGET=12.0 \ diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 9bd30e243..9c8cc488b 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -13,25 +13,8 @@ on: jobs: build_linux: name: Manylinux Build - runs-on: a100 - strategy: - matrix: - package: [ torch-mlir, torch-mlir-core ] - py_version: [ cp38-cp38, cp310-cp310, cp311-cp311 ] - exclude: - - package: torch-mlir-core - py_version: cp38-cp38 - - package: torch-mlir-core - py_version: cp310-cp310 - + runs-on: ubuntu-latest steps: - - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Get torch-mlir uses: actions/checkout@v3 with: @@ -45,7 +28,7 @@ jobs: python -m pip install wheel TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh + ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. @@ -73,7 +56,7 @@ jobs: run: mkdir dist - name: Copy releases to publish to dist directory if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/ # Wheels must be published from a linux environment. # @@ -87,9 +70,6 @@ jobs: build_macos: name: MacOS Build runs-on: macos-latest - strategy: - matrix: - package: [ torch-mlir, torch-mlir-core ] steps: - name: Get torch-mlir uses: actions/checkout@v3 @@ -105,7 +85,7 @@ jobs: TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version sudo ./build_tools/python_deploy/install_macos_deps.sh - packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh + TORCH_MLIR_PYTHON_VERSIONS="3.10" ./build_tools/python_deploy/build_macos_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. @@ -133,7 +113,7 @@ jobs: run: mkdir dist - name: Copy releases to publish to dist directory if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + run: cp build_tools/python_deploy/wheelhouse/torch_mlir-*.whl dist/ # Wheels must be published from a linux environment. # @@ -147,9 +127,6 @@ jobs: build_windows: name: Windows Build runs-on: windows-latest - strategy: - matrix: - package: [ torch-mlir, torch-mlir-core ] steps: - name: Get torch-mlir uses: actions/checkout@v3 @@ -165,14 +142,6 @@ jobs: - name: Build Python wheels and smoke test. shell: pwsh run: | - if ( "${{ matrix.package }}" -eq "torch-mlir-core" ) - { - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='0' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='1' - } else { - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' - } $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' ./build_tools/python_deploy/build_windows.ps1 @@ -203,7 +172,7 @@ jobs: continue-on-error: true - name: Copy releases to publish to dist directory if: github.event.inputs.release_id != '' - run: cp ./wheelhouse/torch_mlir*.whl dist/ + run: cp ./wheelhouse/torch_mlir-*.whl dist/ # Wheels must be published from a linux environment. # @@ -247,4 +216,4 @@ jobs: # if: github.event.inputs.release_id != '' # uses: pypa/gh-action-pypi-publish@v1.5.1 # with: - # password: ${{ secrets.PYPI_API_TOKEN }} + # password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index c6df475cc..ecac146b9 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -13,13 +13,8 @@ jobs: if: github.repository == 'llvm/torch-mlir' steps: - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Run scrape releases script diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index 46832ce9c..c21ed8f9e 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -10,14 +10,8 @@ jobs: # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index c18eff88d..1037abca0 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -13,15 +13,8 @@ jobs: # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: - - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@v2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.gitignore b/.gitignore index 676cd3653..330a871b0 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,3 @@ bazel-* build_oot/ docker_venv/ llvm-build/ - -# C++ build artifacts -compile_commands.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 790fcfeb5..c20627065 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,18 +36,12 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() -option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) -if(TORCH_MLIR_ENABLE_STABLEHLO) - add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) +option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) +if(TORCH_MLIR_ENABLE_MHLO) + add_definitions(-DTORCH_MLIR_ENABLE_MHLO) endif() -option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) -option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF) -if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF) - set(TORCH_MLIR_ENABLE_LTC OFF) -endif() if(TORCH_MLIR_ENABLE_LTC) set(ENV{TORCH_MLIR_ENABLE_LTC} 1) @@ -115,6 +109,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR TORCH_MLIR_OUT_OF_TREE_ # Don't try to compile the python extensions at the moment. We need # to import lots of dependencies from AddMLIRPython to make this work. set(MLIR_ENABLE_BINDINGS_PYTHON 1) + option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) set(TORCH-MLIR_BUILT_STANDALONE 1) set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") @@ -124,6 +119,7 @@ else() # In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF) + option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) # TODO: Fix this upstream so that global include directories are not needed. set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) @@ -132,8 +128,8 @@ else() set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() -if (TORCH_MLIR_ENABLE_STABLEHLO) - set(STABLEHLO_BUILD_EMBEDDED ON) +if (TORCH_MLIR_ENABLE_MHLO) + set(MHLO_BUILD_EMBEDDED ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo EXCLUDE_FROM_ALL) diff --git a/README.md b/README.md index bc8a9748b..3070d3c19 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,13 @@ necessarily a reflection of the completeness or stability of the code, it does indicate that the project is not yet endorsed as a component of LLVM. [PyTorch](https://pytorch.org) -PyTorch is an open source machine learning framework that facilitates the seamless transition from research and prototyping to production-level deployment. +An open source machine learning framework that accelerates the path from research prototyping to production deployment. [MLIR](https://mlir.llvm.org) -The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers. +The MLIR project is a novel approach to building reusable and extensible compiler infrastructure. MLIR aims to address software fragmentation, improve compilation for heterogeneous hardware, significantly reduce the cost of building domain specific compilers, and aid in connecting existing compilers together. + [Torch-MLIR](https://github.com/llvm/torch-mlir) -Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend. +Multiple Vendors use MLIR as the middle layer, mapping from platform frameworks like PyTorch, JAX, and TensorFlow into MLIR and then progressively lowering down to their target hardware. We have seen half a dozen custom lowerings from PyTorch to MLIR. Having canonical lowerings from the PyTorch ecosystem to the MLIR ecosystem provides much needed relief to hardware vendors to focus on their unique value rather than implementing yet another PyTorch frontend for MLIR. The goal is to be similar to current hardware vendors adding LLVM target support instead of each one also implementing Clang / a C++ frontend. [![Release Build](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml/badge.svg)](https://github.com/llvm/torch-mlir/actions/workflows/buildRelease.yml) @@ -42,26 +43,15 @@ We have few paths to lower down to the Torch MLIR Dialect. ## Install torch-mlir snapshot -At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.10 on Linux and macOS. +This installs a pre-built snapshot of torch-mlir for Python 3.7/3.8/3.9/3.10 on Linux and macOS. -If you have Python 3.10, the following commands initialize a virtual environment. ```shell -python3.10 -m venv mlir_venv +python -m venv mlir_venv source mlir_venv/bin/activate -``` - -Or, if you want to switch over multiple versions of Python using conda, you can create a conda environment with Python 3.10. -```shell -conda create -n torch-mlir python=3.10 -conda activate torch-mlir +# Some older pip installs may not be able to handle the recent PyTorch deps python -m pip install --upgrade pip -``` - -Then, we can install torch-mlir with the corresponding torch and torchvision nightlies. -``` -pip install --pre torch-mlir torchvision \ - -f https://llvm.github.io/torch-mlir/package-index/ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu +pip install --pre torch-mlir torchvision -f https://llvm.github.io/torch-mlir/package-index/ --extra-index-url https://download.pytorch.org/whl/nightly/cpu +# This will install the corresponding torch and torchvision nightlies ``` ## Demos diff --git a/build-requirements.txt b/build-requirements.txt index db0c32051..16f7da636 100644 --- a/build-requirements.txt +++ b/build-requirements.txt @@ -1,3 +1,5 @@ +-r pytorch-requirements.txt + numpy pybind11 wheel diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index fd85f6c8d..cc8bf9c52 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -39,16 +39,16 @@ set -eu -o errtrace this_dir="$(cd "$(dirname "$0")" && pwd)" repo_root="$(cd "$this_dir"/../../ && pwd)" # This needs to be a manylinux image so we can ship pip packages -TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-gcr.io/iree-oss/manylinux2014_x86_64-release@sha256:d8994b87b45b7b2e6055fccc32db018ec73aeb05a4e43a9daa61b77cc34f846e}" +TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-stellaraccident/manylinux2014_x86_64-bazel-5.1.0:latest}" # This assumes an Ubuntu LTS like image. You can build your own with # ./build_tools/docker/Dockerfile TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}" # Version of Python to use in Release builds. Ignored in CIs. -TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}" +TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp310-cp310}" # Location to store Release wheels TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" # What "packages to build" -TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}" +TM_PACKAGES="${TM_PACKAGES:-torch-mlir}" # Use pre-built Pytorch TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" # Skip running tests if you want quick iteration @@ -84,11 +84,6 @@ function run_on_host() { export USERID=0 export GROUPID=0 ;; - torch-mlir-core) - TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} - export USERID=0 - export GROUPID=0 - ;; out-of-tree) TM_CURRENT_DOCKER_IMAGE=${TM_CI_DOCKER_IMAGE} # CI uses only Python3.10 @@ -164,12 +159,6 @@ function run_in_docker() { clean_build torch_mlir "$python_version" ;; - torch-mlir-core) - clean_wheels torch_mlir_core "$python_version" - build_torch_mlir_core - run_audit_wheel torch_mlir_core "$python_version" - clean_build torch_mlir_core "$python_version" - ;; out-of-tree) setup_venv "$python_version" build_out_of_tree "$TM_USE_PYTORCH_BINARY" "$python_version" @@ -278,8 +267,8 @@ function test_in_tree() { echo ":::: Run Linalg e2e integration tests" python -m e2e_testing.main --config=linalg -v - echo ":::: Run StableHLO e2e integration tests" - python -m e2e_testing.main --config=stablehlo -v + echo ":::: Run MHLO e2e integration tests" + python -m e2e_testing.main --config=mhlo -v echo ":::: Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v @@ -288,7 +277,7 @@ function test_in_tree() { python -m e2e_testing.main --config=lazy_tensor_core -v echo ":::: Run TorchDynamo e2e integration tests" - python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot + python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic } function setup_venv() { @@ -384,15 +373,6 @@ function run_audit_wheel() { rm "$generic_wheel" } -function build_torch_mlir_core() { - python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt - CMAKE_GENERATOR=Ninja \ - TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ - TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir -} - function clean_wheels() { local wheel_basename="$1" local python_version="$2" diff --git a/build_tools/python_deploy/build_macos_packages.sh b/build_tools/python_deploy/build_macos_packages.sh index b928c1e48..18606a0c2 100755 --- a/build_tools/python_deploy/build_macos_packages.sh +++ b/build_tools/python_deploy/build_macos_packages.sh @@ -20,7 +20,7 @@ set -eu -o errtrace this_dir="$(cd "$(dirname "$0")" && pwd)" repo_root="$(cd "$this_dir"/../../ && pwd)" -python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10 3.11}" +python_versions="${TORCH_MLIR_PYTHON_VERSIONS:-3.9 3.10}" output_dir="${output_dir:-${this_dir}/wheelhouse}" packages="${packages:-torch-mlir}" @@ -61,11 +61,6 @@ function run() { build_torch_mlir torch_mlir "$python_version" run_audit_wheel torch_mlir "$python_version" ;; - torch-mlir-core) - clean_wheels torch_mlir_core "$python_version" - build_torch_mlir_core torch_mlir_core "$python_version" - run_audit_wheel torch_mlir_core "$python_version" - ;; *) echo "Unrecognized package '$package'" exit 1 @@ -82,8 +77,7 @@ function build_torch_mlir() { python"${python_version}" -m venv "$output_dir"/build_venv source "$output_dir"/build_venv/bin/activate python"${python_version}" -m pip install -U pip - python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu - python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt + python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \ @@ -93,25 +87,6 @@ function build_torch_mlir() { rm -rf "$output_dir"/build_venv } -function build_torch_mlir_core() { - local wheel_basename="$1" - local python_version="$2" - rm -rf "$output_dir"/build_venv - python"${python_version}" -m venv "$output_dir"/build_venv - source "$output_dir"/build_venv/bin/activate - python"${python_version}" -m pip install -U pip delocate - python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt - CMAKE_GENERATOR=Ninja \ - TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \ - CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \ - TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ - TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ - python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root" - deactivate - rm -rf "$output_dir"/build_venv -} - function clean_wheels() { local wheel_basename="$1" local python_version="$2" @@ -132,8 +107,7 @@ function run_audit_wheel() { python"${python_version}" -m venv "$output_dir"/test_venv source "$output_dir"/test_venv/bin/activate python"${python_version}" -m pip install -U pip - python"${python_version}" -m pip install -r "$repo_root"/pytorch-requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu - python"${python_version}" -m pip install -r "$repo_root"/build-requirements.txt + python"${python_version}" -m pip install -r "$repo_root"/requirements.txt --extra-index-url https://download.pytorch.org/whl/nightly/cpu python"${python_version}" -m pip install "$generic_wheel" --extra-index-url https://download.pytorch.org/whl/nightly/cpu DYLD_LIBRARY_PATH="$output_dir"/test_venv/lib/python"${python_version}"/site-packages/torch/lib delocate-wheel -v "$generic_wheel" deactivate diff --git a/build_tools/python_deploy/build_windows.ps1 b/build_tools/python_deploy/build_windows.ps1 index 808a16cb1..2c934ccf3 100644 --- a/build_tools/python_deploy/build_windows.ps1 +++ b/build_tools/python_deploy/build_windows.ps1 @@ -13,9 +13,7 @@ Write-Host "Installing Build Dependencies" python -m venv .\mlir_venv\ .\mlir_venv\Scripts\Activate.PS1 -pip install -r .\pytorch-requirements.txt -pip install -r .\build-requirements.txt -pip install delvewheel +pip install -r .\requirements.txt Write-Host "Build Deps installation completed successfully" Write-Host "Building torch-mlir" @@ -24,7 +22,3 @@ $env:TORCH_MLIR_ENABLE_LTC='0' python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt Write-Host "Build completed successfully" - -Write-Host "Fixing up wheel dependencies" -delvewheel repair --add-path .\build\cmake_build\tools\torch-mlir\python_packages\torch_mlir\torch_mlir\_mlir_libs --add-dll TorchMLIRAggregateCAPI.dll --no-dll 'c10.dll;torch_python.dll;torch_cpu.dll' -v (get-item .\wheelhouse\torch_mlir*.whl).FullName -Write-Host "All Done." diff --git a/build_tools/python_deploy/install_macos_deps.sh b/build_tools/python_deploy/install_macos_deps.sh index 4d91a244c..4b413f01d 100755 --- a/build_tools/python_deploy/install_macos_deps.sh +++ b/build_tools/python_deploy/install_macos_deps.sh @@ -19,13 +19,11 @@ if [[ "$(whoami)" != "root" ]]; then fi PYTHON_INSTALLER_URLS=( - "https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg" - "https://www.python.org/ftp/python/3.10.10/python-3.10.10-macos11.pkg" + "https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg" "https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg" ) PYTHON_SPECS=( - 3.11@https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg 3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg 3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg ) diff --git a/docs/architecture.md b/docs/architecture.md index 043bb74ec..1619f81a8 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various - Linalg-on-Tensors (+ `arith`, `tensor`, etc.) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) -- [StableHLO](https://github.com/openxla/stablehlo) +- [MHLO](https://github.com/tensorflow/mlir-hlo) The terms "frontend" and "backend" are highly overloaded in any compiler project, but frequently in Torch-MLIR this is the meaning that they have. Sometimes "frontend" can mean something even further up the stack, such as something in PyTorch itself. When there is ambiguity we will refer to this as "at the PyTorch level". Similarly, "backend" can sometimes refer to something -sitting below Linalg-on-Tensors, TOSA, or StableHLO. +sitting below Linalg-on-Tensors, TOSA, or MHLO. ## The `torch` dialect @@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96 The backend contract is a normalized form of the `torch` dialect with a set of properties that make it easy to lower into various forms such as -Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of -the box. The primary guarantees that we provide Torch-MLIR's backends are: +Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the +box. The primary guarantees that we provide Torch-MLIR's backends are: - All tensors have been converted to value semantics. - All tensors have at least a known number of dimensions (i.e. rank), and @@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are: - [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, `tensor`, etc.) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) -- [StableHLO](https://github.com/openxla/stablehlo) +- [MHLO](https://github.com/tensorflow/mlir-hlo) ### The Linalg Backend (Linalg-on-Tensors) @@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha - It is extremely solid with static shapes (and many of its users only care about static shapes, so that's fine). -### The StableHLO Backend +### The MHLO Backend -Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo -The StableHLO backend was the third backend that we added, and it offers a -reasonable blend of the benefits of the other two. +The MHLO backend was the third backend that we added, and it offers a reasonable +blend of the benefits of the other two. - It is a coarse-grained named-op approach. - It has a pretty clear spec for most of the ops (with a bit of mental - translation and hoping that StableHLO is the same as HLO): + translation and hoping that MHLO is the same as HLO): https://www.tensorflow.org/xla/operation_semantics - It functionally supports dynamic shapes (though not as coherent and consistent as Linalg-on-Tensors, and the dynamic shape support falls outside the @@ -317,7 +317,7 @@ reasonable blend of the benefits of the other two. example, TOSA limits (for highly considered reasons) the number of dimensions that certain operators can handle to 1D-4D, when from a purely algebraic perspective there isn't a good reason to not be more general. Similarly, more - general forms of reduction and scatter also fall into StableHLO nicely while + general forms of reduction and scatter also fall into MHLO nicely while TOSA's principles tend to bias it away from that. ### Backend Implementation @@ -433,9 +433,8 @@ filling in some corners missing upstream and to pull together upstream functionality into a working system. The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the -ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support -lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on -RefBackend. +ops and lowers them to loops. Note that TOSA and MHLO support lowering to +Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend. The RefBackend is absolutely not suitable for any production use case. It leaks memory, doesn't support any error handling, performs no optimizations, and diff --git a/docs/code_owners.md b/docs/code_owners.md index d70b236b9..299f19656 100644 --- a/docs/code_owners.md +++ b/docs/code_owners.md @@ -34,7 +34,7 @@ and Clang's - Eric Kunze (@eric-k256) - Suraj Sudhir (@sjarus) -### TorchToStablehlo +### TorchToMHLO - Tianyo Kwok (@tanyokwok) - Ziheng Jiang (@ZihengJiang) diff --git a/docs/development.md b/docs/development.md index 048f363c0..f6e976769 100644 --- a/docs/development.md +++ b/docs/development.md @@ -139,7 +139,7 @@ Ex: module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") ``` -Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. +Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`. ## Jupyter diff --git a/docs/long_term_roadmap.md b/docs/long_term_roadmap.md index 62c3b6f94..1e8981da1 100644 --- a/docs/long_term_roadmap.md +++ b/docs/long_term_roadmap.md @@ -46,7 +46,7 @@ the ecosystem are: - The frontend work required to lower TorchScript to the backend contract. - The irregular support surface area of the large number of PyTorch ops across - the Linalg, TOSA, and StableHLO backends. + the Linalg, TOSA, and MHLO backends. Most of this document describes long-term ecosystem changes that will address these, drastically improving Torch-MLIR's ability to meet its goals. @@ -108,7 +108,7 @@ more advanced). ### Refactoring the backend Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors, -TOSA, and StableHLO. These backends take IR in the backend contract form (see +TOSA, and MHLO. These backends take IR in the backend contract form (see [architecture.md](architecture.md)) and lowers them to the respective dialects. Today, each backend is implemented completely independently. This leads to duplication and irregularity across the backends. @@ -120,10 +120,12 @@ lowering of so many ops across backends. Additionally, there are 3 forward-looking efforts that intersect with this effort: - [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect - initially forked from MHLO. MHLO is a fairly complete op set, so it is very - attractive to have "almost all" models bottleneck through a stable interface - like StableHLO. StableHLO is currently under relatively early development, - but already delivers on many of the goals of stability. + initially forked from MHLO which intends to create a stable support surface + area for what today is our "at head" dependency on MHLO. MHLO is a fairly + complete op set, so it is very attractive to have "almost all" models + bottleneck through a stable interface like StableHLO. StableHLO is currently + under relatively early development, but already delivers on many of the goals + of stability. - [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect which could serve a role very similar to MHLO, while providing community ownership. TCP is still in early planning phases, but there is strong diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 770d32ca5..d48223ad4 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -16,7 +16,7 @@ from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY from torch_mlir_e2e_test.configs import ( LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, - StablehloBackendTestConfig, + MhloBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, @@ -24,17 +24,17 @@ from torch_mlir_e2e_test.configs import ( ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend -from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend +from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET +from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -42,7 +42,7 @@ def _get_argparse(): help=f""" Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. -"stablehlo": run through torch-mlir"s default StableHLO backend. +"mhlo": run through torch-mlir"s default MHLO backend. "tosa": run through torch-mlir"s default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). @@ -80,9 +80,9 @@ def main(): if args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET - if args.config == "stablehlo": - config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) - xfail_set = all_test_unique_names - STABLEHLO_PASS_SET + if args.config == "mhlo": + config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) + xfail_set = all_test_unique_names - MHLO_PASS_SET elif args.config == "native_torch": config = NativeTorchTestConfig() xfail_set = {} diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5be3b6c11..df51514df 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -26,11 +26,6 @@ TORCHDYNAMO_XFAIL_SET = { # https://github.com/pytorch/pytorch/issues/89629 "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2D_basic", - - # error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic - # https://github.com/llvm/torch-mlir/issues/1859 - "ConvolutionModule2DGroups_basic", - # RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Failed running call_function aten.nll_loss_backward(... # https://github.com/pytorch/pytorch/issues/89630 @@ -44,6 +39,10 @@ TORCHDYNAMO_XFAIL_SET = { # RuntimeError: Failed running call_function aten.uniform(... # https://github.com/pytorch/torchdynamo/issues/1954 "UniformNoCorrelationModule_basic", + # TypeError: expected np.ndarray (got float) + # TODO: This is due to returning a scalar float as output from the test. + # We should probably just standardize all tests to return tensors. + "DivIntModule_basic", #### Torch-MLIR internal compiler errors @@ -67,13 +66,14 @@ TORCHDYNAMO_XFAIL_SET = { "IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", + # %4 = torch.operator "aten.squeeze_.dim"(%3, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.tensor + "Matmul_vecmat", # https://github.com/llvm/torch-mlir/issues/1611 # error: 'tensor.cast' op operand type 'tensor<0xi64>' and result type 'tensor<18xi64>' are cast incompatible "Aten_EmbeddingBagExample_basic", # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal "BernoulliFloatModule_basic", - "BernoulliPModule_basic", # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", @@ -83,16 +83,8 @@ TORCHDYNAMO_XFAIL_SET = { # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", - - #ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777) - "UpSampleNearest2dDynamicFactor_basic", - "ReduceMaxAlongDimUnsignedInt_basic", - #ERROR: value (-56) is not equal to golden value (200) - "AtenIntTensorByteDtypeModule_basic", - # ERROR: assert isinstance(e, FakeTensor) - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - # ERROR: assert isinstance(e, FakeTensor) - "RsubInt0d_NumToTensor_Module_basic", + # error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792 + "StdCorrectionKeepDimModule_basic", # Dtype function transition failures "MobilenetV3Module_basic", @@ -100,12 +92,8 @@ TORCHDYNAMO_XFAIL_SET = { "ResNet18StaticModule_basic", } -STABLEHLO_PASS_SET = { - "MaskedFillScalarIntValueStaticModule_basic", - "MaskedFillScalarFloatValueStaticModule_basic", +MHLO_PASS_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AddSizeIntModule_basic", - "AddSizeIntNegDimModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", @@ -120,15 +108,10 @@ STABLEHLO_PASS_SET = { "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", - "BatchMlpLayerModule_basic", "BmmModule_basic", "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", - "CumsumStaticModule_basic", - "CumsumStaticNegativeDimModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", @@ -143,29 +126,19 @@ STABLEHLO_PASS_SET = { "ElementwiseClampModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMaxModule_basic", - "ElementwisePowModule_basic", "ElementwiseExpModule_basic", - "ElementwiseFlattenBroadcastModule_basic", - "ElementwiseLeakyReluModule_basic", "ElementwiseLogModule_basic", "ElementwiseNegModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSqrtModule_basic", - "ElementwiseSinModule_basic", - "ElementwiseCosModule_basic", - "ElementwiseCeilModule_basic", - "ElementwiseFloorModule_basic", "ElementwiseUnaryModule_basic", - "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseDivScalarModule_basic", "ElementwiseEqDiffWidthScalarModule_basic", "ElementwiseEqFloatScalarModule_basic", @@ -228,8 +201,6 @@ STABLEHLO_PASS_SET = { "Gather2DInputModdule_basic", "GatherRandomIndexModule_basic", "GeluBackwardModule_basic", - "HardswishModule_basic", - "HardswishRandomModule_basic", "HardTanhIntModule_basic", "HardTanhModule_basic", "HardsigmoidModule_basic", @@ -252,8 +223,6 @@ STABLEHLO_PASS_SET = { "MeanDynamicSizesModule_basic", "MeanLargeInputModule_basic", "MeanModule_basic", - "Mlp1LayerModule_basic", - "Mlp2LayerModule_basic", "MmTanhModule_basic", "Mv_basic", "NativeLayerNormModule4D_basic", @@ -270,7 +239,6 @@ STABLEHLO_PASS_SET = { "ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeIntModule_basic", "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SliceSingleIdxModule_basic", "SqueezeDimModule_dynamic", "SqueezeDimModule_negDim", @@ -282,15 +250,9 @@ STABLEHLO_PASS_SET = { "FlattenStaticModule_basic", "FlattenRank0Module_basic", "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsStackModule_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", "LiftFreshCopyModule_basic", "Mlp2LayerModuleNoBias_basic", "NumelModule_basic", - "SiluModule_basic", - "SquareModule_basic", "SqueezeModule_allUnitDim", "SqueezeDimModule_unitDim", "ViewCollapseOnesMiddleModule_basic", @@ -310,7 +272,6 @@ STABLEHLO_PASS_SET = { "Convolution2DStaticModule_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", "ElementwiseBinaryStaticShapeModule_basic", "ReturnThreeTensorFloat32_basic", @@ -327,7 +288,6 @@ STABLEHLO_PASS_SET = { "RsubFloatModule_noalpha_basic", "RsubIntModule_basic", "RsubIntModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", "SliceStaticModule_basic", "SliceModule_basic", "SliceNegIdxModule_basic", @@ -398,7 +358,6 @@ STABLEHLO_PASS_SET = { "ViewExpandCollapseModule_basic", "ViewExpandCollapseWithOnesModule_basic", "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", "ViewNoChangeStaticModule_basic", "ViewNoChange1dModule_basic", "ViewNoChange2dModule_basic", @@ -461,14 +420,12 @@ STABLEHLO_PASS_SET = { "UnsafeViewDynamicExpandModule_basic", "AtenRoundIntModule_basic", "TestF16Return_basic", - "_LogSoftmaxModuleStable_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", "ElementwiseUnaryModule_basic", "ElementwiseBinaryModule_basic", @@ -492,7 +449,6 @@ TOSA_PASS_SET = { "ViewExpandOnesMiddleOppModule_basic", "ViewOffsetBackwardTestStaticModule_basic", "TanhBackward_basic", - "HardtanhBackward_basic", "ElementwiseAddModule_basic", "ReturnThreeTensorFloat32_basic", "AddCMulModule_basic", @@ -503,7 +459,6 @@ TOSA_PASS_SET = { "BoolTensorReturnMixedModule_basic", "BoolTensorHandleSignless_basic", "ElementwiseRsqrtModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SqueezeModule_static", "SqueezeModule_noUnitDim", "SqueezeModule_allUnitDim", @@ -525,7 +480,6 @@ TOSA_PASS_SET = { "Matmul_3d", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", "ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", @@ -555,7 +509,6 @@ TOSA_PASS_SET = { "ElementwiseDivScalarModule_basic", "ElementwiseSubScalarFloatModule_basic", "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseMulScalarModule_float", "ElementwiseCeilModule_basic", "ElementwiseReciprocalModule_basic", @@ -619,7 +572,6 @@ TOSA_PASS_SET = { "ViewExpandCollapseWithOnesModule_basic", "ViewCollapseInferredDimModule_basic", "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", "ViewNoChangeStaticModule_basic", "UnsafeViewExpandModule_basic", "ReshapeCollapseModule_basic", @@ -652,7 +604,6 @@ TOSA_PASS_SET = { "_LogSoftmaxModuleStable_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", - "MaskedFillScalarIntValueModule_basic", "MaskedFillScalarIntValueStaticModule_basic", "MaskedFillTensorIntValueStaticModule_basic", "ElementwiseAddScalarInt64Module_basic", @@ -660,11 +611,8 @@ TOSA_PASS_SET = { "TensorOpaqueLiteralModule_basic", "TypePromotionDifferentCategoryModule_basic", "TypePromotionSameCategoryDifferentWidthModule_basic", - "TypePromotionSameCategoryZeroRankWider_basic", "TypePromotionZeroRankHigherCategoryModule_basic", "GatherStaticModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorMultiIndexStaticModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", @@ -703,10 +651,6 @@ TOSA_PASS_SET = { "HardsigmoidRandomModule_basic", "HardswishModule_basic", "HardswishRandomModule_basic", - "FullLikeModuleInt2DStatic_basic", - "FullModuleInt3D_basic", - "FullModuleFloat2D_basic", - "RepeatModule_basic" } LTC_XFAIL_SET = { @@ -722,7 +666,7 @@ LTC_XFAIL_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AddIntModule_basic", - "AtenIntBoolOpModule_basic", + "BernoulliFloatModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", @@ -742,7 +686,6 @@ LTC_XFAIL_SET = { "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", - "HardtanhBackward_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -777,8 +720,6 @@ LTC_XFAIL_SET = { "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiInputContiguousCenter_basic", "IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputOneDim_basic", @@ -811,8 +752,6 @@ LTC_XFAIL_SET = { "SubFloatModule_basic", "SubIntModule_basic", "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsStackPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", "TensorToFloatZeroRank_basic", @@ -849,34 +788,4 @@ LTC_XFAIL_SET = { "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", - "RandnLikeModule_basic", - "RandnLikeDtypeModule_basic", - "NewEmptyStridedModuleDefaultDtype_basic", - "BernoulliFloatModule_basic", - "BernoulliModule_basic", - "BernoulliPModule_basic", - "DropoutTrainModule_basic", - "StdCorrectionKeepDimModule_basic", - "StdCorrectionNoneModule_basic", - "SliceCopy_Module_basic", - "SliceCopyNegative_Module_basic", - "VarBiasedModule_basic", - "VarCorrectionAllDimReduceModule_basic", - "VarCorrectionEmptyDimModule_basic", - "VarCorrectionKeepDimModule_basic", - "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", - "VarCorrectionNoneModule_basic", - "VarCorrectionSingleDimReduceModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimBiasedModule_basic", - "VarDimEmptyDimModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", - "VarDimNoneDimModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", - "VarUnbiasedModule_basic", - "AtenFloatScalarModule_basic" } diff --git a/examples/torchdynamo_resnet18.py b/examples/torchdynamo_resnet18.py index d7abd80da..44d155b5d 100644 --- a/examples/torchdynamo_resnet18.py +++ b/examples/torchdynamo_resnet18.py @@ -91,4 +91,4 @@ resnet18 = models.resnet18(pretrained=True) resnet18.train(False) dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18) -predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).detach().numpy(), img, labels) +predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels) diff --git a/examples/torchscript_mhlo_backend_resnet.py b/examples/torchscript_mhlo_backend_resnet.py new file mode 100644 index 000000000..bb481f6c3 --- /dev/null +++ b/examples/torchscript_mhlo_backend_resnet.py @@ -0,0 +1,14 @@ +import torch +import torchvision.models as models +import torch_mlir + +model = models.resnet18(pretrained=True) +model.eval() +data = torch.randn(2,3,200,200) +out_mhlo_mlir_path = "./resnet18_mhlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_stablehlo_backend_tinybert.py b/examples/torchscript_mhlo_backend_tinybert.py similarity index 69% rename from examples/torchscript_stablehlo_backend_tinybert.py rename to examples/torchscript_mhlo_backend_tinybert.py index c035be3a5..62827361e 100644 --- a/examples/torchscript_stablehlo_backend_tinybert.py +++ b/examples/torchscript_mhlo_backend_tinybert.py @@ -15,10 +15,10 @@ class BertTinyWrapper(torch.nn.Module): model = BertTinyWrapper() model.eval() data = torch.randint(30522, (2, 128)) -out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" +out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) -with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) +with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) -print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}") +print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_stablehlo_backend_resnet.py b/examples/torchscript_stablehlo_backend_resnet.py deleted file mode 100644 index 7a97359cf..000000000 --- a/examples/torchscript_stablehlo_backend_resnet.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torchvision.models as models -import torch_mlir - -model = models.resnet18(pretrained=True) -model.eval() -data = torch.randn(2,3,200,200) -out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" - -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) -with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: - outf.write(str(module)) - -print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}") diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index f16b436c8..7c9d884a6 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -10,7 +10,7 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ -#include "mlir/IR/IRMapping.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td index 8e7be05e1..1f23a190c 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td @@ -457,7 +457,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, "ValueRange":$operands), [{ - IRMapping bvm; + BlockAndValueMapping bvm; OperationState state( loc, ConcreteOp::getOperationName(), operands, resultTypes, $_op->getAttrs()); diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 6ce9b502f..71473a8b1 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, } auto scfIf = b.create( - loc, cond, + loc, TypeRange{}, cond, [&](OpBuilder &b, Location loc) { if (isInclusive) { auto value = b.create(loc, input(), indices); @@ -232,7 +232,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, auto &srcBlock = getRegion().front(); Region &thisRegion = scfIf.getElseRegion(); - IRMapping bvm; + BlockAndValueMapping bvm; { OpBuilder::InsertionGuard guard(b); auto &block = thisRegion.front(); @@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } -LogicalResult ScanOp::fold(FoldAdaptor adaptor, +LogicalResult ScanOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); } @@ -461,7 +461,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, Value init = b.create(loc, original(), starts); - IRMapping bvm; + BlockAndValueMapping bvm; Block &block = getRegion().front(); bvm.map(block.getArgument(0), update); bvm.map(block.getArgument(1), init); diff --git a/externals/llvm-project b/externals/llvm-project index 21f4b84c4..de3f0f7fa 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 21f4b84c456b471cc52016cf360e14d45f7f2960 +Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index b1ac0403e..2c8823d25 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit b1ac0403ee2a40fc648ada6b9f11096f3d50fd19 +Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4 diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d65523149..9ee80b304 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) +if(TORCH_MLIR_ENABLE_MHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) else() mlir_tablegen(Passes.h.inc -gen-pass-decls) endif() diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index b5f30bfbe..7072b8d5f 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; } -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { - let summary = "Convert Torch ops to Stablehlo ops"; +#ifdef TORCH_MLIR_ENABLE_MHLO +def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { + let summary = "Convert Torch ops to MHLO ops"; let description = [{ - Convert Torch ops to Stablehlo ops. + Convert Torch ops to mhlo ops. }]; - let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; + let constructor = "mlir::torch::createConvertTorchToMhloPass()"; // Specify any options. let options = [ diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h similarity index 64% rename from include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h rename to include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h index c19260159..8e2f5fc86 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h +++ b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H -#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" @@ -16,11 +16,10 @@ namespace mlir { namespace torch { +std::unique_ptr> createConvertTorchToMhloPass(); std::unique_ptr> -createConvertTorchToStablehloPass(); -std::unique_ptr> -createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); +createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index); } // namespace torch } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2a43ec66c..82038cb58 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2014,55 +2014,6 @@ def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ }]; } -def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$min, - AnyTorchOptionalTensorType:$max - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenClampTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$min, - AnyTorchOptionalTensorType:$max - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenClamp_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ AllowsTypeRefinement, HasValueSemantics, @@ -3389,7 +3340,6 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; - let hasCanonicalizer = 1; } def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ @@ -3688,31 +3638,6 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ }]; } -def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_FloatType:$p, - AnyTorchOptionalGeneratorType:$generator - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenBernoulliPOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenBernoulliPOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ AllowsTypeRefinement, HasValueSemantics, @@ -3846,34 +3771,6 @@ def Torch_AtenRandnGeneratorOp : Torch_Op<"aten.randn.generator", [ }]; } -def Torch_AtenRandnLikeOp : Torch_Op<"aten.randn_like", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory, - AnyTorchOptionalIntType:$memory_format - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRandnLikeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenRandnLikeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics, @@ -5249,11 +5146,11 @@ def Torch_AtenStdCorrectionOp : Torch_Op<"aten.std.correction", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchOptionalListOfTorchIntType:$dim, - AnyTorchOptionalScalarType:$correction, + AnyTorchOptionalIntType:$correction, Torch_BoolType:$keepdim ); let results = (outs @@ -5325,11 +5222,11 @@ def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchOptionalListOfTorchIntType:$dim, - AnyTorchOptionalScalarType:$correction, + AnyTorchOptionalIntType:$correction, Torch_BoolType:$keepdim ); let results = (outs @@ -5351,11 +5248,11 @@ def Torch_AtenVarMeanCorrectionOp : Torch_Op<"aten.var_mean.correction", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)`"; + let summary = "Generated op for `aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchOptionalListOfTorchIntType:$dim, - AnyTorchOptionalScalarType:$correction, + AnyTorchOptionalIntType:$correction, Torch_BoolType:$keepdim ); let results = (outs @@ -6585,35 +6482,6 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [ }]; } -def Torch_AtenNewEmptyStridedOp : Torch_Op<"aten.new_empty_strided", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size, - AnyTorchListOfTorchIntType:$stride, - AnyTorchOptionalIntType:$dtype, - AnyTorchOptionalIntType:$layout, - AnyTorchOptionalDeviceType:$device, - AnyTorchOptionalBoolType:$pin_memory - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenNewEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void AtenNewEmptyStridedOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [ AllowsTypeRefinement, HasValueSemantics, @@ -7107,6 +6975,30 @@ def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ let hasFolder = 1; } +def Torch_AtenStackOp : Torch_Op<"aten.stack", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenStackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenSumOp : Torch_Op<"aten.sum", [ AllowsTypeRefinement, HasValueSemantics, @@ -7664,86 +7556,6 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [ }]; } -def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::scatter_add_ : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchTensorType:$src - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenScatterAdd_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenScatterAdd_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - -def Torch_AtenScatterReduceTwoOp : Torch_Op<"aten.scatter_reduce.two", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchTensorType:$src, - Torch_StringType:$reduce, - Torch_BoolType:$include_self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenScatterReduceTwoOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenScatterReduceTwoOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - -def Torch_AtenScatterReduce_TwoOp : Torch_Op<"aten.scatter_reduce_.two", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::scatter_reduce_.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchTensorType:$src, - Torch_StringType:$reduce, - Torch_BoolType:$include_self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenScatterReduce_TwoOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); - } - void AtenScatterReduce_TwoOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); - } - }]; -} - def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ AllowsTypeRefinement, HasValueSemantics, @@ -8858,31 +8670,6 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [ let hasFolder = 1; } -def Torch_AtenStackOp : Torch_Op<"aten.stack", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; - let arguments = (ins - AnyTorchListOfTensorType:$tensors, - Torch_IntType:$dim - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenStackOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ AllowsTypeRefinement ]> { @@ -9298,7 +9085,6 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasFolder = 1; } def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [ @@ -9325,30 +9111,6 @@ def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [ let hasFolder = 1; } -def Torch_AtenIntBoolOp : Torch_Op<"aten.Int.bool", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::Int.bool : (bool) -> (int)`"; - let arguments = (ins - Torch_BoolType:$a - ); - let results = (outs - Torch_IntType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenIntBoolOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenIntBoolOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; - let hasFolder = 1; -} - def Torch_Aten__RangeLengthOp : Torch_Op<"aten.__range_length", [ AllowsTypeRefinement, HasValueSemantics, @@ -9818,7 +9580,6 @@ def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_AtenMulFloatOp : Torch_Op<"aten.mul.float", [ @@ -10089,31 +9850,6 @@ def Torch_AtenGtFloatIntOp : Torch_Op<"aten.gt.float_int", [ }]; } -def Torch_AtenPowIntFloatOp : Torch_Op<"aten.pow.int_float", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::pow.int_float : (int, float) -> (float)`"; - let arguments = (ins - Torch_IntType:$a, - Torch_FloatType:$b - ); - let results = (outs - Torch_FloatType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenPowIntFloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenPowIntFloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ AllowsTypeRefinement, HasValueSemantics, @@ -10571,7 +10307,6 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; } def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ @@ -10624,32 +10359,6 @@ def Torch_AtenTanhBackwardOp : Torch_Op<"aten.tanh_backward", [ }]; } -def Torch_AtenHardtanhBackwardOp : Torch_Op<"aten.hardtanh_backward", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchScalarType:$min_val, - AnyTorchScalarType:$max_val - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenHardtanhBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenHardtanhBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -11024,7 +10733,6 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ @@ -11225,11 +10933,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`"; + let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$inp, AnyTorchOptionalListOfTorchIntType:$dims, - Torch_FloatType:$correction, + Torch_IntType:$correction, AnyTorchOptionalIntType:$output_dtype ); let results = (outs diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 889a29908..d4c1a5a3f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -376,6 +376,9 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", [ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ Pure, + TypesMatchWith<"contained types correspond to operand types", + "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))", + "isValidSubtype">, AllowedInModuleInitializer, ]> { let summary = "TorchScript prim::TupleConstruct op"; @@ -394,8 +397,6 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ let assemblyFormat = [{ $elements attr-dict `:` qualified(type($elements)) `->` qualified(type($result)) }]; - - let hasVerifier = 1; } def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 4cf27639a..930e6fac1 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -98,8 +98,6 @@ std::unique_ptr> createRefinePublicReturnPass(); std::unique_ptr> createDecomposeComplexOpsPass(ArrayRef legalOps); -std::unique_ptr> createRecomposeComplexOps(); - std::unique_ptr> createPreprocessShapeLibraryPass(); std::unique_ptr> createReifyShapeCalculationsPass(); @@ -123,7 +121,8 @@ createLowerToBackendContractPass(int maxIterations, bool decompose, ArrayRef backendLegalOps); std::unique_ptr> -createVerifyBackendContractNoDecompositionsPass(); +createVerifyBackendContractPass(bool decompose, + ArrayRef backendLegalOps); StringRef getAbstractInterpLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 1ee87b36e..5dcf2286b 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -343,17 +343,24 @@ def LowerToBackendContract let dependentDialects = ["func::FuncDialect"]; } -def VerifyBackendContractNoDecompositions - : Pass<"torch-verify-backend-contract-no-decompositions", "ModuleOp"> { +def VerifyBackendContract + : Pass<"torch-verify-backend-contract", "ModuleOp"> { let summary = "Check that program satisfies backend contract."; let constructor = [{ - mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() + mlir::torch::Torch::createVerifyBackendContractPass( + /*decompose=*/true, /*backendLegalOps=*/{}) }]; let description = [{ This pass performs a set of inspections to check that program satisfies backend - contract assuming that no decompositions were applied. In case of check failure - it prints out the error message and returns `signalPassFailure()` status. + contract. In case of check failure it prints out the error message and returns + `signalPassFailure()` status. }]; + let options = [ + Option<"decompose", "decompose", "bool", /*default=*/"true", + "Decompose ops.">, + ListOption<"backendLegalOps", "backend-legal-ops", "std::string", + "List of ops to be considered legal for the backend."> + ]; } #endif // TORCHMLIR_TORCH_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index efb114fbf..d6bc0a699 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -9,7 +9,6 @@ #define TORCHMLIR_DIALECT_TORCH_UPSTREAM_H #include "mlir/Support/LLVM.h" -#include "llvm/ADT/StringRef.h" // For layering reasons, the parts of the core MLIR compiler code written in C++ // never take a C++ dependency on Torch itself (any code depending on Torch C++ @@ -161,15 +160,6 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; //===-----------------------------------------------------------------------===// enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX }; -//===----------------------------------------------------------------------===// -// Possible value for `reduce` argument for Scatter reduce ops. -// Source: -// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h -//===-----------------------------------------------------------------------===// -enum ReductionType { MAX, MEAN, MIN, SUM, PROD }; - -ReductionType get_reduction_enum(const llvm::StringRef &reduce); - } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 4e3f3cecc..a5cbcf52c 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl &elems); std::optional matchLegalConstantIndexIntoListOfSize(Value v, int64_t length); torch_upstream::ScalarType getScalarTypeForType(Type type); -FailureOr getTypeForScalarType( +Type getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt index 77e46eb4b..00818899f 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) +if(TORCH_MLIR_ENABLE_MHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) else() mlir_tablegen(Passes.h.inc -gen-pass-decls) endif() diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index e6493a154..fd350da1d 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// TOSA backend contract. void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); -// Do not register the stablehlo options if the stablehlo target is disabled -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -struct StablehloBackendPipelineOptions - : public PassPipelineOptions { +// Do not register the torch-to-mhlo pipeline if mhlo target is disabled +#ifdef TORCH_MLIR_ENABLE_MHLO +struct MhloBackendPipelineOptions + : public PassPipelineOptions { Option enableStaticShape{ *this, "enable-static-shape", llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; @@ -46,10 +46,9 @@ struct StablehloBackendPipelineOptions llvm::cl::init(false)}; }; -void createTorchBackendToStablehloBackendPipeline( - OpPassManager &pm, const StablehloBackendPipelineOptions &options); -std::unique_ptr> -createVerifyStablehloBackendContractPass(); +void createTorchBackendToMhloBackendPipeline( + OpPassManager &pm, const MhloBackendPipelineOptions &options); +std::unique_ptr> createVerifyMhloBackendContractPass(); #endif std::unique_ptr> createFuncBackendTypeConversionPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index cb58dbbd9..4ce7cdadb 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; } -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the stablehlo backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; +#ifdef TORCH_MLIR_ENABLE_MHLO +def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the mhlo backend contract"; + let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()"; } -#endif // TORCH_MLIR_ENABLE_STABLEHLO +#endif // TORCH_MLIR_ENABLE_MHLO #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 04384b98b..7609d89b4 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -61,7 +61,7 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, return wrap(Torch::TupleType::get( unwrap(context), llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), + llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes), [](MlirType t) { return unwrap(t); })))); } @@ -89,7 +89,7 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, return wrap(Torch::UnionType::get( unwrap(context), llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), + llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes), [](MlirType t) { return unwrap(t); })))); } @@ -230,7 +230,7 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, std::optional> optionalSizesArrayRef = std::nullopt; // if numSizes == -1, then it is unranked. if (numSizes > -1) - optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes); + optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::NonValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); } @@ -293,7 +293,7 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, std::optional> optionalSizesArrayRef = std::nullopt; // if numSizes == -1, then it is unranked. if (numSizes > -1) - optionalSizesArrayRef = llvm::ArrayRef(optionalSizes, numSizes); + optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::ValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 4c37cca5e..ec6ee8cee 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,7 +3,13 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(RefBackend) -set(LinkedLibs +add_mlir_library(TorchMLIRInitAll + InitAll.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC MLIRFuncDialect MLIRIR MLIRSupport @@ -21,22 +27,4 @@ set(LinkedLibs TorchMLIRRefBackend ) -if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND LinkedLibs - MhloPasses - MhloToLinalg - StablehloToMhlo - ) -endif() - -add_mlir_library(TorchMLIRInitAll - InitAll.cpp - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - ${LinkedLibs} -) - torch_mlir_target_includes(TorchMLIRInitAll) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index d72563b1e..29812d1fe 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) add_subdirectory(TorchToTosa) -if(TORCH_MLIR_ENABLE_STABLEHLO) - add_subdirectory(TorchToStablehlo) +if(TORCH_MLIR_ENABLE_MHLO) + add_subdirectory(TorchToMhlo) endif() add_subdirectory(TorchToTMTensor) add_subdirectory(TorchConversionToMLProgram) @@ -17,8 +17,10 @@ set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram TorchMLIRConversionUtils) -if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND linked_libs TorchMLIRTorchToStablehlo) +if(TORCH_MLIR_ENABLE_MHLO) + list(APPEND linked_libs + MhloPasses + TorchMLIRTorchToMhlo) endif() add_mlir_library(TorchMLIRConversionPasses diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 45714601d..f07a3afb3 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,15 +9,15 @@ #include "torch-mlir/Conversion/Passes.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mhlo/transforms/passes.h" #include "transforms/passes.h" -#endif // TORCH_MLIR_ENABLE_STABLEHLO - +#endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" @@ -32,4 +32,12 @@ namespace { void mlir::torch::registerConversionPasses() { ::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_MHLO + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::mhlo::createLegalizeHloToLinalgPass(); + }); + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::mhlo::createSymbolicShapeOptimizationPass(); + }); +#endif // TORCH_MLIR_ENABLE_MHLO } diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index 839bae364..60c126b06 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -68,7 +68,7 @@ public: // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = rewriter.create(loc, seed, globalVar); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 1f921dcaa..4ed1843f2 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -232,67 +232,6 @@ public: return success(); } }; - -class ConvertTorchConstantIntOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = Torch::ConstantIntOp::Adaptor; - LogicalResult - matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // note: arith.constant only accept singless integer, so convert singed to - // singless - rewriter.replaceOpWithNewOp( - op, rewriter.getIntegerAttr(rewriter.getI64Type(), - op.getValueAttr().getValue())); - return success(); - } -}; -} // namespace - -namespace { -class ConvertAtenFloatScalarOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenFloatScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = - this->getTypeConverter()->convertType(op->getResult(0).getType()); - Value result = - convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - -namespace { -class ConvertAtenAddOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenAddOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type resultType = - this->getTypeConverter()->convertType(op->getResult(0).getType()); - Value operandA = - convertScalarToDtype(rewriter, loc, adaptor.getA(), resultType); - Value operandB = - convertScalarToDtype(rewriter, loc, adaptor.getB(), resultType); - if (resultType.isa()) { - rewriter.replaceOpWithNewOp(op, operandA, operandB); - } else if (resultType.isa()) { - rewriter.replaceOpWithNewOp(op, operandA, operandB); - } else { - return rewriter.notifyMatchFailure( - op, "unimplemented: only support integer or float result type"); - } - return success(); - } -}; } // namespace namespace { @@ -442,14 +381,8 @@ public: patterns.add>(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - + patterns.add>(typeConverter, + context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 293649de5..6ace4926d 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -463,8 +463,8 @@ public: } SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef outputShapeInt = llvm::ArrayRef(outputSizeInt); - ArrayRef inputShapeInt = llvm::ArrayRef(inputSize); + ArrayRef outputShapeInt = llvm::makeArrayRef(outputSizeInt); + ArrayRef inputShapeInt = llvm::makeArrayRef(inputSize); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds @@ -1117,18 +1117,6 @@ public: RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); - - auto outElemType = newResultType.getElementType(); - auto dtypePromoteBody = [&](OpBuilder &builder, Location loc, - ValueRange payloadArgs) { - Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], outElemType); - builder.create(loc, elem); - }; - for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric( - rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody); - } - int rank = newResultType.getRank(); SmallVector offsets, sizes, strides; sizes.reserve(rank); @@ -1148,7 +1136,7 @@ public: Value dimIndex = rewriter.createOrFold( loc, rewriter.getIndexAttr(dim)); - for (auto tensor : ArrayRef(tensors).drop_front()) { + for (auto tensor : makeArrayRef(tensors).drop_front()) { auto size = rewriter.createOrFold(loc, tensor, dimIndex); resultDimSize = rewriter.createOrFold(loc, resultDimSize, size); @@ -1282,7 +1270,7 @@ public: /*resultType=*/selfType, /*inputs=*/broadcastedSrc, /*outputs=*/self, - /*indexingMaps=*/llvm::ArrayRef({id, id}), + /*indexingMaps=*/llvm::makeArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { Value result = args[0]; diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index da308ce53..1a48a0023 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -81,21 +81,9 @@ public: Type inElementType = inputType.getElementType(); if (!inElementType.isa()) { - if (inElementType.isa()) { - auto integerTy = maxDimOp.getSelf() - .getType() - .cast() - .getDtype() - .dyn_cast(); - if (integerTy.isUnsigned()) - return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim to linalg.* requires input element type " - "to be signed in case of integer"); - } else { - return rewriter.notifyMatchFailure( - maxDimOp, "aten.max_dim to linalg.* requires Float or Integer " - "input element type"); - } + return rewriter.notifyMatchFailure( + maxDimOp, + "aten.max_dim to linalg.* requires Float input element type"); } // Constant op to account for the reduction along dim. @@ -116,23 +104,13 @@ public: Value initTensorMax = rewriter.create( loc, getAsOpFoldResult(resultShape), inElementType); - Value fillValueMax; - if (inElementType.isa()) { - fillValueMax = rewriter.create( - loc, - rewriter.getFloatAttr( - inElementType, - APFloat::getLargest( - inElementType.cast().getFloatSemantics(), - true))); - } else { - fillValueMax = rewriter.create( - loc, rewriter.getIntegerAttr( - inElementType, - APSInt::getSignedMinValue( - inElementType.cast().getWidth()))); - } + FloatAttr fillValueMaxAttr = rewriter.getFloatAttr( + inElementType, + APFloat::getLargest( + inElementType.cast().getFloatSemantics(), true)); + Value fillValueMax = + rewriter.create(loc, fillValueMaxAttr); Value filledTensorMax = rewriter.create(loc, fillValueMax, initTensorMax) .result(); @@ -174,18 +152,10 @@ public: nestedLoc, oldIndex.getType(), rewriter.create(loc, dim)); - Value resultMax, predicate; - if (inElementType.isa()) { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); - } else { - resultMax = - rewriter.create(nestedLoc, newValue, oldValue); - predicate = rewriter.create( - nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); - } + auto resultMax = rewriter.create( + nestedLoc, newValue, oldValue); + Value predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index e861a1877..42ec8657e 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -127,14 +127,9 @@ public: if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( + resultElementType = getTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt, IntegerType::Signless); - if (failed(maybeResultElementType)) { - return rewriter.notifyMatchFailure( - op, "unable to convert `dtypeInt` to builtin type"); - } - resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape and fill it with @@ -232,14 +227,9 @@ public: if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( + resultElementType = getTypeForScalarType( op->getContext(), (torch_upstream::ScalarType)dtypeInt, IntegerType::Signless); - if (failed(maybeResultElementType)) { - return rewriter.notifyMatchFailure( - op, "unable to convert `dtypeInt` to builtin type"); - } - resultElementType = *maybeResultElementType; } // Create an uninitialized tensor of `resultSize` shape. diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 49730d5bf..bc16c8c1e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -59,15 +59,6 @@ static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, b, loc, elementalType, lhs, rhs); } -static Value createGreaterThanOrEqual(OpBuilder &b, Location loc, - Type elementalType, Value lhs, - Value rhs) { - return createComparisonTemplate( - b, loc, elementalType, lhs, rhs); -} - static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate( - b, loc, elementalType, lhs, rhs); -} - static Value createEqual(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate(loc, arg); } -template -static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); - - Type lhsDtype = lhs.getType(); - Type rhsDtype = rhs.getType(); - - // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs - // to be handled. - if (lhsDtype != rhsDtype) { - op.emitError("unimplemented: lhs and rhs dtype must be same"); - return nullptr; - } - - Type elementalType = - op.getSelf().getType().template cast().getDtype(); - if constexpr (std::is_same()) { - return createLessThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createEqual(b, loc, elementalType, lhs, rhs); - } - llvm_unreachable("unimplemented: op type not supported"); -} - static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { @@ -234,10 +177,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (!clone.getMemoryFormat().getType().isa() && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || - (memoryFormat != torch_upstream::MemoryFormat::Contiguous && - memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { - clone.emitError("unimplemented: only contiguous and channels last memory " - "format is supported"); + memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { + clone.emitError("unimplemented: only default memory format is supported"); return nullptr; } return payloadArgs[0]; @@ -352,7 +293,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( round.emitError("unimplemented: non-floating point dtype"); return nullptr; } - return b.create(loc, payloadArgs[0]); + return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { if (!prelu.getType() @@ -429,29 +370,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value cdfExt = b.create(loc, dinputInputAlpha, cdf); return b.create(loc, payloadArgs[0], cdfExt); } - if (auto hardtanhBackward = dyn_cast(op)) { - AtenHardtanhBackwardOp::Adaptor adaptor(operands); - if (!hardtanhBackward.getType() - .cast() - .getDtype() - .isa()) { - hardtanhBackward.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } - Value gradOutput = payloadArgs[0]; - Type elementType = gradOutput.getType(); - Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType); - Value constantZero = - b.create(loc, FloatAttr::get(elementType, 0.0)); - Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType); - Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType); - Value lesser = - b.create(loc, arith::CmpFPredicate::ULT, self, min); - Value greater = - b.create(loc, arith::CmpFPredicate::UGT, self, max); - Value cmp = b.create(loc, lesser, greater); - return b.create(loc, cmp, constantZero, gradOutput); - } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); Type dtype = converter->convertType(add.getType()) @@ -545,25 +463,64 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } - if (auto ltTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], - payloadArgs[1]); - } - if (auto leTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], - payloadArgs[1]); - } if (auto gtTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], - payloadArgs[1]); - } - if (auto geTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], - payloadArgs[1]); + AtenGtTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) { + gtTensor.emitError("unimplemented: different lhs and rhs dtype"); + return nullptr; + } + + Type elementalType = + gtTensor.getSelf().getType().cast().getDtype(); + return createGreaterThan(b, loc, elementalType, payloadArgs[0], + payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], - payloadArgs[1]); + AtenEqTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) { + eqTensor.emitError("unimplemented: lhs and rhs dtype must be same"); + return nullptr; + } + + Type elementalType = + eqTensor.getSelf().getType().cast().getDtype(); + + if (elementalType.isa()) + return b.create(loc, arith::CmpFPredicate::UEQ, + payloadArgs[0], payloadArgs[1]); + if (elementalType.isa()) { + return b.create(loc, arith::CmpIPredicate::eq, + payloadArgs[0], payloadArgs[1]); + } + eqTensor.emitError("unimplemented: dtype isn't supported."); + return nullptr; + } + if (auto ltTensor = dyn_cast(op)) { + AtenLtTensorOp::Adaptor adaptor(operands); + Type lhsDtype = payloadArgs[0].getType(); + Type rhsDtype = payloadArgs[1].getType(); + + // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs + // to be handled. + if (lhsDtype != rhsDtype) { + ltTensor.emitError("unimplemented: lhs and rhs dtype must be same"); + return nullptr; + } + + Type elementalType = + ltTensor.getSelf().getType().cast().getDtype(); + return createLessThan(b, loc, elementalType, payloadArgs[0], + payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); @@ -1007,6 +964,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); return convertScalarToDtype(b, loc, adaptor.getValue(), dtype); } + if (auto maskedFillScalar = dyn_cast(op)) { + AtenMaskedFillScalarOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(maskedFillScalar.getType()) + .cast() + .getElementType(); + + Value input = payloadArgs[0]; + Value mask = payloadArgs[1]; + Value fillValue = convertScalarToDtype(b, loc, adaptor.getValue(), dtype); + + return b.create(loc, mask, fillValue, input); + } if (auto maskedFillTensor = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); Type dtype = converter->convertType(maskedFillTensor.getType()) @@ -1065,7 +1034,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value allOnesVal = b.create( loc, b.getIntegerAttr( elementType, - APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); + APSInt::getAllOnesValue(elementType.getIntOrFloatBitWidth()))); return b.create(loc, payloadArgs[0], allOnesVal); } @@ -1113,10 +1082,10 @@ public: AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, + AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, + AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op)) @@ -1592,12 +1561,12 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(); + AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, + AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, + AtenFillScalarOp, AtenFillTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp similarity index 80% rename from lib/Conversion/TorchToStablehlo/Basic.cpp rename to lib/Conversion/TorchToMhlo/Basic.cpp index d84fbaf9b..a773ae652 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -7,16 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -30,7 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, @@ -44,16 +43,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, - unsqueezeDims, dimSizeIndexBits); + auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other, + unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, - dimSizeIndexBits); + auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self, + unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; @@ -79,8 +78,7 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/false)); - return rewriter - .create(op->getLoc(), constType, constAttr) + return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { @@ -93,8 +91,7 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } - return rewriter - .create(op->getLoc(), constType, constAttr) + return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); @@ -108,8 +105,7 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/true)); - return rewriter - .create(op->getLoc(), constType, constAttr) + return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { @@ -122,8 +118,7 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } - return rewriter - .create(op->getLoc(), constType, constAttr) + return rewriter.create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); @@ -131,7 +126,7 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, // These legalizations are for unary ops. namespace { -template +template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -142,13 +137,13 @@ public: Value self = adaptor.getSelf(); auto selfType = self.getType().cast(); if (!selfType) { - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); } auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - self = hlo::promoteType(rewriter, self, outType); - rewriter.replaceOpWithNewOp(op, outType, self); + self = mhlo::promoteType(rewriter, self, outType); + rewriter.replaceOpWithNewOp(op, outType, self); return success(); } }; @@ -157,7 +152,7 @@ public: // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { -template +template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -169,10 +164,10 @@ public: auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); if (selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -203,7 +198,7 @@ public: .template dyn_cast(); if (!outType) - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) @@ -221,9 +216,9 @@ public: SmallVector values(size, fillVal); auto constOp = - hlo::getConstTensor(rewriter, op, values, shape).value(); + mhlo::getConstTensor(rewriter, op, values, shape).value(); - rewriter.replaceOpWithNewOp(op, outType, constOp); + rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; @@ -252,8 +247,8 @@ public: ->convertType(op.getType()) .template cast(); - lhs = hlo::promoteType(rewriter, lhs, outTy); - rhs = hlo::promoteType(rewriter, rhs, outTy); + lhs = mhlo::promoteType(rewriter, lhs, outTy); + rhs = mhlo::promoteType(rewriter, rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -279,7 +274,7 @@ public: RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -292,19 +287,18 @@ public: } if (!rhsType) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), - outElemTy); + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); if (isa(op)) { std::swap(lhs, rhs); } } - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { - Value alpha = hlo::scalarToStablehloTensor(rewriter, op, - adaptor.getAlpha(), outElemTy); + Value alpha = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); @@ -334,7 +328,7 @@ public: TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -349,12 +343,11 @@ public: if (std::is_same()) { rhs = lhs; } else if (!rhsType) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), - outElemTy); + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -375,15 +368,15 @@ public: if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. - auto sign = rewriter.create(loc, result); - auto abs = rewriter.create(loc, result); - auto floor = rewriter.create(loc, abs); - result = rewriter.create(loc, sign, floor).getResult(); + auto sign = rewriter.create(loc, result); + auto abs = rewriter.create(loc, result); + auto floor = rewriter.create(loc, abs); + result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - result = rewriter.create(loc, result).getResult(); + result = rewriter.create(loc, result).getResult(); } rewriter.replaceOp(op, result); return success(); @@ -408,7 +401,7 @@ public: RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("only Tensor types supported in StableHLO"); + return op.emitError("only Tensor types supported in MHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -421,12 +414,11 @@ public: } if (!rhsTy) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), - lhsElemTy); + rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy); } // TODO: what is the PyTorch default type promotion? - rhs = hlo::promoteType(rewriter, rhs, lhsTy); + rhs = mhlo::promoteType(rewriter, rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; @@ -493,8 +485,8 @@ public: TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType); - Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType); + Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType); + Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -545,8 +537,8 @@ public: RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); - rewriter.replaceOpWithNewOp(op, outType, self, - permutation); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); return success(); } }; @@ -560,7 +552,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); - rewriter.replaceOpWithNewOp(op, outType, self); + rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -581,8 +573,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { Value inputRank = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); - dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), - inputRank); + dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); dim = rewriter.create(op.getLoc(), rewriter.getIndexType(), dim); } @@ -598,8 +589,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenWhereSelfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + AtenWhereSelfOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { Value self = adaptor.getSelf(); Value cond = adaptor.getCondition(); Value other = adaptor.getOther(); @@ -613,7 +605,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), + op, + getTypeConverter()->convertType(op.getType()), ArrayRef{cond, self, other}); return success(); } @@ -630,7 +623,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .cast(); if (options.enableStaticShape && selfTy.hasStaticShape()) { - Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); + Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } @@ -677,7 +670,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); } @@ -715,11 +708,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); - rewriter.replaceOpWithNewOp(op, outType, self, - permutation); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); return success(); } +// AtenTanhOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTanhOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + if (selfTy && selfTy.getElementType().isa()) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } else { + return op.emitError( + "only floating-point datatype legalization currently supported"); + } +} + // ValueTensorLiteralOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -741,16 +751,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); - rewriter.replaceOpWithNewOp(op, resultType, - valueAttr); + rewriter.replaceOpWithNewOp(op, resultType, valueAttr); return success(); } - rewriter.replaceOpWithNewOp(op, resultType, - adaptor.getValue()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); return success(); } + // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> @@ -767,45 +777,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); - rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); - return success(); -} - -// AtenPowTensorScalarOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsType = lhs.getType().dyn_cast(); - Value rhs = adaptor.getExponent(); - TensorType rhsType = rhs.getType().dyn_cast(); - - if (!lhsType) - return op.emitError("only Tensor types supported in StableHLO"); - - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - - if (!rhsType) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, - outElemTy); - } - DenseIntElementsAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, lhs, outType); - rhs = hlo::promoteType(rewriter, rhs, outType); - auto loc = op.getLoc(); - Value result = - rewriter.create(loc, outType, lhs, rhs, bcastDimensions); - - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } @@ -818,9 +790,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ->convertType(op->getResult(0).getType()) .cast(); auto outputElemType = outputType.getElementType(); - Value stablehloTensor = hlo::scalarToStablehloTensor( - rewriter, op, adaptor.getA(), outputElemType); - rewriter.replaceOp(op, stablehloTensor); + Value mhloTensor = + mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType); + rewriter.replaceOp(op, mhloTensor); return success(); } @@ -843,6 +815,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + // AtenReluOp // Relu(x) = Max(0, x) template <> @@ -863,10 +836,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); - rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); + rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } + // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] template <> @@ -883,12 +857,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); - auto rsqrtTwo = rewriter.create(loc, two); - auto erfElement = rewriter.create(loc, input, rsqrtTwo); + auto rsqrtTwo = rewriter.create(loc, two); + auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); - auto erfAdd = rewriter.create(loc, erf, one); - auto halfMul = rewriter.create(loc, erfAdd, half); - rewriter.replaceOpWithNewOp(op, input, halfMul); + auto erfAdd = rewriter.create(loc, erf, one); + auto halfMul = rewriter.create(loc, erfAdd, half); + rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } @@ -907,6 +881,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + // AtenBatchNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -944,28 +919,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { - weight = hlo::getConstantOfShape( + weight = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { - bias = hlo::getConstantOfShape( + bias = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { - runningVar = hlo::getConstantOfShape( + runningVar = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { - runningMean = hlo::getConstantOfShape( + runningMean = mhlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, @@ -1008,11 +983,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); - auto batchNormTrainingResult = - rewriter.create( - op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, - weight, bias, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(1)); + auto batchNormTrainingResult = rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); return success(); } else { @@ -1021,11 +995,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); - // Feature counts must match among operands of - // stablehlo::BatchNormInferenceOp. + // Feature counts must match among operands of mhlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); - Value output = rewriter.create( + Value output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. @@ -1035,6 +1008,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } + // AtenNativeLayerNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1102,21 +1076,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; - SmallVector meanOrVarStablehloOutShape{numFeatureDimSize}; + SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; - auto stablehloBatchNormOutTy = + auto mhloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); - auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get( - meanOrVarStablehloOutShape, inputTy.getElementType()); + auto mhloBathNormOutMeanOrVarTy = + RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); // Reshape input - auto stablehloInput = rewriter.create( - op->getLoc(), stablehloBatchNormOutTy, input, - hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), - {static_cast(inputFlattenShape.size())}) + auto mhloInput = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, input, + mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), + {static_cast(inputFlattenShape.size())}) .value()); - // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. + // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() @@ -1129,18 +1103,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); - Value scale = rewriter.create( + Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); - Value offset = rewriter.create( + Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); - auto batchNormTrainingResult = - rewriter.create( - op->getLoc(), stablehloBatchNormOutTy, - stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, - stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(1)); + auto batchNormTrainingResult = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, + mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = @@ -1148,35 +1120,36 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outputMeanOrVarTy = getTypeConverter()->convertType(op.getType(1)).cast(); - auto output = rewriter.create( + auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), - hlo::getConstTensor(rewriter, op, outputTy.getShape(), - {static_cast(outputTy.getShape().size())}) + mhlo::getConstTensor(rewriter, op, outputTy.getShape(), + {static_cast(outputTy.getShape().size())}) .value()); - auto mean = rewriter.create( + auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), - hlo::getConstTensor( + mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); - auto var = rewriter.create( + auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), - hlo::getConstTensor( + mhlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); // Apply affine transform: output x weight + bias [element-wise] - auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); - auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); + auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = - rewriter.create(op->getLoc(), output, bcastedWeight); - auto finalOuput = rewriter.create( - op->getLoc(), outputMulWeight, bcastedBias); + rewriter.create(op->getLoc(), output, bcastedWeight); + auto finalOuput = + rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } + // AtenCatOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1200,11 +1173,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Promote type for (auto &v : builtinTensors) { - v = hlo::promoteType(rewriter, v, outType); + v = mhlo::promoteType(rewriter, v, outType); } size_t posDim = toPositiveDim(dim, outType.getRank()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outType, ValueRange(builtinTensors), posDim); return success(); } @@ -1252,8 +1225,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(checkNotNone(rewriter, op, minValue))) { - maxValue = - hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( @@ -1261,8 +1233,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } minValue = *minInfo; } else if (failed(checkNotNone(rewriter, op, maxValue))) { - minValue = - hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( @@ -1270,13 +1241,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } maxValue = *maxInfo; } else { - minValue = - hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); - maxValue = - hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); + minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); } - rewriter.replaceOpWithNewOp(op, minValue, input, - maxValue); + rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); return success(); } @@ -1298,27 +1266,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: only int or float dtype supported"); } - Value start = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype); - Value end = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype); - Value step = - hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype); + Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype); + Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype); + Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype); // Get length of the 1-d output tensor - Value subOut = rewriter.create(loc, end, start); - Value divOut = rewriter.create(loc, subOut, step); + Value subOut = rewriter.create(loc, end, start); + Value divOut = rewriter.create(loc, subOut, step); - Value resultLength = rewriter.create( + Value resultLength = rewriter.create( loc, RankedTensorType::get({1}, dtype), divOut); if (dtype.isa()) { - resultLength = rewriter.create(loc, resultLength); - resultLength = rewriter.create( + resultLength = rewriter.create(loc, resultLength); + resultLength = rewriter.create( loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); } Value window = - rewriter.create(loc, outType, resultLength, 0); + rewriter.create(loc, outType, resultLength, 0); DenseIntElementsAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); @@ -1333,8 +1298,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto outType = - this->getTypeConverter()->convertType(op.getType()).cast(); + auto outType = this->getTypeConverter() + ->convertType(op.getType()) + .cast(); if (!outType) { return op.emitError("only tensor type is supported"); } @@ -1354,27 +1320,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); // Compute - Value kBeta0 = - rewriter.create(loc, outType, kAlpha, cstAlpha0); - Value kBeta = rewriter.create(loc, outType, kBeta0, half); - Value erfArg = rewriter.create(loc, outType, kAlpha, - adaptor.getSelf()); + Value kBeta0 = rewriter.create(loc, outType, kAlpha, cstAlpha0); + Value kBeta = rewriter.create(loc, outType, kBeta0, half); + Value erfArg = + rewriter.create(loc, outType, kAlpha, adaptor.getSelf()); Value erf = rewriter.create(loc, outType, erfArg); - Value erfAdd = rewriter.create(loc, outType, erf, one); - Value cdf = rewriter.create(loc, outType, erfAdd, half); - Value inputSquared = rewriter.create( + Value erfAdd = rewriter.create(loc, outType, erf, one); + Value cdf = rewriter.create(loc, outType, erfAdd, half); + Value inputSquared = rewriter.create( loc, outType, adaptor.getSelf(), adaptor.getSelf()); Value negHalfInputSquared = - rewriter.create(loc, outType, inputSquared, negHalf); + rewriter.create(loc, outType, inputSquared, negHalf); Value expRes = - rewriter.create(loc, outType, negHalfInputSquared); - Value pdf = rewriter.create(loc, outType, kBeta, expRes); + rewriter.create(loc, outType, negHalfInputSquared); + Value pdf = rewriter.create(loc, outType, kBeta, expRes); Value pdfTimesInput = - rewriter.create(loc, outType, pdf, adaptor.getSelf()); + rewriter.create(loc, outType, pdf, adaptor.getSelf()); Value pdfTimesInputAddCdf = - rewriter.create(loc, outType, pdfTimesInput, cdf); - rewriter.replaceOpWithNewOp( - op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); + rewriter.create(loc, outType, pdfTimesInput, cdf); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getGradOutput(), + pdfTimesInputAddCdf); return success(); } @@ -1401,9 +1366,9 @@ public: }; } // namespace -void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); @@ -1411,29 +1376,23 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); -#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ +#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context) - INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); - INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); - INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); + patterns.add>(typeConverter, context) + INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp); + INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp); + INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp); + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp); #undef INSERT_UNARY_PATTERN -#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ +#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); - INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp); - INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); - INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); - INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); + patterns.add>(typeConverter, context) + INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); + INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -1500,9 +1459,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); @@ -1523,10 +1482,10 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenWhereSelfOp); #undef INSERT_ATENOP_PATTERN -#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ +#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context) + patterns.add>(typeConverter, \ + context) INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt new file mode 100644 index 000000000..4c0929268 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -0,0 +1,35 @@ +add_mlir_conversion_library(TorchMLIRTorchToMhlo + TorchToMhlo.cpp + MhloLegalizeUtils.cpp + Basic.cpp + Gather.cpp + Linear.cpp + ViewLike.cpp + Reduction.cpp + Pooling.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo + + DEPENDS + MhloDialect + MhloToLinalg + MLIRMhloPassIncGen + LMHLOTransformsPassIncGen + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MhloDialect + MhloToLinalg + MLIRBufferTransforms + StablehloOps + TorchMLIRTorchDialect + TorchMLIRConversionUtils +) + +torch_mlir_target_includes(TorchMLIRTorchToMhlo) diff --git a/lib/Conversion/TorchToStablehlo/Gather.cpp b/lib/Conversion/TorchToMhlo/Gather.cpp similarity index 87% rename from lib/Conversion/TorchToStablehlo/Gather.cpp rename to lib/Conversion/TorchToMhlo/Gather.cpp index 437332703..8d7a3f5c0 100644 --- a/lib/Conversion/TorchToStablehlo/Gather.cpp +++ b/lib/Conversion/TorchToMhlo/Gather.cpp @@ -7,15 +7,14 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -25,7 +24,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, @@ -70,7 +69,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, SmallVector startIndexMap(1, axis); // indexVecDim int64_t indexVecDim = indicesRank; - auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, @@ -92,18 +91,17 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); return rewriter - .create(loc, outputTy, input, indices, - sliceSizesTensor, dimsAttr) + .create(loc, outputTy, input, indices, + sliceSizesTensor, dimsAttr) .getResult(); } } // namespace -// Ref: -// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html +// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) -// – If specified, the entries at padding_idx do not contribute to the -// gradient; therefore, the embedding vector at padding_idx is not updated -// during training, i.e. it remains as a fixed “pad”. +// – If specified, the entries at padding_idx do not contribute to the gradient; +// therefore, the embedding vector at padding_idx is not updated during training, +// i.e. it remains as a fixed “pad”. // scale_grad_by_freq (boolean, optional) // – If given, this will scale gradients by the inverse of frequency of the // words in the mini-batch. Default False. @@ -141,7 +139,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value output = gatherTensorAlongSingleAxis( rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); @@ -163,7 +161,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value output = gatherTensorAlongSingleAxis( rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); @@ -202,7 +200,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto options = getOptions(); auto indexShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); @@ -225,15 +223,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { - toConcat.push_back(rewriter.create( + toConcat.push_back(rewriter.create( loc, toConcatIndexType, index, toConcatIndexShape)); } else { - toConcat.push_back(rewriter.create( + toConcat.push_back(rewriter.create( loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } - auto gatherIndicies = rewriter.create( + auto gatherIndicies = rewriter.create( loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); @@ -245,22 +243,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( startIndexMap.push_back(i); } - auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); return success(); } -void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp similarity index 83% rename from lib/Conversion/TorchToStablehlo/Linear.cpp rename to lib/Conversion/TorchToMhlo/Linear.cpp index fbc3d6ee4..8632af4ba 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -7,16 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -26,7 +25,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; namespace { Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, @@ -34,7 +33,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef broadcastDims) { auto tensorTy = tensor.getType().dyn_cast(); auto loc = op->getLoc(); - Value stablehloShape = rewriter.create(loc, dimSizes); + Value mhloShape = rewriter.create(loc, dimSizes); RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); @@ -44,8 +43,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, rewriter.getIntegerType(64)); auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); - auto broadcast = rewriter.create( - loc, outTy, tensor, stablehloShape, broadcastAttr); + auto broadcast = rewriter.create( + loc, outTy, tensor, mhloShape, broadcastAttr); return broadcast; } @@ -53,7 +52,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, ArrayRef inpTransDims) { auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); - auto transDims = hlo::toPositiveDims(inpTransDims, rank); + auto transDims = mhlo::toPositiveDims(inpTransDims, rank); auto inpShape = inputTy.getShape(); std::vector newShape; newShape.reserve(rank); @@ -67,8 +66,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); - auto result = rewriter.create(op->getLoc(), outTy, - input, permuteAttr); + auto result = rewriter.create(op->getLoc(), outTy, input, + permuteAttr); return result.getResult(); } @@ -120,12 +119,10 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, } // set result dimensions - if (lhsResultDim < static_cast(lhsShape.size()) && - lhsResultDim >= 0) { + if (lhsResultDim < static_cast(lhsShape.size()) && lhsResultDim >= 0) { outShape.push_back(lhsShape[lhsResultDim]); } - if (rhsResultDim < static_cast(rhsShape.size()) && - rhsResultDim >= 0) { + if (rhsResultDim < static_cast(rhsShape.size()) && rhsResultDim >= 0) { outShape.push_back(rhsShape[rhsResultDim]); } return RankedTensorType::get(outShape, lhsTy.getElementType()); @@ -154,10 +151,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims, - dimSizeIndexBits); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, rhs, leadingDims, dimSizeIndexBits); auto lhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + *mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), lhsDimSizes.end()); lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, @@ -166,10 +163,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(lhsShape.begin(), lhsShape.begin() + leadingRank); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims, - dimSizeIndexBits); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, lhs, leadingDims, dimSizeIndexBits); auto rhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + *mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), rhsDimSizes.end()); rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, @@ -221,8 +218,8 @@ public: if (lhsRank <= 2 && rhsRank <= 2) { auto tensorType = ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - output = rewriter.create(op->getLoc(), tensorType, lhs, - rhs, nullptr); + output = rewriter.create(op->getLoc(), tensorType, lhs, rhs, + nullptr); return success(); } @@ -256,8 +253,8 @@ public: lhsContractingDim = nBatchDims; } - stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = - stablehlo::DotDimensionNumbersAttr::get( + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims, @@ -267,8 +264,8 @@ public: castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); output = rewriter - .create(op->getLoc(), outTy, lhs, rhs, - dotDimensionNumbers, nullptr) + .create(op->getLoc(), outTy, lhs, rhs, + dotDimensionNumbers, nullptr) .getResult(); return success(); } @@ -315,7 +312,7 @@ public: if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in StableHLO matmul"); + "only ranked tensor types are supported in MHLO matmul"); return success(); } @@ -338,7 +335,7 @@ public: if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in StableHLO matmul"); + "only ranked tensor types are supported in MHLO matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -374,7 +371,7 @@ public: if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in StableHLO matmul"); + "only ranked tensor types are supported in MHLO matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -401,10 +398,10 @@ public: auto bias = adaptor.getBias(); auto biasTy = bias.getType(); - // StableHLO does not mandate that elementwise op tensors need to be ranked. + // MHLO does not mandate that elementwise op tensors need to be ranked. if (!biasTy.template isa() && !biasTy.template isa()) - return op.emitError("only ranked tensor types are supported in StableHLO " + return op.emitError("only ranked tensor types are supported in MHLO " "matmul for bias tensor"); // weight.T @@ -430,14 +427,14 @@ public: auto outTy = castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); - stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = - stablehlo::DotDimensionNumbersAttr::get( + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - Value matmulOutput = rewriter.create( + Value matmulOutput = rewriter.create( op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; @@ -467,7 +464,7 @@ public: auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); const auto &options = getOptions(); - SmallVector weightShapeVec = *hlo::getDimSizesOfTensor( + SmallVector weightShapeVec = *mhlo::getDimSizesOfTensor( rewriter, op, weight, options.dimSizeIndexBits); auto weightShape = weightTy.getShape(); SmallVector weightShapeInt(rank); @@ -491,7 +488,7 @@ public: } Value weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), weight, weightShapeTensor); @@ -500,7 +497,7 @@ public: for (int64_t i = 0; i <= rank; i++) transposeDims[i] = i; std::swap(transposeDims[1], transposeDims[0]); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] @@ -512,7 +509,7 @@ public: weightShapeVec[1] = OCMulGValue; weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), weight, weightShapeTensor); return weight; @@ -547,27 +544,25 @@ public: } // Prepare for transposed convolution - SmallVector stablehloStrideVec(nSpatialDims, 1); - DenseIntElementsAttr stablehloStride = - rewriter.getI64TensorAttr(stablehloStrideVec); - SmallVector stablehloPaddingVec(nSpatialDims * 2, 0); + SmallVector mhloStrideVec(nSpatialDims, 1); + DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec); + SmallVector mhloPaddingVec(nSpatialDims * 2, 0); for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; - stablehloPaddingVec[i * 2] = padInt; - stablehloPaddingVec[i * 2 + 1] = padInt; + mhloPaddingVec[i * 2] = padInt; + mhloPaddingVec[i * 2 + 1] = padInt; } - DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), - stablehloPaddingVec); - SmallVector stablehloLhsDilationVec(nSpatialDims); - std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); - DenseIntElementsAttr stablehloLhsDilation = - rewriter.getI64TensorAttr(stablehloLhsDilationVec); - SmallVector stablehloRhsDilationVec(nSpatialDims); - std::copy(dilation.begin(), dilation.end(), - stablehloRhsDilationVec.begin()); - DenseIntElementsAttr stablehloRhsDilation = - rewriter.getI64TensorAttr(stablehloRhsDilationVec); + mhloPaddingVec); + SmallVector mhloLhsDilationVec(nSpatialDims); + std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin()); + DenseIntElementsAttr mhloLhsDilation = + rewriter.getI64TensorAttr(mhloLhsDilationVec); + SmallVector mhloRhsDilationVec(nSpatialDims); + std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin()); + DenseIntElementsAttr mhloRhsDilation = + rewriter.getI64TensorAttr(mhloRhsDilationVec); DenseElementsAttr windowReversal; ArrayAttr precisionConfig; @@ -576,8 +571,8 @@ public: for (int i = 0; i < nSpatialDims; ++i) { spatialDims.push_back(i + 2); } - stablehlo::ConvDimensionNumbersAttr dimensionNumbers = - stablehlo::ConvDimensionNumbersAttr::get( + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*inputFeatureDimension=*/1, /*inputSpatialDimensions=*/spatialDims, @@ -588,18 +583,17 @@ public: /*outputSpatialDimensions=*/spatialDims); // Reverse and transpose weight - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); if (groups != 1) { weight = reshapeConvWeight(rewriter, op, weight, groups); } // Create transposed convolution - auto transposedConvOp = rewriter.create( - op->getLoc(), convOutTy, input, weight, stablehloStride, - stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, - windowReversal, dimensionNumbers, static_cast(groups), 1, - precisionConfig); + auto transposedConvOp = rewriter.create( + op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + static_cast(groups), 1, precisionConfig); // Handle output padding if (!needHandleOutputPadding) { @@ -611,8 +605,8 @@ public: std::copy(outputPadding.begin(), outputPadding.end(), edgePaddingHighVec.begin() + 2); Value paddingValue = - hlo::getConstTensor(rewriter, op, {0.0}, {}).value(); - paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy); + mhlo::getConstTensor(rewriter, op, {0.0}, {}).value(); + paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy); mlir::DenseIntElementsAttr edgePaddingLow = rewriter.getI64VectorAttr(edgePaddingLowVec); mlir::DenseIntElementsAttr edgePaddingHigh = @@ -620,7 +614,7 @@ public: mlir::DenseIntElementsAttr interiorPadding = rewriter.getI64VectorAttr(interiorPaddingVec); - auto paddedOutput = rewriter.create( + auto paddedOutput = rewriter.create( op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, edgePaddingHigh, interiorPadding); @@ -634,22 +628,22 @@ public: ArrayRef dilation, int64_t groups) const { int64_t nDims = outType.getRank(); - // Get stablehlo::ConvolutionOp attributes - DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( + // Get mhlo::ConvolutionOp attributes + DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stride.size())}, rewriter.getI64Type()), stride); - std::vector stablehloPaddingVec; + std::vector mhloPaddingVec; for (size_t i = 0; i < padding.size(); i++) { - stablehloPaddingVec.emplace_back(padding[i]); - stablehloPaddingVec.emplace_back(padding[i]); + mhloPaddingVec.emplace_back(padding[i]); + mhloPaddingVec.emplace_back(padding[i]); } - DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), - stablehloPaddingVec); - DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( + mhloPaddingVec); + DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(dilation.size())}, rewriter.getI64Type()), dilation); @@ -657,8 +651,8 @@ public: for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); } - stablehlo::ConvDimensionNumbersAttr dimensionNumbers = - stablehlo::ConvDimensionNumbersAttr::get( + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*inputFeatureDimension=*/1, /*inputSpatialDimensions=*/spatialDimensions, @@ -668,18 +662,17 @@ public: /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputSpatialDimensions=*/spatialDimensions); - // stablehlo::ConvolutionOp's optional attributes, leave them as default - DenseIntElementsAttr stablehloLhsDilation; + // mhlo::ConvolutionOp's optional attributes, leave them as default + DenseIntElementsAttr mhloLhsDilation; DenseElementsAttr windowReversal; ArrayAttr precisionConfig; - auto stablehloConvOp = rewriter.create( - op->getLoc(), outType, input, weight, stablehloWindowStride, - stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, - windowReversal, dimensionNumbers, static_cast(groups), 1, - precisionConfig); + auto mhloConvOp = rewriter.create( + op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + static_cast(groups), 1, precisionConfig); - return stablehloConvOp.getResult(); + return mhloConvOp.getResult(); } LogicalResult @@ -761,22 +754,21 @@ public: } } - Value stablehloConvResult; + Value mhloConvResult; if (transposed) { - stablehloConvResult = convertTransposedConv( + mhloConvResult = convertTransposedConv( op, rewriter, outTy, input, weight, stride, padding, dilation, outputPadding, groups, needHandleOutputPadding); } else { - stablehloConvResult = - convertNormalConv(op, rewriter, outTy, input, weight, stride, padding, - dilation, groups); + mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, + stride, padding, dilation, groups); } auto bias = adaptor.getBias(); // No bias provided if (failed(checkNotNone(rewriter, op, op.getBias()))) { - rewriter.replaceOp(op, stablehloConvResult); + rewriter.replaceOp(op, mhloConvResult); return success(); } @@ -798,21 +790,21 @@ public: llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); const auto &options = getOptions(); - bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, - options.dimSizeIndexBits); - bias = hlo::promoteType(rewriter, bias, outTy); + bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, + options.dimSizeIndexBits); + bias = mhlo::promoteType(rewriter, bias, outTy); DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, stablehloConvResult, bias, bcastDimensions); + rewriter.replaceOpWithNewOp(op, outTy, mhloConvResult, + bias, bcastDimensions); return success(); } }; } // namespace -void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp similarity index 84% rename from lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp rename to lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index dbcfba2ff..b9fb00aff 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -7,12 +7,11 @@ // //===----------------------------------------------------------------------===// -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include @@ -22,27 +21,27 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace mlir { -namespace hlo { +namespace mhlo { // Create a 32-bit float constant operator from a float -Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, - float val) { +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val) { auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } // Create a 64-bit float constant operator from a double -Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, - double val) { +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val) { auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -66,8 +65,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -89,8 +88,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, shape, rewriter.getIntegerType(vec[0].getBitWidth())); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -112,8 +111,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -134,8 +133,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -170,18 +169,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape) { auto const_type = RankedTensorType::get(dshape, dtype); auto const_attr = SplatElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } -Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value scalarValue, Type dtype) { +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype) { auto tensor = rewriter.create( op->getLoc(), ArrayRef{scalarValue}); auto dtype_tensor = - rewriter.create(op->getLoc(), tensor, dtype); - return rewriter.create( + rewriter.create(op->getLoc(), tensor, dtype); + return rewriter.create( op->getLoc(), RankedTensorType::get(mlir::ArrayRef{}, dtype), dtype_tensor); } @@ -193,8 +192,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { if (in_type.getElementType() != outType.getElementType()) { TensorType promotedType = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, - input); + return rewriter.create(op->getLoc(), promotedType, input); } return input; } @@ -212,8 +210,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (in_type.getElementType() != outType.getElementType()) { TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - input = rewriter.create(op->getLoc(), promoted_type, - input); + input = + rewriter.create(op->getLoc(), promoted_type, input); } ArrayRef inShape = in_type.getShape(); @@ -247,8 +245,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, RankedTensorType::get({static_cast(bcastDims.size())}, rewriter.getI64Type()), bcastDims); - auto bcast_op = rewriter.create( - op->getLoc(), outType, input, bcast_attr); + auto bcast_op = rewriter.create(op->getLoc(), outType, + input, bcast_attr); return bcast_op.getResult(); } @@ -350,8 +348,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, } auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto shape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, shape) + auto mhloShape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, mhloShape) .getResult(); } @@ -359,11 +357,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); - auto constTensor = rewriter.create(loc, constAttr); + auto constTensor = rewriter.create(loc, constAttr); return rewriter - .create( - loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) + .create(loc, outType, constTensor, shape, + rewriter.getI64TensorAttr({})) .getResult(); } -} // namespace hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h similarity index 79% rename from lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h rename to lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h index 6d31d267a..dc7daa42d 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H -#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -18,22 +18,22 @@ #include "mlir/Transforms/DialectConversion.h" namespace mlir { -namespace hlo { +namespace mhlo { using mlir::ConversionPatternRewriter; // Create a 32-bit float constant operator from a float -Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, - float val); +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val); // Create a 64-bit float constant operator from a double -Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, - double val); +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val); // Templated function to create a constant op for given type and shape. // T: storage C type. // Default template creates a constant tensor in T. -// To create INT48 StableHLO constant, need to pass in llvm::APInt instead. +// To create INT48 MHLO constant, need to pass in llvm::APInt instead. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape); @@ -42,8 +42,8 @@ template Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape); -Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, - Operation *op, Value scalarValue, Type dtype); +Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, + Value scalarValue, Type dtype); Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); @@ -71,7 +71,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); -} // namespace hlo +} // namespace mhlo } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp similarity index 63% rename from lib/Conversion/TorchToStablehlo/Pooling.cpp rename to lib/Conversion/TorchToMhlo/Pooling.cpp index 90044cc8b..8262ca6d3 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -7,16 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -29,26 +28,26 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -59,15 +58,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } op->emitError("unimplemented lowering in AtenPoolingOp"); @@ -117,43 +116,42 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), - stablehloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); + mhloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, + RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), - stablehloKernelSize); + mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, + RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), - stablehloStride); + mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, + RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), - stablehloDilation); + mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - stablehloPadding); - auto reduceWindowOp = rewriter.create( + mhloPadding); + auto reduceWindowOp = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -170,8 +168,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value result = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), result); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), result); } rewriter.replaceOp(op, reduceWindowOp.getResults()); @@ -223,46 +221,45 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), - stablehloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); + mhloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); + mhloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, + RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), - stablehloKernelSize); + mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, + RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), - stablehloStride); + mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, + RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), - stablehloDilation); + mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - stablehloPadding); + mhloPadding); const auto &options = getOptions(); auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -292,7 +289,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto initIndexTensor = rewriter - .create( + .create( op->getLoc(), RankedTensorType::get(initIndexShapeForType, rewriter.getI64Type()), @@ -301,15 +298,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indexTensor = rewriter - .create( + .create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), initIndexTensor, inputShapeTensor) .getResult(); - Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); + Value initIdx = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); - auto reduceWindowOp = rewriter.create( + auto reduceWindowOp = rewriter.create( op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -329,43 +326,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); - stablehlo::ComparisonTypeAttr compareTypeAttr; + mhlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { - compareTypeAttr = stablehlo::ComparisonTypeAttr::get( - rewriter.getContext(), stablehlo::ComparisonType::FLOAT); + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { - compareTypeAttr = stablehlo::ComparisonTypeAttr::get( - rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::SIGNED); } - stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = - stablehlo::ComparisonDirectionAttr::get( - rewriter.getContext(), stablehlo::ComparisonDirection::GE); - stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = - stablehlo::ComparisonDirectionAttr::get( - rewriter.getContext(), stablehlo::ComparisonDirection::EQ); + mhlo::ComparisonDirectionAttr compareGeDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonDirectionAttr compareEqDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( + Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( + Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. - Value compareEqResult = rewriter.create( + Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, - *secondIdxArg); - Value idxWithGeVal = rewriter.create( + Value minIdx = + rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( + Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - rewriter.create( + rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } @@ -422,42 +419,41 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); + SmallVector mhloStride(inputRank, 1); + SmallVector mhloDilation(inputRank, 1); + SmallVector mhloKernelSize(inputRank, 1); + SmallVector mhloPadding(inputRank * 2, 0); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + mhloKernelSize.begin() + inputRank - 2); + mhloPadding[mhloPadding.size() - 4] = padding[0]; + mhloPadding[mhloPadding.size() - 3] = padding[0]; + mhloPadding[mhloPadding.size() - 2] = padding[1]; + mhloPadding[mhloPadding.size() - 1] = padding[1]; Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, + RankedTensorType::get({static_cast(mhloKernelSize.size())}, rewriter.getI64Type()), - stablehloKernelSize); + mhloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, + RankedTensorType::get({static_cast(mhloStride.size())}, rewriter.getI64Type()), - stablehloStride); + mhloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, + RankedTensorType::get({static_cast(mhloDilation.size())}, rewriter.getI64Type()), - stablehloDilation); + mhloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - stablehloPadding); + mhloPadding); - auto reduceWindowSum = rewriter.create( + auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -475,39 +471,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { - Value divisor = hlo::getConstTensor( + Value divisor = mhlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); - divisor = hlo::promoteType(rewriter, divisor, outTy); + divisor = mhlo::promoteType(rewriter, divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); } - // Use another stablehlo.ReduceWindowOp to get the divisor + // Use another mhlo.ReduceWindowOp to get the divisor Value windowSizeConst = - hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); - windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); + mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); + windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); const auto &options = getOptions(); auto inputShapeVec = - *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + *mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); - windowSizeConst = rewriter.create( + windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - auto reduceWindowSize = rewriter.create( + auto reduceWindowSize = rewriter.create( op->getLoc(), RankedTensorType::get(outShape, inputElemTy), windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -526,99 +522,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sizeBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); } -// AtenCumsumOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenCumsumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().cast(); - auto inputElemTy = inputTy.getElementType(); - auto inputRank = inputTy.getRank(); - auto inputShape = inputTy.getShape(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); - - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: dim must be a constant int"); - } - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) { - return rewriter.notifyMatchFailure(op, "dim is out of range"); - } - if (inputTy.isDynamicDim(dim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: cumsum dim must be static"); - } - - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - - SmallVector stablehloKernelSize(inputRank, 1); - stablehloKernelSize[dim] = inputShape[dim]; - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); - stablehloPadding[dim * 2] = inputShape[dim] - 1; - - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); - DenseIntElementsAttr pad = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(inputRank), static_cast(2)}, - rewriter.getI64Type()), - stablehloPadding); - - auto reduceWindowSum = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); - - Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); - - // Add bb argument - auto blockArgumentType = RankedTensorType::get({}, inputElemTy); - sumBlock.addArgument(blockArgumentType, op->getLoc()); - sumBlock.addArgument(blockArgumentType, op->getLoc()); - auto *firstArg = sumBlock.args_begin(); - auto *secondArg = std::next(firstArg); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sumBlock); - - Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); - } - - rewriter.replaceOp(op, reduceWindowSum.getResults()); - return success(); -} - -void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add>(typeConverter, context, options); @@ -627,6 +542,4 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( target.addIllegalOp(); patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); } diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h new file mode 100644 index 000000000..2e195a87f --- /dev/null +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -0,0 +1,74 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H +#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace torch { +namespace torch_to_mhlo { + +struct TorchToMhloOptions { + bool enableStaticShape = false; + size_t dimSizeIndexBits = 64; +}; + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpAdaptor = typename AtenOpT::Adaptor; + ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, + const TorchToMhloOptions &options) + : OpConversionPattern(typeConverter, context) { + this->options = options; + } + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure(op, "haven't been implemented"); + } + const TorchToMhloOptions &getOptions() const { return options; } + +private: + TorchToMhloOptions options; +}; + +void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); +void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); + +void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); + +} // namespace torch_to_mhlo +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp similarity index 73% rename from lib/Conversion/TorchToStablehlo/Reduction.cpp rename to lib/Conversion/TorchToMhlo/Reduction.cpp index eb4e11116..20a6c377e 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -7,15 +7,14 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -26,7 +25,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { @@ -37,14 +36,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -54,15 +53,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -91,9 +90,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, return std::nullopt; Value initIndex; if (dimSizeIndexBits == 32) { - initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); + initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); } else { - initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); + initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); } DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( @@ -101,13 +100,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); - auto indexTensor = rewriter.create( + auto indexTensor = rewriter.create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(dimSizeIndexBits)), inputShapeTensor, static_cast(dim)); - auto stablehloReduceOp = rewriter.create( + auto mhloReduceOp = rewriter.create( op->getLoc(), ValueRange{input, indexTensor}, ValueRange{ initValue, @@ -115,7 +114,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, }, dimensions); - Block &block = stablehloReduceOp.getBody().emplaceBlock(); + Block &block = mhloReduceOp.getBody().emplaceBlock(); // Add block arguments auto blockValArgumentType = @@ -134,46 +133,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); - stablehlo::ComparisonTypeAttr compareTypeAttr; + mhlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { - compareTypeAttr = stablehlo::ComparisonTypeAttr::get( - rewriter.getContext(), stablehlo::ComparisonType::FLOAT); + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { - compareTypeAttr = stablehlo::ComparisonTypeAttr::get( - rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + rewriter.getContext(), mhlo::ComparisonType::SIGNED); } - stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = - stablehlo::ComparisonDirectionAttr::get( - rewriter.getContext(), stablehlo::ComparisonDirection::GE); - stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = - stablehlo::ComparisonDirectionAttr::get( - rewriter.getContext(), stablehlo::ComparisonDirection::EQ); + mhlo::ComparisonDirectionAttr compareGeDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonDirectionAttr compareEqDirectionAttr = + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( + Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( + Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. - Value compareEqResult = rewriter.create( + Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, - *secondIdxArg); - Value idxWithGeVal = rewriter.create( + Value minIdx = + rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( + Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - rewriter.create( + rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } - return stablehloReduceOp.getResults(); + return mhloReduceOp.getResults(); } namespace { @@ -197,8 +196,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().template cast(); if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); @@ -211,7 +209,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenArgmaxOp to StableHLO"); + "AtenArgmaxOp to MHLO"); } int64_t dim; @@ -230,15 +228,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const auto &options = getOptions(); auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -249,13 +247,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), mhloReduceResults[1], outShapeTensor); return success(); } - rewriter.replaceOp(op, stablehloReduceResults[1]); + rewriter.replaceOp(op, mhloReduceResults[1]); return success(); } } // namespace @@ -269,8 +267,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -282,7 +279,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxDimOp to StableHLO"); + "AtenMaxDimOp to MHLO"); } RankedTensorType valResultType = getTypeConverter() @@ -311,15 +308,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const auto &options = getOptions(); auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -330,21 +327,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, stablehloReduceResults[0], - outShapeTensor); - auto stablehloReduceIndexResult = - rewriter.create( - op->getLoc(), idxResultType, stablehloReduceResults[1], - outShapeTensor); - rewriter.replaceOp( - op, {stablehloReduceValueResult, stablehloReduceIndexResult}); + auto mhloReduceValueResult = rewriter.create( + op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor); + auto mhloReduceIndexResult = rewriter.create( + op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor); + rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult}); return success(); } - rewriter.replaceOp(op, - {stablehloReduceResults[0], stablehloReduceResults[1]}); + rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]}); return success(); } } // namespace @@ -361,14 +352,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ->convertType(op.getType()) .template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } if (inputTy.getElementType() != outTy.getElementType()) { // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); + input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); @@ -381,7 +370,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumOp to StableHLO"); + "AtenSumOp to MHLO"); } SmallVector dims; @@ -390,14 +379,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); + if (!initValue) return failure(); llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( + auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Block &block = stablehloReduceOp.getBody().emplaceBlock(); + Block &block = mhloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); @@ -409,13 +397,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( + Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + rewriter.create(op->getLoc(), addResult); } rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); + mhloReduceOp.getResults()); return success(); } } // namespace @@ -429,8 +417,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -442,7 +429,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxOp to StableHLO"); + "AtenMaxOp to MHLO"); } SmallVector dims; @@ -452,13 +439,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); + if (!initValue) return failure(); llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( + auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Block &block = stablehloReduceOp.getBody().emplaceBlock(); + Block &block = mhloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); @@ -470,14 +456,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value maxResult = rewriter.create( + Value maxResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), maxResult); + rewriter.create(op->getLoc(), maxResult); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); + mhloReduceOp.getResults()); return success(); } } // namespace @@ -494,14 +480,12 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ->convertType(op.getType()) .template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); + return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); } if (inputTy.getElementType() != outTy.getElementType()) { // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); + input = rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); @@ -515,7 +499,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumDimIntListOp to StableHLO"); + "AtenSumDimIntListOp to MHLO"); } SmallVector inputDims; @@ -541,14 +525,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); + if (!initValue) return failure(); llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( + auto mhloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Region ®ion = stablehloReduceOp.getBody(); + Region ®ion = mhloReduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -561,15 +544,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( + Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + rewriter.create(op->getLoc(), addResult); } if (keepDim) { const auto &options = getOptions(); - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -584,27 +567,26 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResult(0), outShapeTensor); + mhloReduceOp.getResult(0), outShapeTensor); return success(); } rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); + mhloReduceOp.getResults()); return success(); } } // namespace // AtenFrobeniusNormDimOp -// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given -// dims) -// + stablehlo.sqrt +// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims) +// + mhlo.sqrt namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenFrobeniusNormDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - const TorchToStablehloOptions &options = getOptions(); + const TorchToMhloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = input.getType().dyn_cast(); @@ -632,7 +614,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } - // Sort the dims in ascending order, making the conversion + // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); @@ -642,57 +624,58 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto squareOp = rewriter.create(op->getLoc(), input, input); - auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); if (!initValue) { return failure(); } - auto reduceOp = rewriter.create( - op->getLoc(), squareOp.getResult(), initValue, - rewriter.getI64TensorAttr(dims)); + auto squareSumReduceOp = rewriter.create( + op->getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Region ®ion = reduceOp.getBody(); + Region ®ion = squareSumReduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputElemType); block.addArgument(blockArgumentTy, op->getLoc()); block.addArgument(blockArgumentTy, op->getLoc()); - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + auto constantOrd2 = rewriter.create( + op->getLoc(), blockArgumentTy, + DenseElementsAttr::get(blockArgumentTy, llvm::ArrayRef{2.0})); + auto abs = rewriter.create(op->getLoc(), *secondArgument); + auto squareResult = rewriter.create( + op->getLoc(), abs, constantOrd2); + auto addResult = rewriter.create(op->getLoc(), squareResult, + *firstArgument); + rewriter.create(op->getLoc(), addResult.getResult()); } - auto output = - rewriter.create(op->getLoc(), reduceOp.getResult(0)); + auto output = rewriter.create(op->getLoc(), + squareSumReduceOp.getResult(0)); if (keepDim) { - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + op->getLoc(), rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output, outShapeTensor); return success(); @@ -702,9 +685,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace -void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp similarity index 58% rename from lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp rename to lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index ba0838484..f81afd9ca 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -7,18 +7,17 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" - +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -31,18 +30,17 @@ using namespace mlir::torch::Torch; namespace { -class ConvertTorchToStablehlo - : public ConvertTorchToStablehloBase { +class ConvertTorchToMhlo : public ConvertTorchToMhloBase { public: - ConvertTorchToStablehlo() = default; - ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) { + ConvertTorchToMhlo() = default; + ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) { this->enableStaticShape = enableStaticShape; this->enableI32Index = enableI32Index; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); @@ -50,7 +48,7 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - target.addLegalDialect(); TypeConverter typeConverter; @@ -59,20 +57,20 @@ public: RewritePatternSet patterns(context); - torch_to_stablehlo::TorchToStablehloOptions options{ - enableStaticShape, enableI32Index ? 32u : 64u}; - torch_to_stablehlo::populateBasicOpPatternsAndLegality( + torch_to_mhlo::TorchToMhloOptions options{enableStaticShape, + enableI32Index ? 32u : 64u}; + torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populateViewLikeOpPatternsAndLegality( typeConverter, patterns, target, options); - torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( - typeConverter, patterns, target, options); - torch_to_stablehlo::populateGatherOpPatternsAndLegality( - typeConverter, patterns, target, options); - torch_to_stablehlo::populateReductionOpPatternsAndLegality( - typeConverter, patterns, target, options); - torch_to_stablehlo::populateLinearOpPatternsAndLegality( - typeConverter, patterns, target, options); - torch_to_stablehlo::populatePoolingOpPatternsAndLegality( + torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populateReductionOpPatternsAndLegality( typeConverter, patterns, target, options); + torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, + target, options); + torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, + target, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -84,13 +82,13 @@ public: } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass() { - return std::make_unique(false, false); +mlir::torch::createConvertTorchToMhloPass() { + return std::make_unique(false, false); } std::unique_ptr> -mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape, - bool enableI32Index) { - return std::make_unique(enableStaticShape, - enableI32Index); +mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape, + bool enableI32Index) { + return std::make_unique(enableStaticShape, + enableI32Index); } diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp similarity index 88% rename from lib/Conversion/TorchToStablehlo/ViewLike.cpp rename to lib/Conversion/TorchToMhlo/ViewLike.cpp index b6511c384..29284d50e 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -7,15 +7,14 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" -#include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" - +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -29,7 +28,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; -using namespace mlir::torch::torch_to_stablehlo; +using namespace mlir::torch::torch_to_mhlo; namespace { // A dimension index from torch.dialect might outside the range [0, dimSize]. @@ -101,7 +100,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, auto stridesTensor = rewriter.create(loc, strides).getResult(); - return rewriter.create( + return rewriter.create( loc, outTy, input, startTensor, endTensor, stridesTensor); } @@ -145,7 +144,7 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, step = rewriter.create(loc, intType, step); } FailureOr> dimSizesInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); + mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -180,7 +179,7 @@ public: auto loc = op.getLoc(); auto newRank = dimSizes.size(); if (newRank == 0 || rankType.getRank() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -215,18 +214,17 @@ public: numel); if (dimSizes.size() == 0) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.getSelf()); return success(); } - Value stablehloShape = - rewriter.create(loc, dimSizes); - Value computedShape = rewriter.create( - loc, stablehloShape.getType(), numel, stablehloShape); - rewriter.replaceOpWithNewOp( + Value mhloShape = rewriter.create(loc, dimSizes); + Value computedShape = rewriter.create( + loc, mhloShape.getType(), numel, mhloShape); + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -317,21 +315,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dims.push_back(r); } if (dims.size() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; - auto stablehloShape = + auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); return success(); } @@ -367,20 +365,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); if (dims.size() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; - auto stablehloShape = + auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, mhloShape); return success(); } @@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), - {dim}, options.dimSizeIndexBits); + auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), + {dim}, options.dimSizeIndexBits); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); @@ -407,9 +405,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( +void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt deleted file mode 100644 index 237512980..000000000 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToStablehlo - TorchToStablehlo.cpp - StablehloLegalizeUtils.cpp - Basic.cpp - Gather.cpp - Linear.cpp - ViewLike.cpp - Reduction.cpp - Pooling.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRBufferTransforms - StablehloOps - TorchMLIRTorchDialect - TorchMLIRConversionUtils -) - -torch_mlir_target_includes(TorchMLIRTorchToStablehlo) diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h deleted file mode 100644 index b6322efd6..000000000 --- a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h +++ /dev/null @@ -1,69 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H -#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace torch { -namespace torch_to_stablehlo { - -struct TorchToStablehloOptions { - bool enableStaticShape = false; - size_t dimSizeIndexBits = 64; -}; - -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpAdaptor = typename AtenOpT::Adaptor; - ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, - const TorchToStablehloOptions &options) - : OpConversionPattern(typeConverter, context) { - this->options = options; - } - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return rewriter.notifyMatchFailure(op, "haven't been implemented"); - } - const TorchToStablehloOptions &getOptions() const { return options; } - -private: - TorchToStablehloOptions options; -}; - -void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToStablehloOptions &options); -void populateViewLikeOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options); -void populateGatherOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options); -void populateReductionOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options); -void populateLinearOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options); - -void populatePoolingOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options); - -} // namespace torch_to_stablehlo -} // namespace torch -} // namespace mlir - -#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c841afcdf..b80c35c14 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/ValueRange.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -27,9 +26,6 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/APInt.h" -#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::torch; @@ -56,147 +52,6 @@ using namespace mlir::torch::TMTensor; // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". -static Attribute getNumericLimit(PatternRewriter &rewriter, Type elementType, - bool getMin = true) { - auto bitWidth = elementType.getIntOrFloatBitWidth(); - if (llvm::isa(elementType)) { - if (getMin) { - return rewriter.getIntegerAttr(elementType, - APInt::getSignedMinValue(bitWidth)); - } else { - return rewriter.getIntegerAttr(elementType, - APInt::getSignedMaxValue(bitWidth)); - } - } else if (mlir::FloatType floatType = - llvm::dyn_cast(elementType)) { - return rewriter.getFloatAttr( - elementType, - APFloat::getLargest(floatType.getFloatSemantics(), getMin)); - } else { - llvm_unreachable("Only float/integer types are supported!"); - } -} - -// This function will reformat the `index` and `src` from torch operations -// like `torch.scatter` or `torch.scatter_reduce` to match the expected -// input for the TMScatterOp. It will return the reformated `index` and `src` -// as a pair of mlir::Value that can be used as inputs for the TMScatterOp. -static std::pair -convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, - Value indices, Value src, - int64_t dim) { - // Get information on types for inputs - RankedTensorType indexType = indices.getType().cast(); - RankedTensorType srcSelf = src.getType().cast(); - - // Store location for insertions - Location loc = src.getLoc(); - - Value indexSize = getTensorSize(rewriter, loc, indices); - indexSize = castIntToIndex(rewriter, loc, indexSize); - SmallVector indexShape = getTensorSizes(rewriter, loc, indices); - Value cstOne = rewriter.create(loc, 1); - - // We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...) - SmallVector indSliceShape({indexSize, cstOne}); - Value indSlice = - createZeroInitTensor(rewriter, loc, indSliceShape, rewriter.getI32Type()); - - // New output shape will be equal to the product of the dimensions of the - // updates - SmallVector outputs(indexType.getRank(), indSlice); - outputs.push_back(createZeroInitTensor(rewriter, loc, {indexSize}, - srcSelf.getElementType())); - SmallVector outputsType(indexType.getRank(), indSlice.getType()); - outputsType.push_back(outputs[indexType.getRank()].getType()); - - // Create mapping over flattened iteration space - SmallVector indSliceExpr = {rewriter.getAffineDimExpr(0), - rewriter.getAffineConstantExpr(0)}; - SmallVector mapping( - indexType.getRank(), AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, - indSliceExpr, src.getContext())); - // Mapping for updates - mapping.push_back(rewriter.getDimIdentityMap()); - SmallVector iteratorTypes( - {utils::IteratorType::parallel}); - - // This function goes over the flattened iteration space of the `indices` - // and `src`. It will reconstruct the original induction variables based - // on the current flattened index. The flattened iteration space is required - // because TMTensorScatterOp expects a list of single element updates. - auto flattenedUpdates = - rewriter - .create( - loc, outputsType, ValueRange(), outputs, mapping, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indexValues(indexType.getRank()); - Value ind = b.create(loc, 0); - for (int i = indexType.getRank() - 1; i >= 0; i--) { - indexValues[i] = - b.create(loc, ind, indexShape[i]); - ind = b.create(loc, ind, indexShape[i]); - } - // Extract the scatter index and update value - Value extractIndexValue = - b.create(loc, indices, indexValues); - Value extractSrcValue = - b.create(loc, src, indexValues); - SmallVector yieldVals; - for (Value v : indexValues) { - Value scalar = castIndexToInt64(b, loc, v); - yieldVals.push_back(b.create( - loc, rewriter.getI32Type(), scalar)); - } - // Replace the original index with the index specified - // by the scatter. - yieldVals[dim] = b.create( - loc, rewriter.getI32Type(), extractIndexValue); - yieldVals.push_back(extractSrcValue); - b.create(loc, yieldVals); - }) - .getResultTensors(); - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - - // The result of the linalg::Generic operation gives us (rank(`src`) + 1) - // 1D-tensors where each contains a number of elements equal to the total - // number of elements in the `src` tensor. The indices must now be - // constructed by concatanating the first rank(`src`) tensors together. The - // new `src` tensor is the last tensor returned from the linalg::Generic - // operation. - SmallVector offsets = { - rewriter.create(loc, 0), - rewriter.create(loc, 0)}; - SmallVector strides = { - rewriter.create(loc, 1), - rewriter.create(loc, 1)}; - Value indicesRank = - rewriter.create(loc, indexType.getRank()); - Value flattenedIndices = createZeroInitTensor( - rewriter, loc, SmallVector({indexSize, indicesRank}), - rewriter.getI32Type()); - SmallVector scatterInputsVector(flattenedUpdates); - for (auto const slice : ArrayRef(scatterInputsVector).drop_back()) { - SmallVector sizes = getTensorSizes(rewriter, loc, slice); - flattenedIndices = rewriter.createOrFold( - loc, slice, flattenedIndices, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - // Increment offset to insert into next column - offsets[1] = rewriter.createOrFold(loc, offsets[1], cstOne); - } - - return std::make_pair(flattenedIndices, - scatterInputsVector[indexType.getRank()]); -} - static Value createTMTensorScatterOp( OpBuilder &b, Location loc, Value updates, Value indices, Value original, bool uniqueIndices, @@ -287,7 +142,7 @@ public: // Finding the maximum value in the input tensor. SmallVector maxTensorSizes; ValueTensorType maxTensorType = ValueTensorType::get( - context, llvm::ArrayRef(maxTensorSizes), + context, llvm::makeArrayRef(maxTensorSizes), torchTypeInput.getType().cast().getDtype()); Value maxTensor = rewriter.create(loc, maxTensorType, torchTypeInput); @@ -310,7 +165,7 @@ public: SmallVector expandedInputSizes{ makeShapeTorchCompatible(inputType.getShape())[0], 1}; ValueTensorType expandInputType = ValueTensorType::get( - context, llvm::ArrayRef(expandedInputSizes), + context, llvm::makeArrayRef(expandedInputSizes), torchTypeInput.getType().cast().getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); @@ -431,9 +286,9 @@ public: auto indexTensorType = indexTensor.getType().cast(); int64_t indexTensorSize = indexTensorType.getSizes()[0]; SmallVector expandedIndexTensorSizes{indexTensorSize, 1}; - ValueTensorType expandedIndexTensorType = - ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes), - indexTensorType.getDtype()); + ValueTensorType expandedIndexTensorType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIndexTensorSizes), + indexTensorType.getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedIndexTensor = rewriter.create( @@ -697,229 +552,6 @@ public: }; } // namespace -namespace { -class ConvertAtenScatterReduceTwoOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenScatterReduceTwoOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - Location loc = op.getLoc(); - - RankedTensorType selfType = - adaptor.getSelf().getType().cast(); - RankedTensorType indexType = - adaptor.getIndex().getType().cast(); - RankedTensorType srcType = - adaptor.getSrc().getType().cast(); - - Value self = adaptor.getSelf(); - - if (selfType.getRank() != indexType.getRank() || - indexType.getRank() != srcType.getRank()) - return rewriter.notifyMatchFailure(op, - "'self', 'index' and 'src' should all " - "have the same number of dimensions."); - - std::string reduceType; - if (!matchPattern(op.getReduce(), m_TorchConstantStr(reduceType))) - return rewriter.notifyMatchFailure(op, - "'reduce' must be a costant string"); - - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure(op, "'dim' is not constant"); - - bool includeSelf; - if (!matchPattern(op.getIncludeSelf(), m_TorchConstantBool(&includeSelf))) - return rewriter.notifyMatchFailure(op, "'include_self' is not constant"); - - // Get reduce string as the equivalent enum - auto reduceEnum = torch_upstream::get_reduction_enum(reduceType); - - // Get the inputs reformatted for the TMScatterOp - auto [indices, updates] = - convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc( - rewriter, adaptor.getIndex(), adaptor.getSrc(), dim); - - // Value 'counts' will be used to tally the number of reductions into - // each unique index. The tally is used to calculate the average of the - // values scattered per index. - Value counts = nullptr; - if (reduceEnum == torch_upstream::ReductionType::MEAN) { - SmallVector selfShape = - getTensorSizes(rewriter, loc, adaptor.getSelf()); - Attribute initAttr; - if (llvm::isa(srcType.getElementType())) { - initAttr = rewriter.getFloatAttr(srcType.getElementType(), 1); - } else if (llvm::isa(srcType.getElementType())) { - initAttr = rewriter.getIntegerAttr(srcType.getElementType(), 1); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - Value initElement = rewriter.create(loc, initAttr); - counts = createInitTensor(rewriter, loc, selfShape, - selfType.getElementType(), initElement); - } - - // If the original values shouldn't be included, normalize the - // input tensor where the scatters take place. - if (!includeSelf) { - Value normalizationValue; - if (reduceEnum == torch_upstream::ReductionType::SUM || - reduceEnum == torch_upstream::ReductionType::MEAN) { - // Set the values in the input tensor to '0' so they are not included - normalizationValue = rewriter.create( - loc, rewriter.getZeroAttr(srcType.getElementType())); - } else if (reduceEnum == torch_upstream::ReductionType::PROD) { - // Set the values in the input tensor to '1' (multiplication identity) - if (llvm::isa(srcType.getElementType())) { - normalizationValue = rewriter.create( - loc, rewriter.getFloatAttr(srcType.getElementType(), 1.0)); - } else if (llvm::isa(srcType.getElementType())) { - normalizationValue = rewriter.create( - loc, rewriter.getIntegerAttr(srcType.getElementType(), 1)); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - } else if (reduceEnum == torch_upstream::ReductionType::MAX) { - // Set the values in the input tensor to the smallest element of that - // type - auto minAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/true); - normalizationValue = rewriter.create(loc, minAttr); - } else if (reduceEnum == torch_upstream::ReductionType::MIN) { - // Set the values in the input tensor to the largest element of that - // type - auto maxAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/false); - normalizationValue = rewriter.create(loc, maxAttr); - } - - // Scatter the normalizations into the input tensor - Value indexSize = getTensorSize(rewriter, loc, adaptor.getIndex()); - indexSize = castIntToIndex(rewriter, loc, indexSize); - Value normalizations = createInitTensor( - rewriter, loc, SmallVector({indexSize}), - srcType.getElementType(), /*init_element=*/normalizationValue); - self = createTMTensorScatterOp( - rewriter, loc, normalizations, indices, self, - /*uniqueIndices=*/false, - [&](OpBuilder &b, Location loc, Value update, Value current) { - b.create(loc, update); - }); - if (reduceEnum == torch_upstream::ReductionType::MEAN) { - counts = createTMTensorScatterOp( - rewriter, loc, normalizations, indices, counts, - /*uniqueIndices=*/false, - [&](OpBuilder &b, Location loc, Value update, Value current) { - b.create(loc, update); - }); - } - } - - // Create final operation - Value scatterOp = createTMTensorScatterOp( - rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, - [&](OpBuilder &b, Location loc, Value update, Value current) { - Value result; - if (reduceEnum == torch_upstream::ReductionType::SUM || - reduceEnum == torch_upstream::ReductionType::MEAN) { - if (update.getType().isa()) { - result = b.create(loc, update, current); - } else if (update.getType().isa()) { - result = b.create(loc, update, current); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - } else if (reduceEnum == torch_upstream::ReductionType::PROD) { - if (update.getType().isa()) { - result = b.create(loc, update, current); - } else if (update.getType().isa()) { - result = b.create(loc, update, current); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - } else if (reduceEnum == torch_upstream::ReductionType::MAX) { - if (update.getType().isa()) { - result = b.create(loc, update, current); - } else if (update.getType().isa()) { - result = b.create(loc, update, current); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - } else if (reduceEnum == torch_upstream::ReductionType::MIN) { - if (update.getType().isa()) { - result = b.create(loc, update, current); - } else if (update.getType().isa()) { - result = b.create(loc, update, current); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - } - b.create(loc, result); - }); - - // Special case for the mean - if (reduceEnum == torch_upstream::ReductionType::MEAN) { - counts = createTMTensorScatterOp( - rewriter, loc, updates, indices, counts, - /*uniqueIndices=*/false, - [&](OpBuilder &b, Location loc, Value update, Value current) { - Value result; - if (mlir::IntegerType intType = - llvm::dyn_cast(current.getType())) { - Value constantUpdate = b.create( - loc, b.getIntegerAttr(intType, 1)); - result = b.create(loc, constantUpdate, current); - } else if (mlir::FloatType floatType = - llvm::dyn_cast(current.getType())) { - Value constantUpdate = b.create( - loc, b.getFloatAttr(floatType, 1.0)); - result = b.create(loc, constantUpdate, current); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - b.create(loc, result); - }); - - Value output = rewriter.create( - loc, tensor::getMixedSizes(rewriter, loc, self), - selfType.getElementType()); - - // Finally divide the result - scatterOp = - rewriter - .create( - loc, ValueRange{scatterOp, counts}, output, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value result; - if (llvm::isa(args[0].getType())) { - result = b.create(loc, args[0], args[1]); - } else if (llvm::isa(args[0].getType())) { - result = b.create(loc, args[0], args[1]); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - b.create(loc, result); - }) - .getResult()[0]; - } - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - rewriter.replaceOpWithNewOp(op, resultType, scatterOp); - - return success(); - } -}; -} // namespace - namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -1012,8 +644,6 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index dfe655fc9..70987a9b0 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -10,7 +10,6 @@ #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -718,8 +717,8 @@ class ConvertAtenMultipleDimsReductionOp "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); - reduceDimsAttr = - DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); + reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, + llvm::makeArrayRef(reduceDims)); keepDims = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) @@ -748,8 +747,8 @@ class ConvertAtenOneDimReductionOp return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); - reduceDimsAttr = - DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim})); + reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, + llvm::makeArrayRef({reduceDim})); keepDims = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) @@ -782,8 +781,8 @@ public: reduceDims.push_back(i); int64_t N = selfTy.getRank(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); - reduceDimsAttr = - DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); + reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, + llvm::makeArrayRef(reduceDims)); keepDims = false; return success(); @@ -2646,36 +2645,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "size must consist of Scalar constants"); - // the shape -1 is inferred from other dimensions - size_t countNegativeShape{0}; - // Check at most one -1 shape - for (size_t i = 0; i < outShape.size(); i++) { - if (outShape[i] < 0) { - countNegativeShape++; - if (countNegativeShape > 1) - return rewriter.notifyMatchFailure(op, "At most one -1 shape"); - } - } - - auto inputShape = selfType.getShape(); - size_t totalSize = 1; - for (size_t i = 0; i < inputShape.size(); i++) { - totalSize *= inputShape[i]; - } - - size_t otherSize = 1; - for (size_t i = 0; i < outShape.size(); i++) { - if (outShape[i] > 0) { - otherSize *= outShape[i]; - } - } - for (size_t i = 0; i < outShape.size(); i++) { - if (outShape[i] < 0) { - outShape[i] = totalSize / otherSize; - break; - } - } - rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), rewriter.getDenseI64ArrayAttr(outShape)); @@ -2847,79 +2816,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenHardtanhBackwardOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - // Not a tensor type. - auto selfType = adaptor.getSelf().getType().dyn_cast(); - if (!selfType) { - return rewriter.notifyMatchFailure( - op, "Only tensor types are currently supported"); - } - - auto selfElemTy = selfType.getElementType(); - if (!selfElemTy.isIntOrFloat()) { - return rewriter.notifyMatchFailure( - op, "Only floating-point or integer datatype legalization supported"); - } - - // Integer types with width > 32 are not supported - auto selfIntType = selfElemTy.dyn_cast(); - if (selfIntType && selfIntType.getWidth() > 32) { - return rewriter.notifyMatchFailure( - op, "Integer types with width greater than 32 are not supported"); - } - - Value gradOutput = adaptor.getGradOutput(); - auto gradOutputType = adaptor.getSelf().getType().dyn_cast(); - - Type gradOutputElemType = gradOutputType.getElementType(); - - if (selfElemTy != gradOutputElemType) { - return rewriter.notifyMatchFailure( - op, - "Input element type should be same as the grad_output element type."); - } - - SmallVector constTypeShape(selfType.getRank(), 1); - Value maxVal, minVal; - - if (failed(torchScalarToTosaTensor(rewriter, op, op.getMinVal(), minVal, - selfElemTy, constTypeShape))) { - return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); - } - - if (failed(torchScalarToTosaTensor(rewriter, op, op.getMaxVal(), maxVal, - selfElemTy, constTypeShape))) { - return rewriter.notifyMatchFailure(op, "Only scalar constant is supported"); - } - - Value replace = tosa::getConstTensor(rewriter, op, 0, {}).value(); - Type outType = getTypeConverter()->convertType(op.getType()); - - Value lesser = rewriter.create( - op.getLoc(), - RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - minVal, adaptor.getSelf()); - - Value greater = rewriter.create( - op.getLoc(), - RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - adaptor.getSelf(), maxVal); - - Value cmp = rewriter.create( - op.getLoc(), - RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - lesser, greater); - - rewriter.replaceOpWithNewOp(op, outType, cmp, replace, - gradOutput); - - return success(); -} - template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmbeddingOp op, OpAdaptor adaptor, @@ -3217,70 +3113,31 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point or integer datatype legalization supported"); } - SmallVector resultShape; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) + SmallVector outShape; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape))) return rewriter.notifyMatchFailure(op, "size must consist of Scalar constants"); - // Get the result type - auto resultType = getTypeConverter()->convertType(op.getType()); SmallVector inputShape( makeShapeTorchCompatible(selfType.getShape())); - // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is - // true then we can replace the op result with the input operand directly. - if (llvm::equal(inputShape, resultShape)) { - // If we reach here, then it means that the broadcasting is not required - // since the input and result are of same shape. - op.replaceAllUsesWith(op.getSelf()); - rewriter.eraseOp(op); - return success(); - } else if (selfType.hasRank() && - (selfType.getRank() == (int64_t)resultShape.size() || - selfType.getRank() == 0)) { - // Right now to support limited cases where input and result shape are not - // equal, we can put a constraint that either the input should be of rank - // 0 or the rank of input tensor and result should be equal. And then we - // can check for broadcasting compatibility for the latter case. For - // broadcasting compatibility, either the shape of input and result should - // be equal at each dimenion or one of them should be 1. - if (selfType.getRank() != 0) { - for (unsigned i = 0; i < inputShape.size(); i++) { - if (inputShape[i] != resultShape[i] && inputShape[i] != 1 && - resultShape[i] != 1) { + if (inputShape.size() == outShape.size() || inputShape.size() == 0) { + // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is + // true then we can replace the op result with the input operand + // irrespective of the users of the op result. + if (!llvm::equal(inputShape, outShape)) { + for (auto user : op->getResult(0).getUsers()) { + // This case is only supported if the result of the `broadcast_to` op is + // not used by an op which is a view like. + if (isViewLikeOp(user)) { return rewriter.notifyMatchFailure( - op, "unimplemented: either the shape of input and result should " - "be equal at each dimenion or one of them should be 1."); + op, "unimplemented: broadcast not supported for this case"); } } } - - // If the above condition hold true then we can directly create a const - // zero tensor of shape same as the result shape. - SmallVector zeroTensorShape{resultShape}; - - // create the 0 constant tensor - int64_t totalNumElements = 1; - for (auto dimSize : zeroTensorShape) { - totalNumElements = dimSize * totalNumElements; - } - // There is some danger here. For edge cases in floating point, x + 0 != x. - // The cases are denormalized values, which may get flushed, and -0 + 0 = - // +0. (sign bit flips). These are probably acceptable in the short term, - // but we should put a comment acknowledging the danger, as there isn't an - // op that avoids the denorm flushing. - SmallVector intValues(totalNumElements, 0); - SmallVector floatValues(totalNumElements, 0.0); - Value zeroTensor = selfType.getElementType().isa() - ? tosa::getConstTensor( - rewriter, op, floatValues, zeroTensorShape) - .value() - : tosa::getConstTensor( - rewriter, op, intValues, zeroTensorShape) - .value(); - - // Use add broadcast - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), - zeroTensor); + // If we reach here, then it means the given case is handled by implicit + // broadcasting done by tosa. + op.replaceAllUsesWith(op.getSelf()); + rewriter.eraseOp(op); return success(); } return rewriter.notifyMatchFailure( @@ -3375,171 +3232,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenIndexTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // t = tf.constant([[1, 2, 3, 4, 5],[6,7,8,9,10], - // [11,12,13,14,15],[16,17,18,19,20]]) # 4*5 - // i = tf.constant([[1,2,3], [3,2,1]]) # 2*3 - // i_expand = tf.expand_dims(i,axis=2) # 2*3*1 - // IndexTensorOutput = tf.gather_nd(t,tf.i_expand) - // = torch.ops.aten.index(t, (i, )) = t[i] # 2*3*5 - // [[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], - // [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]] - auto input = adaptor.getSelf(); - auto inputTensorType = - adaptor.getSelf().getType().dyn_cast(); - // Check input is a tensor type. - if (!inputTensorType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - - // Deal with torch.prim.ListConstruct of non const value to get the index - auto tensorList = op.getIndices(); - SmallVector tensorsTorchType; - if (!getListConstructElements(tensorList, tensorsTorchType)) - return op.emitError( - "unimplemented: the tensor list is not from list construct"); - auto indexTensors = getTypeConvertedValues( - rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); - - auto outType = getTypeConverter()->convertType(op.getType()); - - // Support for multiple indexes - if (indexTensors.size() > 1) { - // t[i, i] - // = torch.ops.aten.index(t,(i,i)) - // = tensor([[ t[1,1], t[2,2], t[3,3]], - // [ t[3,3], t[2,2], t[1,1]]]) - // = tensor([[ 7, 13, 19], [19, 13, 7]]) - // = tf.gather_nd(t,tf.ii_expand) - // ii_expand - // = tf.concat((i_expand,i_expand), dim=2) - // = tf.constant([[[1,1],[2,2],[3,3]], - // [[3,3],[2,2],[1,1]]]) # 2*3*2 - SmallVector indicesTfConcatTensors; - SmallVector indexesRank; - SmallVector> indexesShape; - - // concat index tensor into to indices tensor for concat - for (size_t i = 0; i < indexTensors.size(); i++) { - auto index = indexTensors[i]; - auto indexTorch = tensorsTorchType[i]; - // TODO add support for none index input like torch.ops.aten.index(x, - // (None, index1, index2, None)) - if (indexTorch.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Only list ranked tensor types index are supported"); - - auto indexType = index.getType().dyn_cast(); - auto indexShape = indexType.getShape(); - indexesShape.push_back(makeShapeTorchCompatible(indexShape)); - indexesRank.push_back(indexType.getRank()); - - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), - index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indiceShapeOneDim; - for (auto shape : indexShape) { - indiceShapeOneDim.push_back(shape); - } - indiceShapeOneDim.push_back(1); - auto indicesTfOneDim = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), - index, rewriter.getDenseI64ArrayAttr(indiceShapeOneDim)); - - // create concat tensor for indicesTf - indicesTfConcatTensors.push_back(indicesTfOneDim.getResult()); - } - - // Right now only support multiple indexes with same shape - // TODO for different shape multiple indexes, add broadcast_to for small - // shape - for (auto indexShapeOneDim : indexesShape) { - if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: Only support multi indexes with same shape"); - } - } - - // concat each indices into indicesTf: shape [2,3,1],[2,3,1] -> [2,3,2] - auto indicesShapeConcat = indexesShape[0]; - uint64_t lastDim = indexesRank[0]; - indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), - indicesTfConcatTensors, lastDim); - - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); - - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); - } - rewriter.replaceOp(op, {result.value()}); - - return success(); - } - - // Support for multiple index - auto index = indexTensors[0]; - auto indexTorch = tensorsTorchType[0]; - // TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None)) - if (indexTorch.getType().isa()) - return rewriter.notifyMatchFailure( - op, "Only list ranked tensor types index are supported"); - auto indexType = index.getType().dyn_cast(); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } - - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); - } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); - - if (!indicesTf) { - return rewriter.notifyMatchFailure(op, - "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); - - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); - } - rewriter.replaceOp(op, {result.value()}); - - return success(); -} - template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenWhereSelfOp op, OpAdaptor adaptor, @@ -4276,11 +3968,9 @@ public: if (!op.getMemoryFormat().getType().template isa() && (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || - (memoryFormat != torch_upstream::MemoryFormat::Contiguous && - memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) { + memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { return op.emitError( - "unimplemented: only contiguous and channels last memory " - "format is supported"); + "unimplemented: only default memory format is supported"); } auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -4479,7 +4169,6 @@ public: target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenTanhOp); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenLeakyReluOp); @@ -4507,7 +4196,6 @@ public: INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 158350630..0ab8f0a4e 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -230,10 +230,6 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || (src.isInteger(8) && dest.isInteger(1)) || - (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || - (src.isF32() && dest.isF64()) || - (src.isF64() && dest.isF32()) || (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(1))) { return success(); diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 5c90df8e6..a29c2e16a 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/InliningUtils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -32,11 +31,11 @@ namespace { struct TorchInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const final { + BlockAndValueMapping &valueMapping) const final { return true; } bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, - IRMapping &) const final { + BlockAndValueMapping &) const final { return true; } }; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index df75d1b64..3386f4dd8 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -128,36 +128,32 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } -static Value getScalarIntValue(Value input, Location loc, - PatternRewriter &rewriter) { +static Value getScalarValue(Value input, Location loc, + PatternRewriter &rewriter) { auto inputType = input.getType(); if (inputType.isa()) { return input; } - - auto inputTensorType = inputType.dyn_cast(); - if (!inputTensorType) - return nullptr; - - Type inputDtype = inputTensorType.getOptionalDtype(); - if (!inputDtype || !inputDtype.isInteger(64)) - return nullptr; - - std::optional inputRank = getTensorRank(input); - if (!inputRank || *inputRank != 0) - return nullptr; - + Value scalar = nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() - .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + std::optional tensorRank = + getTensorRank(valueTensorLiteralOp.getResult()); + if (valueTensorLiteralOp && tensorRank && *tensorRank == 0) { + auto tensorType = + valueTensorLiteralOp.getValue().getType().cast(); + if (tensorType.getElementType().isa()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + scalar = rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } + } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { - return primNumToTensorScalarOp.getA(); + scalar = primNumToTensorScalarOp.getA(); } - return nullptr; + return scalar; } //===----------------------------------------------------------------------===// @@ -390,7 +386,7 @@ void PrimIfOp::getSuccessorRegions(std::optional index, // If the condition is constant, we can give a more precise answer. if (auto condAttr = operands.front().dyn_cast_or_null()) { Region *executedRegion = - condAttr.getValue().isOne() ? &getThenRegion() : &getElseRegion(); + condAttr.getValue().isOneValue() ? &getThenRegion() : &getElseRegion(); regions.push_back(RegionSuccessor(executedRegion)); return; } @@ -511,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs, return isValidSubtype(inputs[0], outputs[0]); } -OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) { +OpFoldResult DerefineOp::fold(ArrayRef operands) { auto uncheckedCast = getOperand().getDefiningOp(); if (!uncheckedCast) return nullptr; @@ -574,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { // Aten__RangeLengthOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { - auto lo = adaptor.getLo(); - auto hi = adaptor.getHi(); - auto step = adaptor.getStep(); +OpFoldResult Aten__RangeLengthOp::fold(ArrayRef operands) { + auto lo = operands[0]; + auto hi = operands[1]; + auto step = operands[2]; if (!lo || !hi || !step) return nullptr; auto loInt = lo.dyn_cast_or_null().getValue(); @@ -599,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { // Aten__DeriveIndexOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { - auto index = adaptor.getIndex(); - auto start = adaptor.getStart(); - auto step = adaptor.getStep(); +OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef operands) { + auto index = operands[0]; + auto start = operands[1]; + auto step = operands[2]; if (!index || !start || !step) return nullptr; auto indexInt = index.dyn_cast_or_null().getValue(); @@ -616,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { // Aten__Is__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Is__Op::fold(ArrayRef operands) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); } @@ -624,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { // Aten__Isnot__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Isnot__Op::fold(ArrayRef operands) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); } @@ -632,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { // Aten__Not__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Not__Op::fold(ArrayRef operands) { bool value; if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) return nullptr; @@ -643,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { // AtenNeBoolOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenNeBoolOp::fold(ArrayRef operands) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); @@ -659,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { // AtenSqueezeOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSqueezeOp::fold(ArrayRef operands) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(); @@ -671,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { // AtenSqueezeDimOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSqueezeDimOp::fold(ArrayRef operands) { if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); @@ -683,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { // AtenRoundOp //===----------------------------------------------------------------------===// -OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenRoundOp::fold(ArrayRef operands) { if (auto selfType = getSelf().getType().dyn_cast()) { if (selfType.hasDtype() && selfType.getDtype().isa()) return getSelf(); @@ -695,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { // AtenTypeAsOp //===----------------------------------------------------------------------===// -OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenTypeAsOp::fold(ArrayRef operands) { Type inType = getSelf().getType(); Type newType = getOther().getType(); @@ -709,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { // AtenToDtypeOp //===----------------------------------------------------------------------===// -OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenToDtypeOp::fold(ArrayRef operands) { bool nonBlocking, copyArg; // The non_blocking arg must be `False`. if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || @@ -740,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { // AtenToDtypeLayoutOp //===----------------------------------------------------------------------===// -OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef operands) { // The pin_memory arg should be either constant `False` or `none`. if (!getPinMemory().getType().isa()) { bool pinMemory; @@ -801,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { // AtenViewOp //===----------------------------------------------------------------------===// -OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenViewOp::fold(ArrayRef operands) { auto inputType = getOperand(0).getType().dyn_cast(); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; @@ -816,7 +812,7 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { // AtenDimOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenDimOp::fold(ArrayRef operands) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes()) return IntegerAttr::get(IntegerType::get(getContext(), 64), @@ -829,7 +825,7 @@ OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { // AtenLenTOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenLenTOp::fold(ArrayRef operands) { // `len([1,1,1])` -> `3`, if it is not mutated. if (auto listConstruct = getOperand().getDefiningOp()) { @@ -857,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenLenStrOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenLenStrOp::fold(ArrayRef operands) { if (auto stringConstruct = getS().getDefiningOp()) return getI64IntegerAttr(getContext(), stringConstruct.getValueAttr().getValue().size()); @@ -873,25 +869,22 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, if (op->getNumOperands() < 2) { return failure(); } - auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter); - auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter); + auto lhs = getScalarValue(op->getOperand(0), loc, rewriter); + auto rhs = getScalarValue(op->getOperand(1), loc, rewriter); auto outType = op->getResult(0).getType(); if (!lhs || !rhs) { return rewriter.notifyMatchFailure( op, "only int scalar lhs or rhs is supported"); } - if (isa(op)) { - Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter); + if (isa( + op)) { + Value alpha = getScalarValue(op->getOperand(2), loc, rewriter); if (!alpha) { return rewriter.notifyMatchFailure(op, "only int scalar alpha is supported"); } - if (isa(op)) - lhs = rewriter.create(loc, lhs, alpha); - else - rhs = rewriter.create(loc, rhs, alpha); + rhs = rewriter.create(loc, rhs, alpha); } if (isa(op)) { @@ -944,8 +937,6 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, result = rewriter.create(loc, lhs, rhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); - } else if (isa(op)) { - result = rewriter.create(loc, rhs, lhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } @@ -993,16 +984,6 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } -//===----------------------------------------------------------------------===// -// AtenRSubScalarOp -//===----------------------------------------------------------------------===// -void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) { - return rewrite0DBinaryTensorOp(op, rewriter); - }); -} - //===----------------------------------------------------------------------===// // AtenMulTensorOp //===----------------------------------------------------------------------===// @@ -1033,23 +1014,6 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns( }); } -//===----------------------------------------------------------------------===// -// AtenScalarImplicitOp -//===----------------------------------------------------------------------===// -void AtenScalarImplicitOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) { - Location loc = op.getLoc(); - Value a = op.getA(); - auto outType = op.getResult().getType(); - Value scalarValue = getScalarIntValue(a, loc, rewriter); - if (!scalarValue) - return failure(); - rewriter.replaceOpWithNewOp(op, outType, scalarValue); - return success(); - }); -} - //===----------------------------------------------------------------------===// // AtenSizeOp //===----------------------------------------------------------------------===// @@ -1128,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenSizeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSizeIntOp::fold(ArrayRef operands) { int64_t dim; if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) return nullptr; @@ -1168,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) { // AtenLtFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenLtFloatOp::fold(ArrayRef operands) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a < b; }); } @@ -1177,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { // AtenGtFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenGtFloatOp::fold(ArrayRef operands) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a > b; }); } @@ -1186,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { // AtenGeFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenGeFloatOp::fold(ArrayRef operands) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a >= b; }); } @@ -1195,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { // AtenEqFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenEqFloatOp::fold(ArrayRef operands) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a == b; }); } @@ -1261,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, // AtenNeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenNeIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a != b; }); } @@ -1270,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { // AtenEqIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenEqIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a == b; }); } @@ -1279,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { // AtenEqStrOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenEqStrOp::fold(ArrayRef operands) { if (getOperand(0) == getOperand(1)) return getI1IntegerAttr(getContext(), true); @@ -1295,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { // AtenLtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenLtIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a < b; }); } @@ -1304,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { // AtenLeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenLeIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a <= b; }); } @@ -1313,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { // AtenGtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a > b; }); } @@ -1322,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { // AtenGeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenGeIntOp::fold(ArrayRef operands) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a >= b; }); } @@ -1331,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { // AtenBoolFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenBoolFloatOp::fold(ArrayRef operands) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI1IntegerAttr(getContext(), c != 0.0); @@ -1342,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { // AtenBoolIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenBoolIntOp::fold(ArrayRef operands) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI1IntegerAttr(getContext(), c != 0); @@ -1353,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { // AtenFloatScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenFloatScalarOp::fold(ArrayRef operands) { // Constant fold int -> float conversion. - if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto integerAttr = operands[0].dyn_cast_or_null()) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); @@ -1366,27 +1330,13 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenIntFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { - // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { - return IntegerAttr::get( - mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), - static_cast(floatAttr.getValue().convertToDouble())); - } - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenIntScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenIntScalarOp::fold(ArrayRef operands) { // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto floatAttr = operands[0].dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), static_cast(floatAttr.getValue().convertToDouble())); @@ -1397,18 +1347,6 @@ OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenIntBoolOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { - bool b; - if (matchPattern(getOperand(), m_TorchConstantBool(&b))) { - return getI64IntegerAttr(getContext(), static_cast(b)); - } - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// @@ -1502,7 +1440,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( return success(); } -OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { +OpFoldResult ValueTensorLiteralOp::fold(ArrayRef operands) { return getValueAttr(); } @@ -1607,7 +1545,7 @@ void CopyToValueTensorOp::getEffects( // ConstantNoneOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) { +OpFoldResult ConstantNoneOp::fold(ArrayRef operands) { return TypeAttr::get(Torch::NoneType::get(getContext())); } @@ -1620,7 +1558,9 @@ void ConstantNoneOp::getAsmResultNames( // ConstantStrOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } +OpFoldResult ConstantStrOp::fold(ArrayRef operands) { + return getValueAttr(); +} void ConstantStrOp::getAsmResultNames( function_ref setNameFn) { @@ -1658,7 +1598,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); } -OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult Torch::ConstantIntOp::fold(ArrayRef operands) { return getValueAttr(); } @@ -1674,7 +1614,7 @@ void Torch::ConstantIntOp::getAsmResultNames( // ConstantFloatOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef operands) { return getValueAttr(); } @@ -1704,7 +1644,7 @@ void Torch::ConstantFloatOp::getAsmResultNames( // ConstantNumberOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) { +OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef operands) { return getValueAttr(); } @@ -1732,7 +1672,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns( // ConstantBoolOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) { +OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef operands) { return getValueAttr(); } @@ -1750,7 +1690,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs, return isValidSubtype(outputs[0], inputs[0]); } -OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) { +OpFoldResult PrimUncheckedCastOp::fold(ArrayRef operands) { if (auto derefineOp = getX().getDefiningOp()) { if (derefineOp.getOperand().getType() == getType()) return derefineOp.getOperand(); @@ -1884,7 +1824,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenEqIntListOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenEqIntListOp::fold(ArrayRef operands) { auto lhsLiteral = getA().getDefiningOp(); if (!lhsLiteral) return nullptr; @@ -1909,20 +1849,6 @@ OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// PrimTupleConstructOp -//===----------------------------------------------------------------------===// - -LogicalResult PrimTupleConstructOp::verify() { - if (!(isValidSubtype( - Torch::TupleType::get(getContext(), - llvm::to_vector<6>(getElements().getType())), - getResult().getType()))) - return emitOpError( - "failed to verify that contained types correspond to operand types"); - return success(); -} - //===----------------------------------------------------------------------===// // PrimTupleIndexOp //===----------------------------------------------------------------------===// @@ -2024,7 +1950,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) { // Aten__Getitem__DictStrOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef operands) { auto dictConstruct = getDictConstructIfNotModified(getSelf()); if (!dictConstruct) return nullptr; @@ -2042,7 +1968,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { // Aten__Contains__StrOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Contains__StrOp::fold(ArrayRef operands) { auto dictConstruct = getDictConstructIfNotModified(getDict()); if (!dictConstruct) return nullptr; @@ -2065,7 +1991,7 @@ static bool isListConstructNotModified(Value torchList) { }); } -OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) { +OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { auto itemConstruct = getItem(); if (!isListConstructNotModified(getL())) return nullptr; @@ -2126,55 +2052,43 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenFloordivIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - adaptor.getOperands(), - [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); + operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); + operands, [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); + operands, [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenSubIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSubIntOp::fold(ArrayRef operands) { return atenBinaryIntOperatorFoldHelper( - adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); + operands, [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { - auto list = getOperand(0).getDefiningOp(); - if (!list || !list->hasOneUse() || list.getElements().size() != 1) - return nullptr; - return list.getElements()[0]; -} - -//===----------------------------------------------------------------------===// -// AtenStackOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenCatOp::fold(llvm::ArrayRef operands) { auto list = getOperand(0).getDefiningOp(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; @@ -2185,7 +2099,7 @@ OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { // AtenSliceTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef operands) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) @@ -2204,7 +2118,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { // AtenMulIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); @@ -2215,70 +2129,46 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenSubFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { - return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), [](double a, double b) { return a - b; }); -} - //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA() || !adaptor.getB()) { +OpFoldResult AtenSubOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (operands[0].isa() && operands[1].isa()) { return atenBinaryIntOperatorFoldHelper( - adaptor.getOperands(), - [](int64_t a, int64_t b) -> int64_t { return a - b; }); + operands, [](int64_t a, int64_t b) -> int64_t { return a - b; }); } return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), - [](double a, double b) -> double { return a - b; }); + operands, [](double a, double b) -> double { return a - b; }); } //===----------------------------------------------------------------------===// // AtenDivOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA() || !adaptor.getB()) { +OpFoldResult AtenDivOp::fold(ArrayRef operands) { + if (!operands[0] || !operands[1]) { return nullptr; } // Since AtenDivOp always returns float value, we don't need to deal with the // case where the operands are both integers separately. return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), - [](double a, double b) -> double { return a / b; }); -} - -//===----------------------------------------------------------------------===// -// AtenPowIntFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA() || !adaptor.getB()) { - return nullptr; - } - return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); }); + operands, [](double a, double b) -> double { return a / b; }); } //===----------------------------------------------------------------------===// // AtenCeilScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA()) { +OpFoldResult AtenCeilScalarOp::fold(ArrayRef operands) { + if (!operands[0]) { return nullptr; } - auto floatValue = adaptor.getA().dyn_cast_or_null(); + auto floatValue = operands[0].dyn_cast_or_null(); if (!floatValue) { return nullptr; } @@ -2291,7 +2181,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { // AtenNegIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenNegIntOp::fold(ArrayRef operands) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI64IntegerAttr(getContext(), -c); @@ -2302,7 +2192,7 @@ OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { // AtenSqrtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getF64FloatAttr(getContext(), std::sqrt(c)); @@ -2313,7 +2203,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { // PrimDtypeOp //===----------------------------------------------------------------------===// -OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { +OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { BaseTensorType tensorType = getA().getType().cast(); if (tensorType.hasDtype()) { torch_upstream::ScalarType scalarType = @@ -2327,7 +2217,7 @@ OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { // AtenIntTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenIntTensorOp::fold(ArrayRef operands) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) @@ -2339,7 +2229,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // AtenFloatTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenFloatTensorOp::fold(ArrayRef operands) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Float.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) @@ -2351,7 +2241,7 @@ OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { // AtenDivFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { double lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); @@ -2368,7 +2258,7 @@ OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { // AtenDivIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenDivIntOp::fold(ArrayRef operands) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); @@ -2381,7 +2271,7 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { // AtenCeilFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { +OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI64IntegerAttr(getContext(), std::ceil(c)); @@ -2392,13 +2282,13 @@ OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { // PrimMaxIntOp //===----------------------------------------------------------------------===// -OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult PrimMaxIntOp::fold(ArrayRef operands) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); + auto lhs = operands[0].dyn_cast_or_null(); + auto rhs = operands[1].dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -2411,7 +2301,7 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { // PrimMinSelfIntOp //===----------------------------------------------------------------------===// -OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { +OpFoldResult PrimMinSelfIntOp::fold(ArrayRef operands) { auto list = getOperand().getDefiningOp(); if (!list) return nullptr; @@ -2430,25 +2320,6 @@ OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { *std::min_element(values.begin(), values.end())); } -//===----------------------------------------------------------------------===// -// PrimMinIntOp -//===----------------------------------------------------------------------===// - -OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { - // If both operands are the same, then the operation is an identity. - if (getA() == getB()) - return getA(); - - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); - if (!lhs || !rhs) - return nullptr; - // Torch semantics are that !torch.int is 64-bit signed. - return IntegerAttr::get( - lhs.getType(), - std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); -} - //===----------------------------------------------------------------------===// // ShapeCalculateOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 5cd48b3a1..968f809a0 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -68,32 +68,16 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return true; } - auto subtypeTensorType = subtype.dyn_cast(); - auto typeTensorType = type.dyn_cast(); - if (subtypeTensorType && typeTensorType) { - // Check that both tensors have the same `BaseTensorType` subtype. - // TODO: This is not subtyping according to PEP 483. See description - // of NonValueTensorType. - if (subtypeTensorType.isa() != - typeTensorType.isa()) - return false; - - // `type` must not have more static information than `subtype`, and `type` - // must not disagree with `subtype`. - if (typeTensorType.hasDtype() && - (!subtypeTensorType.hasDtype() || - typeTensorType.getDtype() != subtypeTensorType.getDtype())) { - return false; - } - - if (typeTensorType.hasSizes() && - (!subtypeTensorType.hasSizes() || - typeTensorType.getSizes() != subtypeTensorType.getSizes())) { - return false; - } - + // TODO: This is not subtyping according to PEP 483. See description + // of NonValueTensorType. + if (subtype.isa() && type.isa() && + type == + NonValueTensorType::getWithLeastStaticInformation(type.getContext())) + return true; + + if (subtype.isa() && type.isa() && + type == ValueTensorType::getWithLeastStaticInformation(type.getContext())) return true; - } return false; } @@ -479,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { } } - return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype); + return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype); } ////===----------------------------------------------------------------------===// @@ -521,4 +505,4 @@ DictType::verify(llvm::function_ref emitError, return failure(); } return success(); -} +} \ No newline at end of file diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c121877bb..082720ac3 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -4088,259 +4088,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %none : !torch.none\n" " }\n" -" func.func @__torch__.torch.jit._shape_functions.stack(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %str = torch.constant.str \"AssertionError: Tensors must have same number of dimensions\"\n" -" %str_0 = torch.constant.str \"AssertionError: Sizes of tensors must match except in dimension\"\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %str_1 = torch.constant.str \"AssertionError: \"\n" -" %none = torch.constant.none\n" -" %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" -" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" -" torch.prim.Loop %1, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %16 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" -" %17 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %18 = torch.aten.add.int %17, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %19 = torch.aten.le.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %20 = torch.prim.If %19 -> (!torch.int) {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %18 : !torch.int\n" -" }\n" -" %21 = torch.aten.neg.int %20 : !torch.int -> !torch.int\n" -" %22 = torch.aten.sub.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %23 = torch.aten.lt.int %arg1, %21 : !torch.int, !torch.int -> !torch.bool\n" -" %24 = torch.prim.If %23 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %31 = torch.aten.gt.int %arg1, %22 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %31 : !torch.bool\n" -" }\n" -" %25 = torch.aten.__not__ %24 : !torch.bool -> !torch.bool\n" -" torch.prim.If %25 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %26 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %27 = torch.prim.If %26 -> (!torch.int) {\n" -" %31 = torch.aten.add.int %arg1, %20 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.If.yield %31 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" -" }\n" -" %28 = torch.prim.ListConstruct : () -> !torch.list\n" -" %29 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" torch.prim.Loop %29, %true, init() {\n" -" ^bb0(%arg3: !torch.int):\n" -" %31 = torch.aten.__getitem__.t %16, %arg3 : !torch.list, !torch.int -> !torch.int\n" -" %32 = torch.aten.append.t %28, %31 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" torch.aten.insert.t %28, %27, %int1 : !torch.list, !torch.int, !torch.int\n" -" %30 = torch.aten.append.t %0, %28 : !torch.list>, !torch.list -> !torch.list>\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %3 = torch.aten.len.t %0 : !torch.list> -> !torch.int\n" -" torch.prim.Loop %3, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" -" %17 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %18 = torch.aten.gt.int %17, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %18 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %4 = torch.aten.len.t %0 : !torch.list> -> !torch.int\n" -" %5 = torch.derefine %none : !torch.none to !torch.optional\n" -" %6 = torch.prim.Loop %4, %true, init(%5) {\n" -" ^bb0(%arg2: !torch.int, %arg3: !torch.optional):\n" -" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" -" %17 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %19 = torch.prim.If %18 -> (!torch.bool) {\n" -" %22 = torch.aten.__getitem__.t %16, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %23 = torch.aten.eq.int %22, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %23 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %20 = torch.aten.__not__ %19 : !torch.bool -> !torch.bool\n" -" %21 = torch.prim.If %20 -> (!torch.optional) {\n" -" %22 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %23 = torch.prim.If %22 -> (!torch.int) {\n" -" %25 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %26 = torch.aten.le.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %27 = torch.prim.If %26 -> (!torch.int) {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %25 : !torch.int\n" -" }\n" -" %28 = torch.aten.neg.int %27 : !torch.int -> !torch.int\n" -" %29 = torch.aten.sub.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %30 = torch.aten.lt.int %arg1, %28 : !torch.int, !torch.int -> !torch.bool\n" -" %31 = torch.prim.If %30 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %35 = torch.aten.gt.int %arg1, %29 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %35 : !torch.bool\n" -" }\n" -" %32 = torch.aten.__not__ %31 : !torch.bool -> !torch.bool\n" -" torch.prim.If %32 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %33 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %34 = torch.prim.If %33 -> (!torch.int) {\n" -" %35 = torch.aten.add.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.If.yield %35 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" -" }\n" -" torch.prim.If.yield %34 : !torch.int\n" -" } else {\n" -" %25 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" -" torch.prim.If.yield %25 : !torch.int\n" -" }\n" -" %24 = torch.derefine %23 : !torch.int to !torch.optional\n" -" torch.prim.If.yield %24 : !torch.optional\n" -" } else {\n" -" torch.prim.If.yield %arg3 : !torch.optional\n" -" }\n" -" torch.prim.Loop.condition %true, iter(%21 : !torch.optional)\n" -" } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional\n" -" %7 = torch.aten.__is__ %6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" -" } else {\n" -" %16 = torch.prim.unchecked_cast %6 : !torch.optional -> !torch.int\n" -" torch.prim.If.yield %16 : !torch.int\n" -" }\n" -" %9 = torch.aten.len.t %0 : !torch.list> -> !torch.int\n" -" %10 = torch.aten.gt.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %11 = torch.aten.len.t %0 : !torch.list> -> !torch.int\n" -" %12 = torch.derefine %none : !torch.none to !torch.optional>\n" -" %13 = torch.prim.Loop %11, %true, init(%12) {\n" -" ^bb0(%arg2: !torch.int, %arg3: !torch.optional>):\n" -" %16 = torch.aten.__getitem__.t %0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" -" %17 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %18 = torch.prim.Loop %17, %true, init(%int1) {\n" -" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" -" %23 = torch.aten.__getitem__.t %16, %arg4 : !torch.list, !torch.int -> !torch.int\n" -" %24 = torch.aten.mul.int %arg5, %23 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%24 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %19 = torch.aten.eq.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %20 = torch.prim.If %19 -> (!torch.bool) {\n" -" %23 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %24 = torch.aten.eq.int %23, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %24 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %21 = torch.aten.__not__ %20 : !torch.bool -> !torch.bool\n" -" %22 = torch.prim.If %21 -> (!torch.optional>) {\n" -" %23 = torch.derefine %16 : !torch.list to !torch.optional>\n" -" torch.prim.If.yield %23 : !torch.optional>\n" -" } else {\n" -" torch.prim.If.yield %arg3 : !torch.optional>\n" -" }\n" -" torch.prim.Loop.condition %true, iter(%22 : !torch.optional>)\n" -" } : (!torch.int, !torch.bool, !torch.optional>) -> !torch.optional>\n" -" %14 = torch.aten.__is__ %13, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %15 = torch.prim.If %14 -> (!torch.list) {\n" -" torch.prim.If.yield %2 : !torch.list\n" -" } else {\n" -" %16 = torch.prim.unchecked_cast %13 : !torch.optional> -> !torch.list\n" -" %17 = torch.aten.len.t %0 : !torch.list> -> !torch.int\n" -" %18 = torch.prim.Loop %17, %true, init(%int0) {\n" -" ^bb0(%arg2: !torch.int, %arg3: !torch.int):\n" -" %22 = torch.aten.__getitem__.t %0, %arg2 : !torch.list>, !torch.int -> !torch.list\n" -" %23 = torch.aten.len.t %22 : !torch.list -> !torch.int\n" -" %24 = torch.prim.Loop %23, %true, init(%int1) {\n" -" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n" -" %29 = torch.aten.__getitem__.t %22, %arg4 : !torch.list, !torch.int -> !torch.int\n" -" %30 = torch.aten.mul.int %arg5, %29 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%30 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %25 = torch.aten.eq.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %26 = torch.prim.If %25 -> (!torch.bool) {\n" -" %29 = torch.aten.len.t %22 : !torch.list -> !torch.int\n" -" %30 = torch.aten.eq.int %29, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %30 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %27 = torch.aten.__not__ %26 : !torch.bool -> !torch.bool\n" -" %28 = torch.prim.If %27 -> (!torch.int) {\n" -" %29 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" %30 = torch.aten.len.t %22 : !torch.list -> !torch.int\n" -" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %31 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %32 = torch.aten.__range_length %int0, %29, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %32, %true, init() {\n" -" ^bb0(%arg4: !torch.int):\n" -" %35 = torch.aten.__derive_index %arg4, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %36 = torch.aten.ne.int %35, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %36 -> () {\n" -" %37 = torch.aten.__getitem__.t %16, %35 : !torch.list, !torch.int -> !torch.int\n" -" %38 = torch.aten.__getitem__.t %22, %35 : !torch.list, !torch.int -> !torch.int\n" -" %39 = torch.aten.eq.int %37, %38 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %39 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %33 = torch.aten.__getitem__.t %22, %8 : !torch.list, !torch.int -> !torch.int\n" -" %34 = torch.aten.add.int %arg3, %33 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.If.yield %34 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %arg3 : !torch.int\n" -" }\n" -" torch.prim.Loop.condition %true, iter(%28 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %19 = torch.prim.ListConstruct : () -> !torch.list\n" -" %20 = torch.aten.len.t %16 : !torch.list -> !torch.int\n" -" torch.prim.Loop %20, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %22 = torch.aten.__getitem__.t %16, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %23 = torch.aten.append.t %19, %22 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %21 = torch.aten._set_item.t %19, %8, %18 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" torch.prim.If.yield %19 : !torch.list\n" -" }\n" -" return %15 : !torch.list\n" -" }\n" " func.func @__torch__.torch.jit._shape_functions.permute(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %true = torch.constant.bool true\n" @@ -6043,10 +5790,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ceil\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6134,10 +5877,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.bucketize.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.contiguous\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6254,7 +5993,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " %none = torch.constant.none\n" " %false = torch.constant.bool false\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" @@ -6267,13 +6006,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.var.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.tuple, list> {\n" +" func.func @\"__torch_mlir_shape_fn.aten.var_mean.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.tuple, list> {\n" " %none = torch.constant.none\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -6296,7 +6035,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.std.correction\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -6810,9 +6549,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.new_empty\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" -" return %arg1 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6847,9 +6583,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.bernoulli.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.bernoulli.p\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten._index_put_impl\"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6863,9 +6596,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" @@ -7151,9 +6881,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.select_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7583,10 +7310,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7617,13 +7340,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" -" %1 = torch.derefine %int0 : !torch.int to !torch.any\n" -" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" -" return %2 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -8139,30 +7855,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %int11 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " %none = torch.constant.none\n" @@ -8233,30 +7925,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" -" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %int11 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -8976,6 +8644,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" +" %int15 = torch.constant.int 15\n" " %int5 = torch.constant.int 5\n" " %true = torch.constant.bool true\n" " %int4 = torch.constant.int 4\n" @@ -8990,7 +8659,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %12 : !torch.bool\n" " }\n" " %4 = torch.prim.If %3 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list\n" +" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" " %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" " %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" " torch.prim.If.yield %13 : !torch.bool\n" @@ -9012,7 +8681,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %12 : !torch.bool\n" " }\n" " %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %11 = torch.prim.ListConstruct %int5 : (!torch.int) -> !torch.list\n" +" %11 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list\n" " %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" " %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" " torch.prim.If.yield %13 : !torch.bool\n" diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 7cc699c04..516758740 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -10,6 +10,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index ce577cf5b..77f504f08 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -9,7 +9,6 @@ add_mlir_library(TorchMLIRTorchPasses LowerToBackendContract.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp - RecomposeComplexOps.cpp ReduceOpVariants.cpp RefinePublicReturn.cpp RefineTypes.cpp diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f54198e50..0a98ce9bc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -33,11 +33,9 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { int64_t dtypeInt; if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) return false; - FailureOr resDtype = + Type resDtype = getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); - if (failed(resDtype)) - return false; - return resDtype->isa(); + return resDtype.isa(); } // Helper function to compute the return type of the reduction function. @@ -72,7 +70,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, Type resultType = tensorType.getWithSizesAndDtype( sizes.size() == 0 ? std::optional>() - : llvm::ArrayRef(sizes), + : llvm::makeArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } @@ -108,7 +106,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, valueType .getWithSizesAndDtype( !valueType.hasSizes() ? std::optional>() - : llvm::ArrayRef(valueType.getSizes()), + : llvm::makeArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed)) .cast(); return rewriter @@ -142,7 +140,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, BaseTensorType inputType, Value scalar) { SmallVector sizes; Type rank0TensorTy = inputType.getWithSizesAndDtype( - ArrayRef(sizes), inputType.getOptionalDtype()); + makeArrayRef(sizes), inputType.getOptionalDtype()); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), ValueRange{}); @@ -171,37 +169,6 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, return sub; } -// Helper function to unsqueeze the input tensor at given dim. -// Return the unsqueezed tensor or failure. -static FailureOr unsqueezeTensor(PatternRewriter &rewriter, - Operation *op, Value input, Value dim) { - BaseTensorType inputType = input.getType().cast(); - if (!inputType.hasSizes()) { - return rewriter.notifyMatchFailure(op, "input tensor must have size"); - } - - SmallVector unsqueezedShape; - ArrayRef inputShape = inputType.getSizes(); - // `input` has a reduced rank. Hence add 1. - int64_t unsqueezedRank = inputShape.size() + 1; - int64_t dimInt = 0; - if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { - dimInt = toPositiveDim(dimInt, unsqueezedRank); - if (!isValidDim(dimInt, unsqueezedRank)) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - unsqueezedShape.append(inputShape.begin(), inputShape.end()); - unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1); - } else { - unsqueezedShape.resize(unsqueezedRank, kUnknownSize); - } - Type unsqueezedType = inputType.getWithSizesAndDtype( - unsqueezedShape, inputType.getOptionalDtype()); - Value unsqueezed = rewriter.create( - op->getLoc(), unsqueezedType, input, dim); - return unsqueezed; -} - namespace { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// number of dimensions across which the max needs to be computed. @@ -291,15 +258,6 @@ public: Value dim = op.getDim(); Value self = op.getSelf(); - // convert `start` to non-negative: start += int(start < 0) * dimSize - Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value isNegative = rewriter.create(loc, start, zero); - isNegative = rewriter.create(loc, isNegative); - Value dimSize = rewriter.create(loc, self, dim); - Value indexOffset = rewriter.create(loc, isNegative, dimSize); - start = rewriter.create(loc, start, indexOffset); - Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = @@ -637,128 +595,6 @@ public: }; } // namespace -// Decompose `aten.bucketize` into the following op sequence: -// -// def aten_bucketize(input, boundaries, out_int32, right): -// unsqz_input = input.unsqueeze(-1) -// if not right: -// comparison = unsqz_input <= boundaries -// else: -// comparison = unsqz_input < boundaries -// indices = torch.argmax(comparison.float(), dim=-1) -// within_bound = comparison[..., -1] -// result = torch.where(within_bound, indices, boundaries.shape[0]) -// if out_int32: -// result = result.int() -// return result -// -namespace { -class DecomposeAtenBucketizeTensorOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenBucketizeTensorOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - Value input = op.getSelf(); - auto inputType = input.getType().cast(); - if (!inputType.hasSizes()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: input must have known sizes"); - } - ArrayRef inputShape = inputType.getSizes(); - - Value boundaries = op.getBoundaries(); - auto boundariesType = boundaries.getType().cast(); - if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) { - return rewriter.notifyMatchFailure(op, - "unimplemented: boundaries must have " - "known sizes and must be a 1D array"); - } - int64_t boundariesSize = boundariesType.getSizes()[0]; - - bool outInt32; - if (!matchPattern(op.getOutInt32(), m_TorchConstantBool(&outInt32))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: out_int32 must be a constant bool"); - } - - bool right; - if (!matchPattern(op.getRight(), m_TorchConstantBool(&right))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: right must be a constant bool"); - } - - // unsqueeze input at the last dim to make it broadcastable with boundaries - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - auto unsqzTensorInfo = - unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne); - if (failed(unsqzTensorInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor"); - } - Value unsqzInput = *unsqzTensorInfo; - - // compare unsqueezed input with boundaries - SmallVector compareShape(inputShape); - compareShape.push_back(boundariesSize); - Type compareType = - inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type()); - Value compare; - if (!right) { - compare = rewriter.create(loc, compareType, unsqzInput, - boundaries); - } else { - compare = rewriter.create(loc, compareType, unsqzInput, - boundaries); - } - - // convert the comparison results to float32 as the argmax op input, - // which does not support integer dtype in LINALG backend - Value compareF32 = - convertTensorToDtype(rewriter, loc, compare, rewriter.getF32Type()); - - // get the first boundary index where the input element is less than (or - // equal to) the boundary value - Type indicesType = inputType.getWithSizesAndDtype( - inputShape, rewriter.getIntegerType(64, IntegerType::Signed)); - Value constFalse = rewriter.create(loc, false); - Value indices = rewriter.create(loc, indicesType, compareF32, - /*dim=*/constMinusOne, - /*keepdim=*/constFalse); - - // get the comparison results between each input element and the rightmost - // boundary value - Type withinUpperBoundType = - inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type()); - Value withinUpperBound = rewriter.create( - loc, withinUpperBoundType, compare, /*dim=*/constMinusOne, - /*index=*/constMinusOne); - - // If the input element is less than (or equal to) the rightmost boundary, - // take the max index as result. Otherwise, the element is beyond the - // rightmost boundary, so take the boundary size. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value upperBound = - rewriter.create(loc, boundaries, /*dim=*/constZero); - Value result = rewriter.create( - loc, indicesType, withinUpperBound, indices, upperBound); - - if (outInt32) { - result = convertTensorToDtype( - rewriter, loc, result, - rewriter.getIntegerType(32, IntegerType::Signed)); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - // To avoid overflow we use the following decomposition rule: // x_max = aten.max(x, dim, keepdim=True)[0] // shifted = x - x_max @@ -1055,50 +891,6 @@ public: }; } // namespace -// Decompose `aten.stack` into `aten.unsqueeze` and `aten.cat`. -namespace { -class DecomposeAtenStackOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenStackOp op, - PatternRewriter &rewriter) const override { - SmallVector tensors; - if (!getListConstructElements(op.getTensors(), tensors)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: the tensor list is not from list construct"); - } - // Ensure all tensors have known sizes - for (Value tensor : tensors) { - BaseTensorType tensorType = tensor.getType().cast(); - if (!tensorType.hasSizes()) { - return rewriter.notifyMatchFailure( - op, "unimplemented: one tensor does not have known sizes"); - } - } - - SmallVector unsqueezedTensors; - for (Value tensor : tensors) { - auto unsqueezedInfo = unsqueezeTensor(rewriter, op, tensor, op.getDim()); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate unsqueeze tensor op"); - } - unsqueezedTensors.push_back(*unsqueezedInfo); - } - - Type listElemType = - op.getType().cast().getWithSizesAndDtype( - /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); - Type listType = Torch::ListType::get(listElemType); - Value unsqueezedTensorList = rewriter.create( - op.getLoc(), listType, unsqueezedTensors); - rewriter.replaceOpWithNewOp(op, op.getType(), - unsqueezedTensorList, op.getDim()); - return success(); - } -}; -} // namespace - // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { @@ -1137,7 +929,7 @@ public: SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; - Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), + Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), selfTy.getOptionalDtype()); Value slice0 = rewriter.create( loc, sliceTy, input, dim, negShift, constNone, constOne); @@ -1274,9 +1066,9 @@ public: Type dtype = self.getType().cast().getOptionalDtype(); Type unsqueezedType = ValueTensorType::get( - context, llvm::ArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); + context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); + Type expandedType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = @@ -1434,25 +1226,6 @@ public: }; } // namespace -// Decompose aten.masked_fill.Scalar into aten.where.self op. -namespace { -class DecomposeAtenMaskedFillScalarOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto resType = op.getType().cast(); - Value mask = op.getMask(); - Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); - rewriter.replaceOpWithNewOp(op, resType, mask, - value, op.getSelf()); - return success(); - } -}; - -} // namespace // Decompose aten.convolution_overrideable to aten.convolution op. namespace { class DecomposeAtenConvolutionOverrideableOp @@ -2204,23 +1977,23 @@ public: // aten.bernoulli.float(x, p) = (randLike(float(x)) < tensor(p)).cast(type(x)). // Since the input x can be an integer tensor, it's important to cast it to // float type before passing it to the `aten.randLike` op. -template -class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { +class DecomposeValsemVariantAtenBernoulliFloatOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(BernoulliLikeOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ValsemVariantAtenBernoulliFloatOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); Value p = op.getP(); - if (!op.getGenerator().getType().template isa()) + if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( op, "The generator has to ben None because only global default " "generator is supported"); auto inputType = input.getType().cast(); SmallVector empty; - Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), + Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty), rewriter.getF64Type()); Value prob = rewriter.create(loc, tensorType, p); Value output; @@ -2298,8 +2071,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { std::vector meanVarSizes(inputRank, 1); for (int i = 0; i < axis; i++) meanVarSizes[i] = input.getSizes()[i]; - auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), - input.getOptionalDtype()); + auto meanVarType = input.getWithSizesAndDtype( + llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype()); auto nativeLayerNorm = rewriter.create( loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); @@ -2536,7 +2309,7 @@ class DecomposeAtenNativeBatchNormOp runningStatsShapeInt[1] = kUnknownSize; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( - context, llvm::ArrayRef(runningStatsShapeInt), dtype); + context, llvm::makeArrayRef(runningStatsShapeInt), dtype); runningMean = rewriter.create(loc, reshapeType, runningMean, runningStatsSizeList); @@ -2682,7 +2455,8 @@ public: SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); - Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); + Type tensorType = + outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); Value fillVal = rewriter.create(loc, tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); @@ -2718,7 +2492,7 @@ public: SmallVector transposeShape = llvm::to_vector(llvm::reverse(weightType.getSizes())); Type transposeType = weightType.getWithSizesAndDtype( - llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); + llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = rewriter.create(loc, transposeType, weight); @@ -2788,7 +2562,8 @@ public: SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); - Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); + Type tensorType = + outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); Value fillVal = rewriter.create( op.getLoc(), tensorType, op.getFillValue()); fillVal = @@ -3228,7 +3003,7 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { template static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, - bool unbiased, double correction) { + bool unbiased, int64_t correction) { Location loc = op.getLoc(); Value self = op.getSelf(); Value dimList = op.getDim(); @@ -3314,22 +3089,19 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, productDimSize = rewriter.create(loc, productDimSize, dimSize); } - productDimSize = rewriter.create(loc, productDimSize); - constantOne = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - Value cstCorrection = rewriter.create( - loc, rewriter.getF64FloatAttr(correction)); + Value cstCorrection = rewriter.create( + loc, rewriter.getI64IntegerAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); + Value productDimSizePlusOne = + rewriter.create(loc, productDimSize, constantOne); Value cond = - rewriter.create(loc, productDimSizePlusOne, cstCorrection); + rewriter.create(loc, productDimSizePlusOne, cstCorrection); rewriter.create( loc, cond, "correction value should be less than or equal to productDimSize + 1"); Value productDimSizeSubCorrection = - rewriter.create(loc, productDimSize, cstCorrection); + rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, productDimSizeSubCorrection); result = @@ -3356,7 +3128,7 @@ public: return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } - double correction = unbiased ? 1.0 : 0.0; + int64_t correction = unbiased ? 1 : 0; if (failed(calculateVariance(op, rewriter, unbiased, correction))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); @@ -3376,32 +3148,18 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarCorrectionOp op, PatternRewriter &rewriter) const override { - int64_t correctionValInt; - double correctionValFloat = 1.0; + int64_t correction; if (!op.getCorrection().getType().isa()) { - if (op.getCorrection().getType().isa()) { - if (!matchPattern(op.getCorrection(), - m_TorchConstantFloat(&correctionValFloat))) - return rewriter.notifyMatchFailure( - op, "Only support constant int or float correction value for " - "aten.var"); - } else if (op.getCorrection().getType().isa()) { - if (!matchPattern(op.getCorrection(), - m_TorchConstantInt(&correctionValInt))) - return rewriter.notifyMatchFailure( - op, "Only support constant int or float correction value for " - "aten.var"); - correctionValFloat = (double)correctionValInt; - } else { + if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correction))) return rewriter.notifyMatchFailure( - op, "unimplemented: correction value should be only constant int " - "or float for aten.var"); - } + op, "Only support constant int correction for aten.var"); + } else { + // The default value in case of `correction` being None is 1. + correction = 1; } - - bool unbiased = correctionValFloat == 0.0 ? false : true; + bool unbiased = correction == 0 ? false : true; if (failed(calculateVariance(op, rewriter, unbiased, - correctionValFloat))) + correction))) return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } @@ -3426,13 +3184,29 @@ public: rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = rewriter.create(loc, one.getType(), start, one); + BaseTensorType srcTensorType = src.getType().cast(); + SmallVector sizes; + if (!srcTensorType.hasSizes()) + return rewriter.notifyMatchFailure(op, "src tensor must have size"); - auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor op"); + ArrayRef srcShape = srcTensorType.getSizes(); + // `src` has a reduced rank. Hence add 1. + int64_t srcRank = srcShape.size() + 1; + int64_t dimInt = 0; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, srcRank); + if (!isValidDim(dimInt, srcRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + + sizes.append(srcShape.begin(), srcShape.end()); + sizes.insert(sizes.begin() + dimInt, 1); + + } else { + sizes.resize(srcShape.size() + 1, kUnknownSize); } - src = *unsqueezedInfo; + Type srcType = srcTensorType.getWithSizesAndDtype( + llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype()); + src = rewriter.create(loc, srcType, src, dim); rewriter.replaceOpWithNewOp( op, op.getSelf().getType(), self, src, dim, start, startPlusOne, /*step=*/one); @@ -3529,7 +3303,7 @@ public: op, "Expected the input tensor to have sizes"); BaseTensorType subType = inputType - .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), + .getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()), resultType.getOptionalDtype()) .cast(); @@ -3556,29 +3330,6 @@ public: }; } // namespace -namespace { -// Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op -class DecomposeAtenNormScalarOptDimOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Value none = rewriter.create(loc); - Value ord = op.getP(); - if (ord.getType().isa()) { - ord = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); - } - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), - /*dtype=*/none); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenRandintLowOp : public OpRewritePattern { public: @@ -3775,40 +3526,6 @@ public: }; } // namespace -namespace { -// Decompose `aten.randn_like` op into `aten.randn.generator` op. -class DecomposeAtenRandnLikeOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenRandnLikeOp op, - PatternRewriter &rewriter) const override { - // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { - int64_t memoryFormat; - if (!matchPattern(op.getMemoryFormat(), - m_TorchConstantInt(&memoryFormat))) - return rewriter.notifyMatchFailure( - op, "unimplemented: the memory format should be specified in " - "an integer constant"); - if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && - memoryFormat != torch_upstream::MemoryFormat::Preserve) - return rewriter.notifyMatchFailure( - op, "unimplemented: only none, contiguous and preserve " - "memory_format is supported"); - } - Value none = rewriter.create(op.getLoc()); - auto sizeListType = - Torch::ListType::get(Torch::IntType::get(op.getContext())); - Value sizeList = - rewriter.create(op.getLoc(), sizeListType, op.getSelf()); - rewriter.replaceOpWithNewOp( - op, op.getType(), sizeList, /*generator=*/none, op.getDtype(), - op.getLayout(), op.getDevice(), op.getPinMemory()); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: @@ -3829,49 +3546,6 @@ public: }; } // namespace -namespace { -class DecomposeAtenNewEmptyStridedOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, - PatternRewriter &rewriter) const override { - SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); - - // We only support the cases with default stride values. - // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) - // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and - // stride[2] == 1. - bool isDefaultStride = true; - for (unsigned i = 0; i < strideListInts.size(); i++) { - int64_t defaultStride = 1; - for (unsigned j = i + 1; j < sizeListInts.size(); j++) - defaultStride *= sizeListInts[j]; - if (defaultStride != strideListInts[i]) { - isDefaultStride = false; - break; - } - } - - if (!isDefaultStride) - return rewriter.notifyMatchFailure( - op, "only default strides supported for new_empty_strided op"); - - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), - op.getLayout(), op.getDevice(), op.getPinMemory()); - return success(); - } -}; -} // namespace - namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -3917,7 +3591,6 @@ public: DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -3925,7 +3598,6 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenConvolutionBackwardOverrideableOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -3968,11 +3640,8 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenBernoulliLikeOp>( + addPatternIfTargetOpIsIllegal( patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4019,7 +3688,6 @@ public: addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4027,12 +3695,9 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; @@ -4050,4 +3715,4 @@ std::unique_ptr> mlir::torch::Torch::createDecomposeComplexOpsPass( ArrayRef legalOps) { return std::make_unique(legalOps); -} +} \ No newline at end of file diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp index b5dcbbf58..450d84b22 100644 --- a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -10,9 +10,9 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17..a93db8d30 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -10,9 +10,9 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -244,7 +244,7 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable, continue; opsToMove.push_back(&op); } - IRMapping mapping; + BlockAndValueMapping mapping; for (Operation *op : opsToMove) { // The ops are used by `torch.slot` ops in the enclosing module. // Cloning avoids needing to handle those uses specially. @@ -329,7 +329,7 @@ template <> struct llvm::DenseMapInfo { // currently only analyzes a subset of ops. static LogicalResult analyzeInstances(func::FuncOp func, ArrayRef argInstances, - IRMapping &mapping) { + BlockAndValueMapping &mapping) { for (auto &argInstance : argInstances) mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance); auto walkResult = func.walk([&](PrimGetAttrOp op) { @@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(func::FuncOp func, } static FailureOr -createMonomorphizationForCall(func::CallOp op, IRMapping &mapping, +createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping, SymbolTable &symbolTable) { auto func = symbolTable.lookup(op.getCallee()); Monomorphization monomorphization; @@ -410,7 +410,7 @@ public: private: LogicalResult generateNewMonomorphizations(const Monomorphization &m) { auto func = m.func; - IRMapping mapping; + BlockAndValueMapping mapping; if (failed(analyzeInstances(func, m.argInstances, mapping))) return failure(); auto walkResult = func.walk([&](func::CallOp op) { @@ -495,7 +495,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable, // Rewrite `func`, given that all values of `NnModuleType` have been mapped in // `mapping` to corresponding global instances. static LogicalResult rewriteMonomorphizedFuncClone( - func::FuncOp func, IRMapping mapping, SymbolTable &symbolTable, + func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable, DenseMap &newFuncs, ObjectGraphInfo &objectGraphInfo) { @@ -662,7 +662,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { } for (auto &kv : newFuncs) { - IRMapping mapping; + BlockAndValueMapping mapping; if (failed(analyzeInstances(kv.second, kv.first.argInstances, mapping))) return failure(); if (failed(rewriteMonomorphizedFuncClone(kv.second, mapping, symbolTable, diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c..e48055570 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -27,8 +27,8 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/IRMapping.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -373,7 +373,7 @@ class InlineGlobalSlotsPass // big deal. SmallVector slice = getBackwardSliceIncludingRoot(initialValue); - IRMapping mapping; + BlockAndValueMapping mapping; OpBuilder builder(op); for (Operation *opInSlice : slice) builder.clone(*opInSlice, mapping); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 1f21a3656..a2db26627 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -285,16 +285,19 @@ public: } }; -class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase { +class VerifyBackendContractPass + : public VerifyBackendContractBase { public: - VerifyBackendContractNoDecompositionsPass() = default; - + VerifyBackendContractPass() = default; + VerifyBackendContractPass(bool decompose, + ArrayRef backendLegalOps) { + this->decompose = decompose; + this->backendLegalOps = backendLegalOps; + } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, /*decompose*/false, - /*backendLegalOps*/{}); + getBackendContractTarget(context, decompose, backendLegalOps); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { @@ -312,8 +315,10 @@ mlir::torch::Torch::createLowerToBackendContractPass( } std::unique_ptr> -mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { - return std::make_unique(); +mlir::torch::Torch::createVerifyBackendContractPass( + bool decompose, ArrayRef backendLegalOps) { + return std::make_unique(decompose, + backendLegalOps); } // The backend contract guarantees that ops with decompositions available will @@ -342,7 +347,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -350,7 +354,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -359,7 +362,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -392,7 +394,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -441,10 +442,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); for (std::string opName : backendLegalOps) { target.addLegalOp(OperationName(opName, context)); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 934ff7c25..4455ec1a7 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -106,7 +106,6 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // Clean up again to avoid needing to to back around the fixed-point // iteration. pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createRecomposeComplexOps()); // Reduce variants of ops to a smaller set of primitives. pm.addNestedPass(createReduceOpVariantsPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index 279cbc41d..eef98ee54 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -10,6 +10,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp deleted file mode 100644 index 7a5269946..000000000 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ /dev/null @@ -1,103 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::torch; -using namespace mlir::torch::Torch; - -namespace { -class RecomposeSliceCopy_ : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenCopy_Op op, - PatternRewriter &rewriter) const override { - if (!op.getSelf().getDefiningOp() || - !isa(op.getSelf().getDefiningOp())) - return failure(); - auto sliceOp = cast(op.getSelf().getDefiningOp()); - - // Get indices - int64_t dim; - if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim))) - return failure(); - int64_t end; - if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end))) - return failure(); - - Value newEnd = sliceOp.getEnd(); - if (end < 0) { - Value dimSize = rewriter.create( - op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); - newEnd = - rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); - } - - Value noneVal = rewriter.create(op.getLoc()); - Value falseVal = rewriter.create(op.getLoc(), false); - - // Create IndexPut_Op - BaseTensorType tensorType = op->getResultTypes()[0].cast(); - Value range = rewriter.create( - op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(), - /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, - /*pin_memory=*/noneVal); - - SmallVector indicesVector; - for (auto i = 0; i < dim - 1; i++) - indicesVector.push_back(noneVal); - indicesVector.push_back(range); - Value indices = rewriter.create( - op.getLoc(), - Torch::ListType::get(op->getContext(), - Torch::OptionalType::get(tensorType)), - indicesVector); - - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(), - /*accumulate=*/falseVal, /*unsafe=*/falseVal); - - return success(); - } -}; -} // namespace - -namespace { -class RecomposeComplexOps - : public DecomposeComplexOpsBase { -public: - RecomposeComplexOps() = default; - void runOnOperation() override { - MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - - // pattern.add calls go here - patterns.add(context); - - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.maxIterations = GreedyRewriteConfig::kNoLimit; - - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { - return signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> -mlir::torch::Torch::createRecomposeComplexOps() { - return std::make_unique(); -} diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 550f0de5a..6427f9ad0 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -59,6 +59,7 @@ #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -80,9 +81,7 @@ using namespace mlir::torch::Torch; // ----------------------------------------------------------------------------- static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { - FailureOr result = - getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); - return failed(result) ? Type() : *result; + return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); } static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, @@ -112,6 +111,24 @@ static torch_upstream::TypeKind getTypeKind(Type type) { return torch_upstream::TypeKind::AnyType; } +/// Returns the dtype that assumes information from both `lhs` and `rhs`. +/// Returns `std::nullopt` if the types are contradictory. Note this can only +/// be used on the `dtype` from tensors and can't be used on other types like +/// scalar types. +static std::optional meetElementTypes(Type lhs, Type rhs) { + auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); }; + (void)isNullOrBuiltIn; + assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type"); + assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type"); + + if (!lhs) + return rhs; + if (!rhs) + return lhs; + if (lhs == rhs) + return lhs; + return std::nullopt; +} enum class OptionalKnowledge { unKnown, @@ -458,8 +475,7 @@ private: void visitAtenToDtypeLikeOp(OpTy op, ArrayRef operands); template void visitTypeConversionOp(OpTy op, ArrayRef operands); - template - void visitAtenCatLikeOp(OpTy op, ArrayRef operands); + void visitAtenCatOp(AtenCatOp op, ArrayRef operands); template void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef operands); @@ -547,9 +563,7 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { /*skipRankCheck=*/true); state = updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); - FailureOr result = - getTypeForScalarType(scalarType.getContext(), result_type(state)); - return failed(result) ? Type() : *result; + return getTypeForScalarType(scalarType.getContext(), result_type(state)); } static SmallVector> @@ -586,8 +600,7 @@ static Type getPromotedResultType(MLIRContext *context, return Type(); state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck); } - FailureOr result = getTypeForScalarType(context, result_type(state)); - return failed(result) ? Type() : *result; + return getTypeForScalarType(context, result_type(state)); } static Type getPromotedResultTypeAssumingNonZeroRank( @@ -636,26 +649,23 @@ void TypeAnalysis::visitOperation(Operation *op, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp, AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp, - AtenTanhBackwardOp, AtenHardtanhBackwardOp, - Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp, - AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp, - AtenBernoulli_FloatOp, AtenBernoulliTensorOp, + AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, + AtenAbsOp, AtenThresholdOp, AtenSquareOp, AtenUniformOp, + AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulliTensorOp, ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp, - AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, - AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, - AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, - AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp, - AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp, - Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp, - AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp, - AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp, - AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp, - AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, - AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, - Aten_IndexPutImplOp, AtenIndexPutOp, AtenCopyOp, AtenZeroOp, - AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp, - AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, - AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, + AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp, + AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp, + AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp, + AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, + Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, + AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp, + AtenSelectIntOp, AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp, + AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, + AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, + AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp, + AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp, + AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, + AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) { @@ -960,16 +970,9 @@ void TypeAnalysis::visitOperation(Operation *op, } else if (auto newEmpty = dyn_cast(op)) { visitConstantTensorNewLikeOp(newEmpty, operands); return; - } else if (auto newEmptyStrided = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmptyStrided, - operands); - return; } else if (auto randLike = dyn_cast(op)) { visitConstantTensorAllocLikeOp(randLike, operands); return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; } else if (auto toCopy = dyn_cast(op)) { visitConstantTensorAllocLikeOp(toCopy, operands); return; @@ -1005,10 +1008,7 @@ void TypeAnalysis::visitOperation(Operation *op, } if (auto cat = dyn_cast(op)) { - visitAtenCatLikeOp(cat, operands); - return; - } else if (auto stack = dyn_cast(op)) { - visitAtenCatLikeOp(stack, operands); + visitAtenCatOp(cat, operands); return; } @@ -1114,22 +1114,6 @@ void TypeAnalysis::visitOperation(Operation *op, return; } - if (auto bucketize = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - bool outInt32; - if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) && - outInt32) { - knowledge.dtype = - IntegerType::get(op->getContext(), 32, IntegerType::Signed); - } else { - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(bucketize.getResult(), knowledge); - return; - } - // Otherwise, this is an unknown operation, so reset the state. setAllToEntryStates(results); return; @@ -1354,26 +1338,30 @@ void TypeAnalysis::visitTypeConversionOp( // `torch.aten.cat` concatenates the given sequence of seq tensors in the given // dimension. The output has the same sizes as the input for all dimensions // except the given dimension. -template -void TypeAnalysis::visitAtenCatLikeOp(OpTy op, - ArrayRef operands) { +void TypeAnalysis::visitAtenCatOp(AtenCatOp op, + ArrayRef operands) { auto tensorList = op.getTensors(); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - auto listConstruct = tensorList.template getDefiningOp(); + auto listConstruct = tensorList.getDefiningOp(); if (!listConstruct) { incorporateKnowledge(op.getResult(), knowledge); return; } - SmallVector tensors = llvm::to_vector( - llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* { - return &getLatticeElement(v)->getValue(); + auto tensors = llvm::to_vector<4>( + llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge { + return getLatticeElement(v)->getValue(); })); - - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), tensors); - incorporateKnowledge(op->getResult(0), knowledge); + for (auto tensor : tensors) { + auto newDtype = meetElementTypes(knowledge.dtype, tensor.dtype); + if (!newDtype.has_value()) { + incorporateKnowledge(op.getResult(), knowledge); + return; + } + knowledge.dtype = newDtype.value(); + } + incorporateKnowledge(op.getResult(), knowledge); } void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) { @@ -1448,16 +1436,12 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; return getRefinedTensorType(tensorType, knowledge); } else if (auto optionalType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; if (knowledge.optional == OptionalKnowledge::isNone) return Torch::NoneType::get(v.getContext()); else if (knowledge.optional == OptionalKnowledge::notNone) { @@ -1472,8 +1456,6 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { if (!latticeElement) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; if (knowledge.kind == torch_upstream::TypeKind::IntType) return Torch::IntType::get(v.getContext()); if (knowledge.kind == torch_upstream::TypeKind::FloatType) diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 7b74e2c50..4e0411552 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -46,15 +46,10 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op, impliedTypeFromDtype = *torchType; } else if (auto originalResultType = result.getType().dyn_cast()) { - FailureOr builtinType = - getTypeForScalarType(op->getContext(), dtypeScalarType); - if (failed(builtinType)) { - return rewriter.notifyMatchFailure( - op, "Failed to convert `dtypeScalarType` to a builtin type"); - } impliedTypeFromDtype = originalResultType.cast().getWithSizesAndDtype( - originalResultType.getOptionalSizes(), *builtinType); + originalResultType.getOptionalSizes(), + getTypeForScalarType(op->getContext(), dtypeScalarType)); } else { return rewriter.notifyMatchFailure(op, "Unimplemented: Expected result type to " diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index f8d3651d9..71d6731e1 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -10,7 +10,7 @@ #include "PassDetail.h" #include "SimplifyAbstractInterpCalculationsUtils.h" -#include "mlir/IR/IRMapping.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -47,7 +47,7 @@ public: Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); SmallVector blocksToMerge; - IRMapping bvm; + BlockAndValueMapping bvm; // TODO: Helper for region().front() auto condition = cast(op.getRegion().front().getTerminator()); @@ -129,7 +129,8 @@ public: // Truncate the list of users to the number of users we're going to // interpret. allUsers.resize(numUsersToInterpret); - auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); + auto usersToInterpret = + makeArrayRef(allUsers).take_front(numUsersToInterpret); // For each mutating op (which must be in the same block), we save the // current state of the list as a vector of Value's. These will then @@ -335,7 +336,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, auto originalResultType = result.getType().cast(); auto impliedTypesFromShape = originalResultType.cast() - .getWithSizesAndDtype(ArrayRef(sizes), + .getWithSizesAndDtype(makeArrayRef(sizes), originalResultType.getOptionalDtype()) .cast(); diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef9..37ffffabd 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -8,8 +8,6 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" -#include "llvm/Support/ErrorHandling.h" - namespace mlir { namespace torch { namespace torch_upstream { @@ -128,23 +126,6 @@ ScalarType result_type(const ResultTypeState &in_state) { combine_categories(in_state.zeroResult, in_state.wrappedResult)); } -ReductionType get_reduction_enum(const llvm::StringRef &reduce) { - if (reduce == "max" || reduce == "amax") { - return torch_upstream::ReductionType::MAX; - } else if (reduce == "mean") { - return torch_upstream::ReductionType::MEAN; - } else if (reduce == "min" || reduce == "amin") { - return torch_upstream::ReductionType::MIN; - } else if (reduce == "sum") { - return torch_upstream::ReductionType::SUM; - } else if (reduce == "prod") { - return torch_upstream::ReductionType::PROD; - } else { - llvm_unreachable( - "'reduce' argument must be either sum, prod, mean, amax or amin"); - } -} - } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d7fdf9481..1d67b24e8 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -83,10 +83,9 @@ Type Torch::getTypeForTorchType( llvm::report_fatal_error("unhandled type for getTypeForTorchType"); } -FailureOr -Torch::getTypeForScalarType(MLIRContext *context, - torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness) { +Type Torch::getTypeForScalarType( + MLIRContext *context, torch_upstream::ScalarType dtypeInt, + mlir::IntegerType::SignednessSemantics signedness) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); @@ -111,8 +110,6 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::ComplexType::get(Float64Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float128Type::get(context)); - case torch_upstream::ScalarType::Undefined: - return failure(); default: llvm::report_fatal_error("unhandled type for getTypeForScalarType"); } @@ -126,7 +123,6 @@ Torch::getTorchTypeForScalarType(MLIRContext *context, return Torch::FloatType::get(context); case torch_upstream::ScalarType::Long: return Torch::IntType::get(context); - case torch_upstream::ScalarType::Undefined: default: return failure(); } diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 79b3d4229..f352e0175 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -32,11 +32,11 @@ namespace { struct TorchConversionInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - IRMapping &valueMapping) const final { + BlockAndValueMapping &valueMapping) const final { return true; } bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, - IRMapping &) const final { + BlockAndValueMapping &) const final { return true; } }; diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index c858edb62..1b83cce37 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() { // FromI64Op //===----------------------------------------------------------------------===// -OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); +OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); if (attr) { return attr; } else { @@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { // ToI64Op //===----------------------------------------------------------------------===// -OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); +OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); if (attr) { return attr; } else { @@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { // ToF64Op //===----------------------------------------------------------------------===// -OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); +OpFoldResult ToF64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); if (attr) { return attr; } else { @@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { // FromF64Op //===----------------------------------------------------------------------===// -OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); +OpFoldResult FromF64Op::fold(llvm::ArrayRef operands) { + auto attr = operands[0].dyn_cast_or_null(); if (attr) { return attr; } else { diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5f3a2609b..3794602a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index a5d5f9b70..eaa15b00e 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR TorchMLIRTorchConversionToMLProgram MLIRMemRefTransforms) -if(TORCH_MLIR_ENABLE_STABLEHLO) +if(TORCH_MLIR_ENABLE_MHLO) list(APPEND LinkedLibs ChloPasses) endif() @@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp - VerifyStablehloBackendContract.cpp + VerifyMhloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 14d8f360b..ffffce244 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -21,8 +21,9 @@ #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mhlo/transforms/passes.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -52,13 +53,12 @@ void mlir::torch::registerTorchConversionPasses() { "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::PassPipelineRegistration< - TorchConversion::StablehloBackendPipelineOptions>( - "torch-backend-to-stablehlo-backend-pipeline", - "Pipeline lowering torch backend contract to StableHLO backend " +#ifdef TORCH_MLIR_ENABLE_MHLO + mlir::PassPipelineRegistration( + "torch-backend-to-mhlo-backend-pipeline", + "Pipeline lowering torch backend contract to MHLO backend " "contract.", - TorchConversion::createTorchBackendToStablehloBackendPipeline); + TorchConversion::createTorchBackendToMhloBackendPipeline); #endif } @@ -121,12 +121,11 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -void TorchConversion::createTorchBackendToStablehloBackendPipeline( +#ifdef TORCH_MLIR_ENABLE_MHLO +void TorchConversion::createTorchBackendToMhloBackendPipeline( OpPassManager &pm, - const TorchConversion::StablehloBackendPipelineOptions &options) { - // Generate Stablehlo ops. - pm.addNestedPass(createConvertTorchToStablehloPass( + const TorchConversion::MhloBackendPipelineOptions &options) { + pm.addNestedPass(createConvertTorchToMhloPass( options.enableStaticShape, options.enableI32Index)); // Clean up any non-canonical code introduced above.. @@ -134,13 +133,21 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( // The resolution of `dim` ops tends to create identical ops. CSE them. pm.addNestedPass(createCSEPass()); + // Convert CHLO ops to MHLO ops + pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + // Finish the type conversion from `torch` types to the types of the - // StableHLO backend contract. + // MHLO backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to Stablehlo and Chlo ops. - pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); + // Verify that we have lowered to the form that MHLO backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(TorchConversion::createVerifyMhloBackendContractPass()); } #endif diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp similarity index 66% rename from lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp rename to lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp index 888f29ade..aebf27599 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -6,9 +6,10 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// -#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#ifdef TORCH_MLIR_ENABLE_MHLO #include "PassDetail.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -17,7 +18,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; @@ -25,15 +25,17 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; namespace { -class VerifyStablehloBackendContractPass - : public VerifyStablehloBackendContractBase< - VerifyStablehloBackendContractPass> { +class VerifyMhloBackendContractPass + : public VerifyMhloBackendContractBase { void runOnOperation() override { + MLIRContext *context = &getContext(); + auto module = getOperation(); TypeConverter converter; converter.addConversion([](Type type) -> Type { auto elemTy = type; - if (isa(type)) + if (isa(type)) { elemTy = type.cast().getElementType(); + } if (BaseMemRefType::isValidElementType(elemTy)) return type; return nullptr; @@ -41,7 +43,6 @@ class VerifyStablehloBackendContractPass auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; - MLIRContext *context = &getContext(); ConversionTarget target(*context); // Structural operations. @@ -49,16 +50,26 @@ class VerifyStablehloBackendContractPass // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); + target.addLegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + + RewritePatternSet patterns(context); + if (failed(applyFullConversion(module, target, std::move(patterns)))) { + // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics + // doesn't unnecessarily spew out the entire module. + emitError(module.getLoc()) + << "Module does not conform to the MHLO backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } } }; } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { - return std::make_unique(); +mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() { + return std::make_unique(); } -#endif // TORCH_MLIR_ENABLE_STABLEHLO +#endif // TORCH_MLIR_ENABLE_MHLO diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 87a2b8f39..b01d62152 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -20,10 +20,6 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "mhlo/transforms/passes.h" -#endif - void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); @@ -38,11 +34,4 @@ void mlir::torch::registerAllPasses() { mlir::torch::registerConversionPasses(); mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::TMTensor::registerPasses(); - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::mhlo::registerSymbolicShapeOptimizationPass(); - mlir::mhlo::registerStablehloLegalizeToHloPass(); - mlir::mhlo::registerChloLegalizeToHloPass(); - mlir::mhlo::registerHloLegalizeToLinalgPass(); -#endif // TORCH_MLIR_ENABLE_STABLEHLO } diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 597f46381..a8f6766f2 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, loc, /*inputs=*/from, /*outputs=*/to, - /*indexingMaps=*/llvm::ArrayRef({id, id}), + /*indexingMaps=*/llvm::makeArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args.front()); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 0ba37a8ec..62f9b9ec4 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -45,16 +45,14 @@ endif() declare_mlir_python_sources(TorchMLIRPythonSources) declare_mlir_python_sources(TorchMLIRPythonExtensions) -if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES - __init__.py - compiler_utils.py - dynamo.py - ) -endif() +declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + __init__.py + compiler_utils.py + dynamo.py +) declare_mlir_python_sources(TorchMLIRPythonSources.Dialects ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" @@ -93,9 +91,7 @@ if(TORCH_MLIR_ENABLE_LTC) endif() # Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC, since it # generates a dummy Python library when disabled. -if(NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - add_subdirectory(torch_mlir/csrc/reference_lazy_backend) -endif() +add_subdirectory(torch_mlir/csrc/reference_lazy_backend) ################################################################################ # Optionally handle JIT IR importer. diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 443512a6d..3f08bb173 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -44,9 +44,9 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to TOSA. TOSA = "tosa" - # This output type consists of `stablehlo` dialect ops. It can be thought of - # as taking the `TORCH` output type and lowering it to StableHLO. - STABLEHLO = "stablehlo" + # This output type consists of `mhlo` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to MHLO. + MHLO = "mhlo" # Raw output of the JIT IR importer. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. @@ -242,7 +242,7 @@ class ExampleArgs: BACKEND_LEGAL_OPS = { OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], - OutputType.STABLEHLO: [], + OutputType.MHLO: [], } @@ -290,7 +290,7 @@ def compile(model: torch.nn.Module, # We only allow `backend_legal_ops` to be specified for the `"torch"` # output type because the other output types actually invoke their - # respective backends (Linalg, TOSA, or STABLEHLO), and those backends have + # respective backends (Linalg, TOSA, or MHLO), and those backends have # very specific requirements about the ops which are legal. # See `BACKEND_LEGAL_OPS` for more details. if backend_legal_ops is not None: @@ -404,14 +404,14 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: print(mb.module) return mb.module - elif output_type == OutputType.STABLEHLO: + elif output_type == OutputType.MHLO: run_pipeline_with_repro_report( mb.module, - "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", - "Lowering Torch Backend IR -> StableHLO Backend IR") + "builtin.module(torch-backend-to-mhlo-backend-pipeline)", + "Lowering Torch Backend IR -> MHLO Backend IR") if verbose: print("\n====================") - print("StableHLO Backend IR") + print("MHLO Backend IR") print(mb.module) return mb.module raise Exception(f"Unknown OutputType: {output_type}") diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index c275c0b2b..35b0151e9 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -44,7 +44,7 @@ def run_pipeline_with_repro_report(module, # Lower module in place to make it ready for compiler backends. with module.context: pm = PassManager.parse(pipeline) - pm.run(module.operation) + pm.run(module) except Exception as e: # TODO: More robust. # - don't arbitrarily clutter up /tmp. When a test suite has many diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index 3293c6e2f..68a604e28 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -71,7 +71,6 @@ add_library(torch_mlir_ltc_backend SHARED mlir_node.cpp ops/device_data.cpp ops/generic.cpp - utils/jit_utils.cpp utils/tensor_utils.cpp ) target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 0182952f8..ec234dc77 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -16,9 +16,6 @@ #include #include #include "torch-mlir-c/Registration.h" -#include "torch-mlir-c/Transforms.h" -#include "mlir-c/IR.h" -#include "mlir-c/Pass.h" #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" #include "backend_impl.h" @@ -26,7 +23,6 @@ #include "mlir_node.h" #include "utils/debug.h" #include "utils/exception.h" -#include "utils/jit_utils.h" #include "utils/string_utils.h" #include "utils/sys_utils.h" @@ -139,11 +135,6 @@ ComputationPtr TorchMlirLoweringContext::Build() { graph_->block()->registerOutput(output); } - // During operations lowering JIT may insert ScalarImplicit ops which output - // type !torch.number doesn't represent any existing MLIR type and should be - // refined either to Torch::IntType or Torch::FloatType. - torch::jit::ConvertScalarImplicit(graph_); - // Generate MLIR. MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp( /*context=*/mlir_context_, @@ -151,35 +142,12 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - - // Convert MlirOperation to MlirModule. - MlirLocation loc = mlirLocationUnknownGet(mlir_context_); - MlirModule module_op = mlirModuleCreateEmpty(loc); - MlirBlock block = mlirModuleGetBody(module_op); - mlirBlockAppendOwnedOperation(block, func_op); - - // Apply passes to verify generated MLIR. - auto pass_manager = mlirPassManagerCreate(mlir_context_); - mlirPassManagerAddOwnedPass( - pass_manager, - mlirCreateVerifyBackendContractNoDecompositions() - ); - - MlirLogicalResult result = mlirPassManagerRunOnOp( - pass_manager, - mlirModuleGetOperation(module_op) - ); - - if (mlirLogicalResultIsFailure(result)) { - throw std::runtime_error("MLIR verification has failed."); - } - - return CreateComputation(module_op); + return CreateComputation(func_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { +ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirOperation func_op) { return std::make_shared( - module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); + func_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); } torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { @@ -327,11 +295,11 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { /////////////////////////////////////////////////////////////////////////////// TorchMlirComputation::TorchMlirComputation( - MlirModule module_op, MlirContext mlir_context, + MlirOperation func_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) - : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), + : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), graph_(graph), input_output_aliases_(input_output_aliases), parameters_map_(parameters_map) { @@ -372,14 +340,7 @@ std::shared_ptr TorchMlirComputation::graph() const { return graph_; } -MlirOperation TorchMlirComputation::func_op() const { - MlirBlock block = mlirModuleGetBody(module_op_); - return mlirBlockGetFirstOperation(block); -} - -MlirModule TorchMlirComputation::module_op() const { - return module_op_; -} +MlirOperation TorchMlirComputation::func_op() const { return func_op_; } MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; @@ -424,7 +385,7 @@ const std::string TorchMlirComputation::to_string() const { *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; - mlirOperationPrint(mlirModuleGetOperation(module_op_), print_callback, &ss); + mlirOperationPrint(func_op_, print_callback, &ss); return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index f62a71ce7..61e18f410 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -73,7 +73,7 @@ public: // embedded builder (returned by the builder() API). torch::lazy::ComputationPtr Build() override; - virtual torch::lazy::ComputationPtr CreateComputation(MlirModule module_op); + virtual torch::lazy::ComputationPtr CreateComputation(MlirOperation func_op); // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the @@ -123,7 +123,7 @@ public: using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; TorchMlirComputation( - MlirModule module_op, MlirContext mlir_context, + MlirOperation func_op, MlirContext mlir_context, const std::shared_ptr& graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases); @@ -142,8 +142,6 @@ public: MlirOperation func_op() const; - MlirModule module_op() const; - MlirContext mlir_context() const; virtual const std::string debug_string() const; @@ -157,7 +155,7 @@ protected: std::vector parameter_shapes_; Shape result_shape_; - MlirModule module_op_; + MlirOperation func_op_; MlirContext mlir_context_; std::shared_ptr graph_; InputOutputAliases input_output_aliases_; diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index 6bed4513d..9e40285df 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -36,17 +36,6 @@ TorchMlirOpVector LowerTorchMlirBuiltin( const std::vector tensor_types, const std::vector& arguments, const std::vector& kwarguments) { - // Workaround for ListType::isSubtypeOfExt behavoir which leads to - // the problems with JIT schema matching, so we need to keep - // c10::ListType empty before magic_method->call function call. - auto dummy_graph = torch::jit::Graph(); - for (auto arg : arguments) { - torch::jit::Value* value = arg.value(dummy_graph); - if (value->type()->kind() == c10::TypeKind::ListType) { - value->setType(c10::ListType::create(c10::TensorType::get())); - } - } - auto builtin = std::make_shared(sym, at::nullopt); auto magic_method = std::make_shared("", builtin); diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 1cbb07262..2061fbee3 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -49,12 +49,5 @@ std::vector compute_shape_where( return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_bucketize( - const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, - bool right) { - auto dtype = out_int32 ? at::kInt : at::kLong; - return {Shape(dtype, self.sizes().vec())}; -} - } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp deleted file mode 100644 index 8d64f9fb7..000000000 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "jit_utils.h" - -#include - -#include - -namespace torch { -namespace jit { - -void ConvertScalarImplicit(std::shared_ptr& graph) { - DepthFirstGraphNodeIterator it(graph); - for (auto* node = it.next(); node != nullptr; node = it.next()) { - if (node->kind() != c10::aten::ScalarImplicit) { - continue; - } - - auto input = node->input(0); - auto scalar_type = input->type()->cast()->scalarType(); - TORCH_CHECK(scalar_type, "scalar type is not defined for input value"); - - NodeKind node_type; - TypePtr output_type; - if (c10::isIntegralType(*scalar_type, false)) { - node_type = c10::aten::IntImplicit; - output_type = IntType::get(); - } else if (c10::isFloatingType(*scalar_type)) { - node_type = c10::aten::FloatImplicit; - output_type = FloatType::get(); - } else { - throw std::runtime_error( - "Expected isIntegralType or isFloatingType"); - } - - Value * output = graph - ->create(node_type, {input}) - ->insertBefore(node) - ->output() - ->setType(output_type); - node->output()->replaceAllUsesWith(output); - node->destroy(); - } -} - -} // namespace jit -} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h b/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h deleted file mode 100644 index 2c4214cfc..000000000 --- a/python/torch_mlir/csrc/base_lazy_backend/utils/jit_utils.h +++ /dev/null @@ -1,10 +0,0 @@ -#include - -namespace torch { -namespace jit { - -// Convert ScalarImplicit to IntImplicit or FloatImplicit. -TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); - -} // namespace jit -} // namespace torch diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 3acfae8df..3d69d52ca 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -127,9 +127,6 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) -def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: - return upstream_shape_functions.unary(grad_output) - def aten〇ceil〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -196,9 +193,6 @@ def aten〇dropout〡shape(input: List[int], p: float, train: bool) -> List[int] def aten〇gelu〡shape(self: List[int], approximate: str = "none") -> List[int]: return upstream_shape_functions.unary(self) -def aten〇bucketize〇Tensor〡shape(self: List[int], boundaries: List[int], out_int32: bool = False, right: bool = False) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇contiguous〡shape(self: List[int], memory_format: int = 0) -> List[int]: return upstream_shape_functions.unary(self) @@ -289,16 +283,16 @@ def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[in def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]: return [] -def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> List[int]: +def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: int, output_dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(inp, dims, False, None) def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) -def aten〇var〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: +def aten〇var〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) -def aten〇var_mean〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: +def aten〇var_mean〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) return out, out @@ -311,7 +305,7 @@ def aten〇std〡shape(self: List[int], unbiased: bool = True) -> List[int]: def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) -def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: +def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): @@ -552,9 +546,6 @@ def aten〇new_ones〡shape(self: List[int], size: List[int], dtype: Optional[in def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size -def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: - return size - def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -586,9 +577,6 @@ def aten〇bernoulli〇float〡shape(self: List[int], p: float = 0.5, generator: def aten〇bernoulli〇Tensor〡shape(self: List[int], p: List[int], generator: Any = None) -> List[int]: return self -def aten〇bernoulli〇p〡shape(self: List[int], p: float, generator: Any = None) -> List[int]: - return self - def aten〇_index_put_impl〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False, unsafe: bool = False) -> List[int]: return upstream_shape_functions.unary(self) @@ -601,9 +589,6 @@ def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self -def aten〇randn_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: - return self - def aten〇randint〇low〡shape(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -814,9 +799,6 @@ def aten〇select〇int〡shape(self: List[int], dim: int, index: int) -> List[i def aten〇select_scatter〡shape(self: List[int], src: List[int], dim: int, index: int) -> List[int]: return self -def aten〇scatter_reduce〇two〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]: - return self - def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -977,9 +959,6 @@ def aten〇index〇Tensor_hacked_twin〡shape(self: List[int], indices: List[Lis def aten〇cat〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.cat(tensors, dim) -def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: - return upstream_shape_functions.stack(tensors, dim) - def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self @@ -1031,9 +1010,6 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) -def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: - return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) - def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return [self[0], self[1], output_size[0], output_size[1]] @@ -1340,15 +1316,6 @@ def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype - self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`self` cannot be complex" - return torch.bool - @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) + @@ -1394,15 +1361,6 @@ def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp assert not is_complex_dtype(other_dtype), "`self` cannot be complex" return torch.bool -@check_dtype_function( - _check_two_tensor_op(input_error_types={torch.complex64, torch.complex128})) -def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - other_rank, other_dtype = other_rank_dtype - self_rank, self_dtype = self_rank_dtype - assert not is_complex_dtype(self_dtype), "`self` cannot be complex" - assert not is_complex_dtype(other_dtype), "`self` cannot be complex" - return torch.bool - @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) @@ -1755,7 +1713,7 @@ def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: @check_dtype_function( _check_tensors_with_the_same_dtype( tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)], - error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16}) + + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.float16, torch.bfloat16}) + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), @@ -1764,8 +1722,8 @@ def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype - assert (input_dtype == torch.int64 or not is_integer_dtype(input_dtype)) and input_dtype not in [torch.float16] - assert (weight_dtype == torch.int64 or not is_integer_dtype(weight_dtype)) and weight_dtype not in [torch.float16] + assert (input_dtype == torch.int64 or not is_integer_dtype(input_dtype)) and input_dtype not in [torch.float16, torch.bfloat16] + assert (weight_dtype == torch.int64 or not is_integer_dtype(weight_dtype)) and weight_dtype not in [torch.float16, torch.bfloat16] ranks: List[Optional[int]] = [input_rank, weight_rank] dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 3820e311a..fa962f41d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -183,7 +183,7 @@ def generate_library(globals_) -> str: mb.import_function(function) # Clean up the IR a bit before writing it out. pm = PassManager.parse("builtin.module(canonicalize)", context=mb.module.context) - pm.run(mb.module.operation) + pm.run(mb.module) # Munge the IR a bit to make it more systematically accessible. asm = mb.module.operation.get_asm() # We'd like a unique function prefix to avoid collisions with user- diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 56bb1ac25..de6510225 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -283,7 +283,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", - "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", @@ -318,7 +317,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") - emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") @@ -332,12 +331,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") - emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::randn_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) @@ -416,11 +413,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") - emit("aten::std.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)") + emit("aten::std.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") - emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)") - emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)") + emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") + emit("aten::var_mean.correction : (Tensor, int[]?, int?, bool) -> (Tensor, Tensor)") emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") @@ -469,7 +466,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)") emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::new_empty_strided : (Tensor, int[], int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") @@ -489,6 +485,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") emit("aten::select.int : (Tensor, int, int) -> (Tensor)") emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) + emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::max : (Tensor) -> (Tensor)") @@ -510,8 +507,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") - emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") + emit("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::FloatImplicit : (Tensor) -> (float)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") @@ -562,7 +558,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # List ops. emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True) - emit("aten::stack : (Tensor[], int) -> (Tensor)", has_folder=True) emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) @@ -584,9 +579,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) emit("aten::Float.str : (str) -> (float)") - emit("aten::Int.float : (float) -> (int)", has_folder=True) + emit("aten::Int.float : (float) -> (int)") emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True) - emit("aten::Int.bool : (bool) -> (int)", has_folder=True) # Primitive ops emit("aten::__range_length : (int, int, int) -> (int)", has_folder=True) @@ -607,7 +601,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)") - emit("aten::sub.float : (float, float) -> (float)", has_folder=True) + emit("aten::sub.float : (float, float) -> (float)") emit("aten::mul.float : (float, float) -> (float)") emit("aten::div.float : (float, float) -> (float)", has_folder=True) emit("aten::neg.float : (float) -> (float)") @@ -619,7 +613,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::ge.float_int : (float, int) -> (bool)") emit("aten::ne.float_int : (float, int) -> (bool)") emit("aten::gt.float_int : (float, int) -> (bool)") - emit("aten::pow.int_float : (int, float) -> (float)", has_folder=True) emit("aten::__and__.bool : (bool, bool) -> (bool)") emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) @@ -641,12 +634,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)") - emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) + emit("aten::ScalarImplicit : (Tensor) -> (Scalar)") # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") - emit("aten::hardtanh_backward : (Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)") emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::native_layer_norm_backward : (Tensor, Tensor, int[], Tensor, Tensor, Tensor?, Tensor?, bool[]) -> (Tensor, Tensor, Tensor)") @@ -666,7 +658,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) - emit("prim::min.int : (int, int) -> (int)", has_folder=True) + emit("prim::min.int : (int, int) -> (int)") emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)", has_folder=True) emit("prim::RaiseException : (str, str?) -> ()") @@ -683,7 +675,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # ========================================================================== emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") - emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") + emit("prims::var : (Tensor, int[]?, int, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") # ========================================================================== diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index 36b98c305..5b580ca57 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -8,7 +8,7 @@ from typing import List import torch from torch._functorch.compile_utils import strip_overloads from torch._decomp import get_decompositions -from torch._dynamo.backends.common import aot_autograd +from torch._dynamo.optimizations.training import aot_autograd import functorch import warnings @@ -49,20 +49,6 @@ def _get_decomposition_table(): # (the upstream decomposition we use here does), even though we have # support for aten.native_batch_norm_backward. aten._native_batch_norm_legit_functional, - aten.native_group_norm, - aten.split.Tensor, - aten.split_with_sizes, - aten.norm.ScalarOpt_dim, - aten.embedding_dense_backward, - aten.native_layer_norm_backward, - aten.slice_backward, - aten.select_backward, - aten.upsample_bilinear2d.vec, - aten.mse_loss_backward, - aten.native_group_norm_backward, - aten.sigmoid_backward, - aten._native_batch_norm_legit, - aten._native_batch_norm_legit_no_training ]) diff --git a/python/torch_mlir_e2e_test/configs/__init__.py b/python/torch_mlir_e2e_test/configs/__init__.py index 4ca4c3dce..36fab40bd 100644 --- a/python/torch_mlir_e2e_test/configs/__init__.py +++ b/python/torch_mlir_e2e_test/configs/__init__.py @@ -7,6 +7,6 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig -from .stablehlo_backend import StablehloBackendTestConfig +from .mhlo_backend import MhloBackendTestConfig from .tosa_backend import TosaBackendTestConfig from .torchdynamo import TorchDynamoTestConfig diff --git a/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/python/torch_mlir_e2e_test/configs/mhlo_backend.py similarity index 74% rename from python/torch_mlir_e2e_test/configs/stablehlo_backend.py rename to python/torch_mlir_e2e_test/configs/mhlo_backend.py index 45f32bb0b..0b7b32534 100644 --- a/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/python/torch_mlir_e2e_test/configs/mhlo_backend.py @@ -8,8 +8,12 @@ from typing import Any import torch import torch_mlir -from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend -from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend +from torch_mlir_e2e_test.framework import ( + TestConfig, + Trace, + TraceItem +) from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, @@ -17,20 +21,20 @@ from .utils import ( ) -class StablehloBackendTestConfig(TestConfig): +class MhloBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with linalg-on-tensors. This class handles all the common lowering that torch-mlir does before reaching the linalg-on-tensors abstraction level. """ - - def __init__(self, backend: StablehloBackend): + def __init__(self, backend: MhloBackend): super().__init__() self.backend = backend def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile(program, example_args, output_type="stablehlo") + module = torch_mlir.compile( + program, example_args, output_type="mhlo") return self.backend.compile(module) @@ -42,6 +46,7 @@ class StablehloBackendTestConfig(TestConfig): outputs = getattr(backend_module, item.symbol)(*numpy_inputs) output = recursively_convert_from_numpy(outputs) result.append( - TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) - ) + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) return result diff --git a/python/torch_mlir_e2e_test/configs/torchdynamo.py b/python/torch_mlir_e2e_test/configs/torchdynamo.py index 2b16b1b92..044059818 100644 --- a/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -5,7 +5,6 @@ from typing import List -import numpy import torch import torch._dynamo as dynamo import torch_mlir @@ -15,14 +14,6 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem -def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: - for node in fx_graph.graph.nodes: - if node.op == "output": - assert len(node.args) == 1, "Output node must have a single argument" - node_arg = node.args[0] - if node_arg != (): - return False - return True @make_simple_dynamo_backend def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, @@ -41,16 +32,6 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, # for that right now since it is still very early stages, but eventually # this Config should test that path (and maybe the current behavior can # be moved to a `legacy_frontend_via_torchdynamo` config). - - # Torch-MLIR does not support returning an empty tuple. The reason is - # that both returning an empty tuple and returning `None` results in MLIR - # functions that have as a return type `()`. In other words, there is no - # way of differentiating between the two. Moreover, since Torch-MLIR treats - # inputs as having value semantics, graphs that return nothing are no-ops to - # Torch-MLIR. - if _returns_empty_tuple(fx_graph): - return fx_graph - mlir_module = torch_mlir.compile( fx_graph, example_inputs, output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() @@ -58,18 +39,13 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, loaded = backend.load(compiled) def compiled_callable(*inputs): - def refine_result_type(_result): - if isinstance(_result, tuple): - return tuple(refine_result_type(x) for x in _result) - elif isinstance(_result, numpy.ndarray): - return torch.from_numpy(_result) - elif isinstance(_result, (bool, int, float)): - return _result - else: - raise ValueError(f"Unhandled return type {type(_result)}") inputs = [x.numpy() for x in inputs] result = loaded.forward(*inputs) - return refine_result_type(result) + if not isinstance(result, tuple): + result = torch.from_numpy(result) + else: + result = tuple(torch.from_numpy(x) for x in result) + return result return compiled_callable @@ -83,17 +59,13 @@ class TorchDynamoTestConfig(TestConfig): return program def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: - def item_symbol_that_clones_inputs(*inputs): - cloned_inputs = [x.clone() for x in inputs] - result = getattr(artifact, item.symbol)(*cloned_inputs) - return result # TODO: Deepcopy the torch.nn.Module, so that if the program is # stateful then it does not mutate the original compiled program. result: Trace = [] for item in trace: f = lambda method, *inputs: method(*inputs) dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend)(f) - output = dynamo_f(item_symbol_that_clones_inputs, *item.inputs) + output = dynamo_f(getattr(artifact, item.symbol), *item.inputs) result.append( TraceItem(symbol=item.symbol, inputs=item.inputs, diff --git a/python/torch_mlir_e2e_test/framework.py b/python/torch_mlir_e2e_test/framework.py index f1fbad2ec..aae4aa925 100644 --- a/python/torch_mlir_e2e_test/framework.py +++ b/python/torch_mlir_e2e_test/framework.py @@ -184,8 +184,8 @@ class TestUtils: def rand(self, *sizes, low=0.0, high=1.0): return torch.empty(sizes).uniform_(low, high) - def randint(self, *sizes, low=0, high=10, dtype=torch.int64): - return torch.randint(low, high, sizes, dtype=dtype) + def randint(self, *sizes, low=0, high=10): + return torch.randint(low, high, sizes) def nans(self, *sizes): vals = torch.empty(sizes) diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 305eb7ca0..39376479a 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -162,7 +162,7 @@ LOWERING_PIPELINE = "builtin.module(" + ",".join([ "convert-math-to-libm", "convert-linalg-to-llvm", "expand-strided-metadata", - "finalize-memref-to-llvm", + "convert-memref-to-llvm", "lower-affine", "func.func(convert-arith-to-llvm)", "convert-func-to-llvm", diff --git a/python/torch_mlir_e2e_test/stablehlo_backends/__init__.py b/python/torch_mlir_e2e_test/mhlo_backends/__init__.py similarity index 100% rename from python/torch_mlir_e2e_test/stablehlo_backends/__init__.py rename to python/torch_mlir_e2e_test/mhlo_backends/__init__.py diff --git a/python/torch_mlir_e2e_test/stablehlo_backends/abc.py b/python/torch_mlir_e2e_test/mhlo_backends/abc.py similarity index 76% rename from python/torch_mlir_e2e_test/stablehlo_backends/abc.py rename to python/torch_mlir_e2e_test/mhlo_backends/abc.py index dbecbcc26..8fc51ac00 100644 --- a/python/torch_mlir_e2e_test/stablehlo_backends/abc.py +++ b/python/torch_mlir_e2e_test/mhlo_backends/abc.py @@ -10,30 +10,29 @@ import torch from torch_mlir.ir import Module -# A type shared between the result of `StablehloBackend.compile` and the -# input to `StablehloBackend.load`. Each backend will likely have a +# A type shared between the result of `MhloBackend.compile` and the +# input to `MhloBackend.load`. Each backend will likely have a # different definition of this type. -CompiledArtifact = TypeVar("CompiledArtifact") +CompiledArtifact = TypeVar('CompiledArtifact') # A wrapper around a backend-specific loaded program representation # that uniformly translates the `x.method(...)` interface expected of # Torch modules into appropriate lower-level operations. -Invoker = TypeVar("Invoker") +Invoker = TypeVar('Invoker') -class StablehloBackend(abc.ABC): - """The interface to an StableHLO backend. +class MhloBackend(abc.ABC): + """The interface to an MHLO backend. Backends are recommended to raise meaningful exceptions in case of error, ideally with easy reproduction instructions. """ - @abc.abstractmethod def compile(self, module: Module) -> CompiledArtifact: """Compile the provided MLIR module into a compiled artifact. - The module adheres to the StableHLO backend contract - (see the VerifyStablehloBackendContract pass). + The module adheres to the MHLO backend contract + (see the VerifyMhloBackendContract pass). The compiled artifact can be any type, but must be correctly interpreted by the `load` method. diff --git a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py similarity index 66% rename from python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py rename to python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py index b285c46b8..3ac1d6cd6 100644 --- a/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -7,32 +7,28 @@ from torch_mlir.ir import * from torch_mlir.passmanager import * from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( - RefBackendLinalgOnTensorsBackend, -) +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend -from .abc import StablehloBackend +from .abc import MhloBackend __all__ = [ - "LinalgOnTensorsStablehloBackend", + "LinalgOnTensorsMhloBackend", ] - -class LinalgOnTensorsStablehloBackend(StablehloBackend): - """Main entry-point for the linalg-on-tensors based StableHLO backend. +class LinalgOnTensorsMhloBackend(MhloBackend): + """Main entry-point for the linalg-on-tensors based MHLO backend. This currently uses the linalg-on-tensors RefBackend for actual execution. """ - def __init__(self): super().__init__() self.refbackend = RefBackendLinalgOnTensorsBackend() def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the StableHLO backend contract. + """Compiles an imported module that satisfied the MHLO backend contract. Args: - imported_module: The MLIR module consisting of funcs in the StableHLO + imported_module: The MLIR module consisting of funcs in the MHLO dialect. Returns: An opaque, backend specific compiled artifact object that can be @@ -40,9 +36,8 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): """ run_pipeline_with_repro_report( imported_module, - "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))", - "Lowering StableHLO to Linalg-on-Tensors", - ) + "builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))", + "Lowering MLIR-HLO to Linalg-on-Tensors") return self.refbackend.compile(imported_module) def load(self, module): diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index b0ea4dd8b..9dd80b0d2 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -8,10 +8,10 @@ # to the backend contract. COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", + "NormalizeModule_basic", "ResNet18Module_basic", "ResNet18StaticModule_basic", "MobilenetV3Module_basic", - "ReduceMaxAlongDimUnsignedInt_basic", } def register_all_tests(): @@ -44,7 +44,7 @@ def register_all_tests(): from . import histogram_binning_calibration from . import rng from . import cast - from . import scatter + from . import index_put from . import pooling from . import return_types from . import control_flow diff --git a/python/torch_mlir_e2e_test/test_suite/backprop.py b/python/torch_mlir_e2e_test/test_suite/backprop.py index 46d61d0e6..ece52a210 100644 --- a/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -58,28 +58,6 @@ def TanhBackward_basic(module, tu: TestUtils): # ============================================================================== -class HardtanhBackwardModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1, -1], torch.float32, True), - ]) - def forward(self, grad_out, input): - return torch.ops.aten.hardtanh_backward(grad_out, input, min_val=0.2, max_val=0.5) - - -@register_test_case(module_factory=lambda: HardtanhBackwardModule()) -def HardtanhBackward_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 20), tu.rand(10, 20)) - -# ============================================================================== - - class ConvolutionBackwardModule2D(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 96b67e09c..1e791fed6 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -621,106 +621,6 @@ def TensorsConcatNegativeDimModule_basic(module, tu: TestUtils): # ============================================================================== -class TensorsConcatPromoteDTypeModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ]) - def forward(self, x, y, z): - return torch.cat([x, y, z], dim=-2) - - -@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeModule()) -def TensorsConcatPromoteDTypeModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 2, 4, low=0, high=2).bool(), - tu.randint(2, 1, 4, low=0, high=100).int(), - tu.randint(2, 3, 4, low=0, high=100).long()) - - -# ============================================================================== - - -class TensorsStackModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x, y, z): - return torch.stack([x, y, z], 1) - - -@register_test_case(module_factory=lambda: TensorsStackModule()) -def TensorsStackModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4), tu.rand(2, 3, 4), tu.rand(2, 3, 4)) - - -# ============================================================================== - - -class TensorsStackNegativeDimModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x, y, z): - return torch.stack([x, y, z], dim=-2) - - -@register_test_case(module_factory=lambda: TensorsStackNegativeDimModule()) -def TensorsStackNegativeDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4), tu.rand(2, 3, 4), tu.rand(2, 3, 4)) - - -# ============================================================================== - - -class TensorsStackPromoteDTypeModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.bool, True), - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ]) - def forward(self, x, y, z): - return torch.cat([x, y, z], dim=-2) - - -@register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule()) -def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, 4, low=0, high=2).bool(), - tu.randint(2, 3, 4, low=0, high=100).int(), - tu.randint(2, 3, 4, low=0, high=100).long()) - - -# ============================================================================== - - class GatherModule(torch.nn.Module): def __init__(self): @@ -828,7 +728,7 @@ class AddSizeIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: AddSizeIntModule()) def AddSizeIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3)) + module.forward(torch.randn(3, 3)) # ============================================================================== @@ -853,7 +753,7 @@ class AddSizeIntNegDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: AddSizeIntNegDimModule()) def AddSizeIntNegDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3)) + module.forward(torch.randn(3, 3)) # ============================================================================== @@ -1004,7 +904,7 @@ class SoftmaxIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: SoftmaxIntModule()) def SoftmaxIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4)) + module.forward(torch.randn(3, 2, 4)) # ============================================================================== @@ -1026,7 +926,7 @@ class _SoftmaxModule(torch.nn.Module): @register_test_case(module_factory=lambda: _SoftmaxModule()) def _SoftmaxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4)) + module.forward(torch.randn(3, 2, 4)) # ============================================================================== @@ -1050,7 +950,7 @@ class SoftmaxIntNegDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: SoftmaxIntNegDimModule()) def SoftmaxIntNegDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4)) + module.forward(torch.randn(3, 2, 4)) # ============================================================================== @@ -1074,7 +974,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module): @register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module()) def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4).double()) + module.forward(torch.randn(3, 2, 4).double()) # ============================================================================== @@ -1096,7 +996,7 @@ class _LogSoftmaxModule(torch.nn.Module): @register_test_case(module_factory=lambda: _LogSoftmaxModule()) def _LogSoftmaxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4)) + module.forward(torch.randn(3, 2, 4)) # ============================================================================== @@ -1365,30 +1265,11 @@ class LogSoftmaxIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: LogSoftmaxIntModule()) def LogSoftmaxIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4).double()) + module.forward(torch.randn(3, 2, 4).double()) # ============================================================================== -class PrimMinIntModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ]) - def forward(self): - return torch.ops.prim.min(1, -1) - - -@register_test_case(module_factory=lambda: PrimMinIntModule()) -def PrimMinIntModule_basic(module, tu: TestUtils): - module.forward() - - -# ============================================================================== class NumToTensorIntModule(torch.nn.Module): @@ -1906,48 +1787,6 @@ def IndexTensorModule_basic(module, tu: TestUtils): module.forward(tu.rand(5), tu.randint(2, 3, high=4)) -# ============================================================================== -class IndexTensorStaticModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([4, 5], torch.float32, True), - ([2, 3], torch.int64, True), - ]) - def forward(self, x, index): - return torch.ops.aten.index(x, (index, )) - - -@register_test_case(module_factory=lambda: IndexTensorStaticModule()) -def IndexTensorStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4)) - -# ============================================================================== -class IndexTensorMultiIndexStaticModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([4, 5], torch.float32, True), - ([2, 3], torch.int64, True), - ([2, 3], torch.int64, True), - ]) - def forward(self, x, index1, index2): - return torch.ops.aten.index(x, (index1, index2)) - - -@register_test_case(module_factory=lambda: IndexTensorMultiIndexStaticModule()) -def IndexTensorMultiIndexStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5), tu.randint(2, 3, high=4), tu.randint(2, 3, high=4)) - - # ============================================================================== @@ -2787,46 +2626,6 @@ def LenStrModule_basic(module, tu: TestUtils): # ============================================================================== -class IntFloatModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.value = 1.0 - - @export - @annotate_args([ - None, - ]) - def forward(self): - return torch.ops.aten.Int(self.value) - -@register_test_case(module_factory=lambda: IntFloatModule()) -def IntFloatModule_basic(module, tu: TestUtils): - module.forward() - - -# ============================================================================== - -class AtenSubFloatModule(torch.nn.Module): - - def __init__(self): - super().__init__() - self.value1 = 1.0 - self.value2 = 2.0 - - @export - @annotate_args([ - None, - ]) - def forward(self): - return float(torch.ops.aten.sub(self.value1, self.value2)) - -@register_test_case(module_factory=lambda: AtenSubFloatModule()) -def AtenSubFloatModule_basic(module, tu: TestUtils): - module.forward() - - -# ============================================================================== class ScalarImplicitFloatModule(torch.nn.Module): @@ -2868,25 +2667,6 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils): # ============================================================================== -class PowIntFloat(torch.nn.Module): - - def __init__(self): - super().__init__() - self.value = 2 - self.power_value = 3.0 - - @export - @annotate_args([ - None, - ]) - def forward(self): - return torch.ops.aten.pow(self.value, self.power_value) - -@register_test_case(module_factory=lambda: IntFloatModule()) -def PowIntFloatModule_basic(module, tu: TestUtils): - module.forward() - -# ============================================================================== class BaddbmmDynamicModule(torch.nn.Module): @@ -3171,7 +2951,7 @@ class AtenEmbeddingBagSumExample(torch.nn.Module): @register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample()) def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils): - weight = tu.rand(100, 10) + weight = torch.rand(100, 10) indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) @@ -3193,7 +2973,7 @@ class Aten_EmbeddingBagExample(torch.nn.Module): @register_test_case(module_factory=lambda: Aten_EmbeddingBagExample()) def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): - weight = tu.rand(100, 10) + weight = torch.rand(100, 10) indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) module.forward(weight, indices, offsets) @@ -3234,23 +3014,6 @@ class CumsumStaticModule(torch.nn.Module): def CumsumStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 7, 4)) -class CumsumStaticNegativeDimModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 7, 4], torch.float32, True), - ]) - def forward(self, val): - return torch.ops.aten.cumsum(val, dim=-1) - -@register_test_case(module_factory=lambda: CumsumStaticNegativeDimModule()) -def CumsumStaticNegativeDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 7, 4)) - # ============================================================================== class AtenToDeviceModule(torch.nn.Module): @@ -3268,7 +3031,7 @@ class AtenToDeviceModule(torch.nn.Module): @register_test_case(module_factory=lambda: AtenToDeviceModule()) def AtenToDeviceModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) + module.forward(torch.randn(2, 4)) # ============================================================================== @@ -3360,113 +3123,4 @@ class SortIntListReverse(torch.nn.Module): @register_test_case(module_factory=lambda: SortIntListReverse()) def SortIntListReverse_basic(module, tu: TestUtils): - module.forward() - -# ============================================================================== - -class BucketizeTensorModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, input, boundaries): - return torch.bucketize(input, boundaries) - -@register_test_case(module_factory=lambda: BucketizeTensorModule()) -def BucketizeTensorModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) - -class BucketizeTensorOutInt32RightModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, input, boundaries): - return torch.bucketize(input, boundaries, out_int32=True, right=True) - -@register_test_case(module_factory=lambda: BucketizeTensorOutInt32RightModule()) -def BucketizeTensorOutInt32RightModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) - -class BucketizeTensorFloatModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) - def forward(self, input, boundaries): - return torch.bucketize(input, boundaries) - -@register_test_case(module_factory=lambda: BucketizeTensorFloatModule()) -def BucketizeTensorFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values) - -class BucketizeTensorStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 4], torch.int64, True), - ([3], torch.int64, True), - ]) - def forward(self, input, boundaries): - return torch.bucketize(input, boundaries) - -@register_test_case(module_factory=lambda: BucketizeTensorStaticModule()) -def BucketizeTensorStaticModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([[0, 2, 5, 7], [1, 3, 4, 6]]), torch.tensor([1, 4, 6])) - -class BucketizeTensorStaticFloatModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([15, 17], torch.float32, True), - ([16], torch.float32, True), - ]) - def forward(self, input, boundaries): - return torch.bucketize(input, boundaries) - -@register_test_case(module_factory=lambda: BucketizeTensorStaticFloatModule()) -def BucketizeTensorStaticFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(15, 17), torch.sort(tu.rand(16)).values) - - -# ============================================================================== - -class AtenFloatScalarModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([], torch.int64, True), - ]) - def forward(self, x): - a = torch.ops.aten.ScalarImplicit(x) - return torch.ops.aten.Float(a) - - -@register_test_case(module_factory=lambda: AtenFloatScalarModule()) -def AtenFloatScalarModule_basic(module, tu: TestUtils): - module.forward(tu.randint(high=5)) + module.forward() \ No newline at end of file diff --git a/python/torch_mlir_e2e_test/test_suite/cast.py b/python/torch_mlir_e2e_test/test_suite/cast.py index 613ba7e3b..b6867ad31 100644 --- a/python/torch_mlir_e2e_test/test_suite/cast.py +++ b/python/torch_mlir_e2e_test/test_suite/cast.py @@ -64,7 +64,7 @@ class TensorToFloatZeroRank(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToFloatZeroRank()) def TensorToFloatZeroRank_basic(module, tu: TestUtils): - module.forward(tu.rand().to(torch.float64)) + module.forward(torch.rand((), dtype=torch.float64)) # ============================================================================== @@ -83,7 +83,7 @@ class TensorToFloat(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToFloat()) def TensorToFloat_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 1).to(torch.float64)) + module.forward(torch.rand((1, 1), dtype=torch.float64)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 9755b8735..43157750b 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1095,7 +1095,7 @@ class ZeroFloat32Module(torch.nn.Module): @register_test_case(module_factory=lambda: ZeroFloat32Module()) def ZeroFloat32Module_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2)) + module.forward(torch.rand(3, 2)) class ZeroInt32Module(torch.nn.Module): @@ -1382,27 +1382,6 @@ def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): tu.randint(2, 3, high=2).to(dtype=torch.bool)) -class MaskedFillScalarFloatValueStaticModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 3], torch.int64, True), - ([2, 3], torch.bool, True), - ]) - def forward(self, x, mask): - return torch.ops.aten.masked_fill(x, mask, value=-0.01) - - -@register_test_case(module_factory=lambda: MaskedFillScalarFloatValueStaticModule()) -def MaskedFillScalarFloatValueStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(2, 3, low=-10, high=10), - tu.randint(2, 3, high=2).to(dtype=torch.bool)) - - class MaskedFillTensorFloatValueModule(torch.nn.Module): def __init__(self): @@ -1466,27 +1445,3 @@ class MaskedFillTensorIntValueStaticModule(torch.nn.Module): def MaskedFillTensorIntValueStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 3), tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint()) - - -# ============================================================================== - - -class NewEmptyStridedModuleDefaultDtype(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([2, 3, 4], torch.float32, True), - ]) - def forward(self, a): - x = torch.ops.aten.new_empty_strided(a, size=[2, 3, 4], stride=[12, 4, 1]) - y = x.copy_(a) - return x + y - - -@register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype()) -def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index d36c8b75a..a2236f3ef 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -165,7 +165,7 @@ class Convolution2DModule(torch.nn.Module): @register_test_case(module_factory=lambda: Convolution2DModule()) def Convolution2DModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class Convolution2DStaticModule(torch.nn.Module): @@ -191,7 +191,7 @@ class Convolution2DStaticModule(torch.nn.Module): @register_test_case(module_factory=lambda: Convolution2DStaticModule()) def Convolution2DStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class Convolution2DStridedModule(torch.nn.Module): def __init__(self): @@ -216,7 +216,7 @@ class Convolution2DStridedModule(torch.nn.Module): @register_test_case(module_factory=lambda: Convolution2DStridedModule()) def Convolution2DStridedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _Convolution2DAllFalseModule(torch.nn.Module): def __init__(self): @@ -245,7 +245,7 @@ class _Convolution2DAllFalseModule(torch.nn.Module): @register_test_case(module_factory=lambda: _Convolution2DAllFalseModule()) def _Convolution2DAllFalseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _Convolution2DBenchmarkModule(torch.nn.Module): def __init__(self): @@ -274,7 +274,7 @@ class _Convolution2DBenchmarkModule(torch.nn.Module): @register_test_case(module_factory=lambda: _Convolution2DBenchmarkModule()) def _Convolution2DBenchmarkModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _Convolution2DDeterministicModule(torch.nn.Module): def __init__(self): @@ -303,7 +303,7 @@ class _Convolution2DDeterministicModule(torch.nn.Module): @register_test_case(module_factory=lambda: _Convolution2DDeterministicModule()) def _Convolution2DDeterministicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _Convolution2DCudnnModule(torch.nn.Module): def __init__(self): @@ -332,7 +332,7 @@ class _Convolution2DCudnnModule(torch.nn.Module): @register_test_case(module_factory=lambda: _Convolution2DCudnnModule()) def _Convolution2DCudnnModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _Convolution2DTF32Module(torch.nn.Module): def __init__(self): @@ -361,7 +361,7 @@ class _Convolution2DTF32Module(torch.nn.Module): @register_test_case(module_factory=lambda: _Convolution2DTF32Module()) def _Convolution2DTF32Module_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): def __init__(self): @@ -389,7 +389,7 @@ class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module): @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DAllFalseModule()) def _ConvolutionDeprecated2DAllFalseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): def __init__(self): @@ -417,7 +417,7 @@ class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module): @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DBenchmarkModule()) def _ConvolutionDeprecated2DBenchmarkModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): def __init__(self): @@ -445,7 +445,7 @@ class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module): @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DDeterministicModule()) def _ConvolutionDeprecated2DDeterministicModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): def __init__(self): @@ -473,7 +473,7 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module): @register_test_case(module_factory=lambda: _ConvolutionDeprecated2DCudnnModule()) def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 10, 10), torch.randn(3, 3, 2, 2)) class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): @@ -498,7 +498,7 @@ class ConvolutionModule2DGroups(torch.nn.Module): @register_test_case(module_factory=lambda: ConvolutionModule2DGroups()) def ConvolutionModule2DGroups_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3)) + module.forward(torch.randn(1, 32, 4, 4), torch.randn(32, 8, 3, 3)) # ============================================================================== @@ -527,7 +527,7 @@ class ConvolutionModule2DTranspose(torch.nn.Module): @register_test_case(module_factory=lambda: ConvolutionModule2DTranspose()) def ConvolutionModule2DTranspose_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3, 4, 4), tu.rand(3, 3, 2, 2)) + module.forward(torch.randn(3, 3, 4, 4), torch.randn(3, 3, 2, 2)) class ConvolutionModule2DTransposeStrided(torch.nn.Module): @@ -554,7 +554,7 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module): @register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStrided()) def ConvolutionModule2DTransposeStrided_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): @@ -581,7 +581,7 @@ class ConvolutionModule2DTransposeStridedStatic(torch.nn.Module): @register_test_case(module_factory=lambda: ConvolutionModule2DTransposeStridedStatic()) def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) class Conv_Transpose2dModule(torch.nn.Module): @@ -608,7 +608,7 @@ class Conv_Transpose2dModule(torch.nn.Module): @register_test_case(module_factory=lambda: Conv_Transpose2dModule()) def Conv_Transpose2dModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + module.forward(torch.randn(5, 2, 5, 6), torch.randn(2, 5, 2, 2)) class UpSampleNearest2d(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index ba23c5222..41f7c8364 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -152,7 +152,7 @@ class ElementwiseAtenWhereSelfModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule()) def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils): - module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), tu.rand(1, 12, 5, 5), tu.rand()) + module.forward(torch.zeros(1, 1, 5, 5, dtype=torch.bool), torch.rand(1, 12, 5, 5), torch.rand(())) # ============================================================================== @@ -774,26 +774,6 @@ class RsubIntModule_noalpha(torch.nn.Module): def RsubIntModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, high=100)) -# ============================================================================== - - -class RsubInt0d_NumToTensor_Module(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ]) - def forward(self): - x = torch.ops.prim.NumToTensor(5) - return torch.rsub(x, 2) - - -@register_test_case(module_factory=lambda: RsubInt0d_NumToTensor_Module()) -def RsubInt0d_NumToTensor_Module_basic(module, tu: TestUtils): - module.forward() # ============================================================================== @@ -1520,7 +1500,7 @@ class ElementwiseRemainderScalarModule_Float(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 3)) + module.forward(torch.rand(10, 3)) # ============================================================================== @@ -1939,52 +1919,6 @@ def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAddScalar_NumToTensorFloat_Module(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ]) - def forward(self): - x = torch.ops.prim.NumToTensor(5.0) - return torch.add(x, 3) - - -@register_test_case( - module_factory=lambda: ElementwiseAddScalar_NumToTensorFloat_Module()) -def ElementwiseAddScalar_NumToTensorFloat_Module_basic(module, tu: TestUtils): - module.forward() - - -# ============================================================================== - - -class ElementwiseAddScalar_TensorLiteralInt32_Module(torch.nn.Module): - - def __init__(self): - super().__init__() - self.x = torch.tensor(2, dtype=torch.int32) - - @export - @annotate_args([ - None, - ]) - def forward(self): - return torch.add(self.x, 3) - - -@register_test_case( - module_factory=lambda: ElementwiseAddScalar_TensorLiteralInt32_Module()) -def ElementwiseAddScalar_TensorLiteralInt32_Module_basic(module, tu: TestUtils): - module.forward() - - -# ============================================================================== - - class ElementwiseCloneModule(torch.nn.Module): def __init__(self): @@ -2029,30 +1963,6 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.clone(x, memory_format=torch.channels_last) - - -@register_test_case( - module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule()) -def ElementwiseCloneChannelsLastMemoryFormatModule_basic( - module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5)) - - -# ============================================================================== - - class LiftFreshCopyModule(torch.nn.Module): def __init__(self): @@ -2379,7 +2289,7 @@ class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule()) def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 3, 5), tu.rand(2, 3, 3, 5)) + module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5)) # ============================================================================== @@ -2737,26 +2647,6 @@ class AtenRoundFloatModule(torch.nn.Module): def AtenRoundFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5, low = -3.0, high = 3.0)) - -class AtenRoundFloatHalfToEvenModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.round(x) - - -@register_test_case(module_factory=lambda: AtenRoundFloatHalfToEvenModule()) -def AtenRoundFloatHalfToEvenModule_basic(module, tu: TestUtils): - module.forward(torch.FloatTensor([[0.5, 1.5], [-0.5, -1.5]])) - - class AtenRoundIntModule(torch.nn.Module): def __init__(self): @@ -2795,7 +2685,7 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module): @register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat32()) def Fill_TensorFloat64WithFloat32_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4)) + module.forward(torch.randn(3, 2, 4)) class Fill_TensorFloat64WithFloat64(torch.nn.Module): @@ -2814,7 +2704,7 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module): @register_test_case(module_factory=lambda: Fill_TensorFloat64WithFloat64()) def Fill_TensorFloat64WithFloat64_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4).to(torch.float64)) + module.forward(torch.randn(3, 2, 4).to(torch.float64)) class Fill_TensorFloat64WithInt64(torch.nn.Module): @@ -2833,7 +2723,7 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module): @register_test_case(module_factory=lambda: Fill_TensorFloat64WithInt64()) def Fill_TensorFloat64WithInt64_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 2, 4).to(torch.float64)) + module.forward(torch.randn(3, 2, 4).to(torch.float64)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 0b7214365..89d334628 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -144,46 +144,6 @@ def ElementwiseGeFloatIntScalarModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseGeFloatTensorModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) - def forward(self, x, y): - return torch.ge(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseGeFloatTensorModule()) -def ElementwiseGeFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) - -# ============================================================================== - -class ElementwiseGeIntTensorModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, x, y): - return torch.ge(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseGeIntTensorModule()) -def ElementwiseGeIntTensorModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) - -# ============================================================================== - class ElementwiseGtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -358,46 +318,6 @@ def ElementwiseLeFloatIntScalarModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseLeFloatTensorModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ([-1], torch.float32, True), - ]) - def forward(self, x, y): - return torch.le(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseLeFloatTensorModule()) -def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5), tu.rand(5)) - -# ============================================================================== - -class ElementwiseLeIntTensorModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.int64, True), - ([-1], torch.int64, True), - ]) - def forward(self, x, y): - return torch.le(x, y) - - -@register_test_case(module_factory=lambda: ElementwiseLeIntTensorModule()) -def ElementwiseLeIntTensorModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) - -# ============================================================================== - class ElementwiseLtFloatTensorModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py index 9e6e2588b..ca95e36a3 100644 --- a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py +++ b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py @@ -89,7 +89,7 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module): @register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature()) def HBC_basic(module, tu: TestUtils): - logits = tu.rand(NUM_LOGITS) + logits = torch.rand(NUM_LOGITS, dtype=torch.float) segment_lengths: Tensor = tu.randint(NUM_LOGITS, high=2).to(torch.int) segment_offsets: Tensor = torch.cumsum(segment_lengths, 0) segment_offsets: Tensor = torch.cat( diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/index_put.py similarity index 78% rename from python/torch_mlir_e2e_test/test_suite/scatter.py rename to python/torch_mlir_e2e_test/test_suite/index_put.py index 11fad62fa..4fb74511b 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -818,151 +818,3 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module): def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4), tu.randint(5, 8, 6, high=1000)) - -# ============================================================================== - -class ScatterReduceFloatModule(torch.nn.Module): - include_self: bool - reduce_type: str - - def __init__(self, reduce_type: str, include_self: bool): - super().__init__() - self.include_self = include_self - self.reduce_type = reduce_type - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, input, index, src): - return torch.ops.aten.scatter_reduce(input, 0, index, src, self.reduce_type, include_self=self.include_self) - - -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("sum", False)) -def ScatterReduceFloatSumModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("sum", True)) -def ScatterReduceFloatSumModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("prod", False)) -def ScatterReduceFloatProdModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("prod", True)) -def ScatterReduceFloatProdModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amax", False)) -def ScatterReduceFloatMaxModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amax", True)) -def ScatterReduceFloatMaxModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amin", False)) -def ScatterReduceFloatMinModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("amin", True)) -def ScatterReduceFloatMinModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("mean", False)) -def ScatterReduceFloatMeanModule(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) -@register_test_case( - module_factory=lambda: ScatterReduceFloatModule("mean", True)) -def ScatterReduceFloatMeanModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), - tu.rand(5, 8, 6)) - -# ============================================================================== - -class ScatterReduceIntModule(torch.nn.Module): - include_self: bool - reduce_type: str - - def __init__(self, reduce_type: str, include_self: bool): - super().__init__() - self.include_self = include_self - self.reduce_type = reduce_type - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int64, True), - ([-1, -1, -1], torch.int32, True), - ]) - def forward(self, input, index, src): - return torch.ops.aten.scatter_reduce(input, 0, index, src, self.reduce_type, include_self=self.include_self) - - -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("sum", False)) -def ScatterReduceIntSumModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("sum", True)) -def ScatterReduceIntSumModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("prod", False)) -def ScatterReduceIntProdModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("prod", True)) -def ScatterReduceIntProdModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amax", False)) -def ScatterReduceIntMaxModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amax", True)) -def ScatterReduceIntMaxModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amin", False)) -def ScatterReduceIntMinModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("amin", True)) -def ScatterReduceIntMinModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("mean", False)) -def ScatterReduceIntMeanModule(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) -@register_test_case( - module_factory=lambda: ScatterReduceIntModule("mean", True)) -def ScatterReduceIntMeanModuleIncludeSelf(module, tu: TestUtils): - module.forward(tu.randint(10, 8, 6, dtype=torch.int32, high=10), tu.randint(2, 4, 3, high=4), - tu.randint(5, 8, 6, dtype=torch.int32, high=10)) - -# ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/index_select.py b/python/torch_mlir_e2e_test/test_suite/index_select.py index 6a426b454..1bb575ccb 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -28,7 +28,7 @@ class IndexSelectSingleIdxModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectSingleIdxModule()) def IndexSelectSingleIdxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([2])) + module.forward(torch.randn(4, 5, 6), torch.tensor([2])) class IndexSelectTwoIdxModule(torch.nn.Module): @@ -47,7 +47,7 @@ class IndexSelectTwoIdxModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectTwoIdxModule()) def IndexSelectTwoIdxModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([2, 4])) + module.forward(torch.randn(4, 5, 6), torch.tensor([2, 4])) class IndexSelectWholeDimensionModule(torch.nn.Module): @@ -66,7 +66,7 @@ class IndexSelectWholeDimensionModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectWholeDimensionModule()) def IndexSelectWholeDimensionModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([0, 1, 2, 3])) + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 1, 2, 3])) class IndexSelectWholeTensorModule(torch.nn.Module): @@ -85,7 +85,7 @@ class IndexSelectWholeTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectWholeTensorModule()) def IndexSelectWholeTensorModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3), torch.tensor([0, 1, 2])) + module.forward(torch.randn(3), torch.tensor([0, 1, 2])) class IndexSelectDynamicModule(torch.nn.Module): @@ -104,7 +104,7 @@ class IndexSelectDynamicModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectDynamicModule()) def IndexSelectDynamicModulebasic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([0, 4])) + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 4])) class IndexSelectDynamicInputSizeModule(torch.nn.Module): @@ -123,7 +123,7 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectDynamicInputSizeModule()) def IndexSelectDynamicInputSizeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([0, 2])) + module.forward(torch.randn(4, 5, 6), torch.tensor([0, 2])) class IndexSelectDynamicIndexSizeModule(torch.nn.Module): @@ -142,4 +142,4 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module): @register_test_case(module_factory=lambda: IndexSelectDynamicIndexSizeModule()) def IndexSelectDynamicIndexSizeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), torch.tensor([1, 2])) + module.forward(torch.randn(4, 5, 6), torch.tensor([1, 2])) diff --git a/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/python/torch_mlir_e2e_test/test_suite/nll_loss.py index 9dcd2eff2..f5eeb1f2c 100644 --- a/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -189,7 +189,7 @@ class NllLossModule_backwardWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backwardWeight()) def NllLossModuleBackwardWeight_basic(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) + torch.rand(4), torch.tensor(3.)) @@ -279,7 +279,7 @@ class NllLossModule_backwardMeanWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backwardMeanWeight()) def NllLossModuleBackwardMeanWeight_basic(module, tu: TestUtils): module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) + torch.rand(4), torch.tensor(3.)) class NllLossModule_backwardSum(torch.nn.Module): @@ -338,7 +338,7 @@ class NllLossModule_backwardSumWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backwardSumWeight()) def NllLossModuleBackwardSumWeight_basic(module, tu: TestUtils): module.forward(tu.rand(1), tu.rand(3, 4), torch.tensor([2, 3, 0]), - tu.rand(4), torch.tensor(3.)) + torch.rand(4), torch.tensor(3.)) class NllLossModule_backward1D(torch.nn.Module): @@ -397,7 +397,7 @@ class NllLossModule_backward1DWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backward1DWeight()) def NllLossModuleBackward1DWeight_basic(module, tu: TestUtils): module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + torch.rand(3), torch.tensor(3.)) class NllLossModule_backward1DMean(torch.nn.Module): @@ -456,7 +456,7 @@ class NllLossModule_backward1DMeanWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backward1DMeanWeight()) def NllLossModuleBackward1DMeanWeight_basic(module, tu: TestUtils): module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + torch.rand(3), torch.tensor(3.)) class NllLossModule_backward1DSum(torch.nn.Module): @@ -515,4 +515,4 @@ class NllLossModule_backward1DSumWeight(torch.nn.Module): @register_test_case(module_factory=lambda: NllLossModule_backward1DSumWeight()) def NllLossModuleBackward1DSumWeight_basic(module, tu: TestUtils): module.forward(tu.rand(1), tu.rand(3), torch.tensor([2, 3, 0]), - tu.rand(3), torch.tensor(3.)) + torch.rand(3), torch.tensor(3.)) diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 70f5cef84..2029727ca 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -315,44 +315,6 @@ def ReduceMaxAlongDim_basic(module, tu: TestUtils): # ============================================================================== -class ReduceMaxAlongDimSignedInt(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.int64, True), - ]) - def forward(self, a): - return torch.ops.aten.max(a, 1) - - -@register_test_case(module_factory=lambda: ReduceMaxAlongDimSignedInt()) -def ReduceMaxAlongDimSignedInt_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, 5, low=-100, high=100)) - -# ============================================================================== - -class ReduceMaxAlongDimUnsignedInt(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.uint8, True), - ]) - def forward(self, a): - return torch.ops.aten.max(a, 1) - - -@register_test_case(module_factory=lambda: ReduceMaxAlongDimUnsignedInt()) -def ReduceMaxAlongDimUnsignedInt_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, 5, low=-100, high=100).to(torch.uint8)) - -# ============================================================================== - class ReduceMaxAlongDimNegative(torch.nn.Module): def __init__(self): super().__init__() @@ -586,7 +548,7 @@ class ReduceL1NormModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceL1NormModule()) def ReduceL1NormModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== @@ -622,7 +584,7 @@ class ReduceL2NormModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceL2NormModule()) def ReduceL2NormModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== @@ -640,7 +602,7 @@ class ReduceLN3NormModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceLN3NormModule()) def ReduceLN3NormModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== @@ -658,7 +620,7 @@ class ReduceL3NormAllDimsModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule()) def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== @@ -676,47 +638,7 @@ class ReduceL3NormKeepDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule()) def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) - -# ============================================================================== - -class NormScalarOptDimModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.p = 3.0 - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.ops.aten.norm(a, self.p, dim=[0, 1], keepdim=False) - -@register_test_case(module_factory=lambda: NormScalarOptDimModule()) -def NormScalarOptDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) - - -# ============================================================================== - -class NormScalarOptDimKeepDimModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.p = 3.0 - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, a): - return torch.ops.aten.norm(a, self.p, dim=[0, 1], keepdim=True) - -@register_test_case(module_factory=lambda: NormScalarOptDimKeepDimModule()) -def NormScalarOptDimKeepDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) - + module.forward(torch.rand(3, 4, 5)) # ============================================================================== class ReduceFrobeniusNormModule(torch.nn.Module): @@ -733,7 +655,7 @@ class ReduceFrobeniusNormModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceFrobeniusNormModule()) def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): @@ -750,7 +672,7 @@ class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule()) def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5)) + module.forward(torch.rand(3, 4, 5)) # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1c2d810c1..7ac4be9e4 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -655,23 +655,6 @@ class ViewNoChangeStaticModule(torch.nn.Module): def ViewNoChangeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6)) -class ViewNegativeStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([1, 128], torch.float32, True), - ]) - - def forward(self, a): - return a.view(-1, 128) - -@register_test_case(module_factory=lambda: ViewNegativeStaticModule()) -def ViewNegativeStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 128)) - # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 89fc81b8b..6096712f7 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -215,38 +215,6 @@ def BernoulliTensorModule_basic(module, tu: TestUtils): # ============================================================================== -class BernoulliPModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ([-1, -1, -1], torch.float64, True), - ]) - def forward(self, x, y): - a = torch.ops.aten.bernoulli(x, 0.4) - b = torch.ops.aten.bernoulli(y, 0.7) - mean = torch.cat([ - torch.flatten(torch.mean(a)), - torch.flatten(torch.mean(b)), - ]) - std = torch.cat([ - torch.flatten(torch.std(a)), - torch.flatten(torch.std(b)), - ]) - return mean, std - - -@register_test_case(module_factory=lambda: BernoulliPModule()) -def BernoulliPModule_basic(module, tu: TestUtils): - module.forward( - tu.rand(512, 512, 16).double(), - tu.rand(512, 512, 16).double()) - -# ============================================================================== - class RandLikeModule(torch.nn.Module): def __init__(self): super().__init__() @@ -396,46 +364,3 @@ class RandnGeneratorModule(torch.nn.Module): @register_test_case(module_factory=lambda: RandnGeneratorModule()) def RandnGeneratorModule_basic(module, tu: TestUtils): module.forward() - - -# ============================================================================== - -class RandnLikeModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) - def forward(self, x): - a = torch.ops.aten.randn_like(x) - std = torch.std(a) - return std - - -@register_test_case(module_factory=lambda: RandnLikeModule()) -def RandnLikeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 512, 1024).double()) - -# ============================================================================== - -class RandnLikeDtypeModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float64, True), - ]) - def forward(self, x): - a = torch.ops.aten.randn_like(x, dtype=torch.float32) - std = torch.std(a) - return std - - -@register_test_case(module_factory=lambda: RandnLikeDtypeModule()) -def RandnLikeDtypeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 1024).double()) diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index 74717d99f..95879b44e 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -75,7 +75,7 @@ class SubFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: SubFloatModule()) def SubFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.rand().double()) + module.forward(torch.rand(()).double(), torch.rand(()).double()) # ============================================================================== @@ -146,7 +146,7 @@ class DivFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: DivFloatModule()) def DivFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.rand().double()) + module.forward(torch.rand(()).double(), torch.rand(()).double()) # ============================================================================== @@ -175,7 +175,7 @@ class CeilFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: CeilFloatModule()) def CeilFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.rand().double()) + module.forward(torch.rand(()).double(), torch.rand(()).double()) # ============================================================================== @@ -339,59 +339,6 @@ def BoolIntConstantModule_basic(module, tu: TestUtils): # ============================================================================== -class AtenIntBoolOpModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([], torch.bool, True), - ]) - def forward(self, x): - return int(torch.ops.aten.Int(x)) - - -@register_test_case(module_factory=lambda: AtenIntBoolOpModule()) -def AtenIntBoolOpModule_basic(module, tu: TestUtils): - module.forward(tu.randint(low=0, high=2).bool()) - - -class AtenIntBoolOpConstTrueModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ]) - def forward(self): - return int(torch.ops.aten.Int(True)) - - -@register_test_case(module_factory=lambda: AtenIntBoolOpConstTrueModule()) -def AtenIntBoolOpConstTrueModule_basic(module, tu: TestUtils): - module.forward() - - -class AtenIntBoolOpConstFalseModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ]) - def forward(self): - return int(torch.ops.aten.Int(False)) - - -@register_test_case(module_factory=lambda: AtenIntBoolOpConstFalseModule()) -def AtenIntBoolOpConstFalseModule_basic(module, tu: TestUtils): - module.forward() - -# ============================================================================== - class AtenIntTensorByteDtypeModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 25f73e349..d9d0bd121 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -121,7 +121,7 @@ class GeFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: GeFloatModule()) def GeFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.rand().double()) + module.forward(torch.randn(()).double(), torch.randn(()).double()) # ============================================================================== @@ -144,7 +144,7 @@ class GeFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GeFloatIntModule()) def GeFloatIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.randint(low=-100, high=100)) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -167,7 +167,7 @@ class NeFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: NeFloatIntModule()) def NeFloatIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.randint(low=-100, high=100)) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) # ============================================================================== @@ -190,4 +190,4 @@ class GtFloatIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: GtFloatIntModule()) def GtFloatIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand().double(), tu.randint(low=-100, high=100)) + module.forward(torch.randn(()).double(), tu.randint(low=-100, high=100)) diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 1e8566826..8776a0b4c 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -243,30 +243,12 @@ class SelectIntModule(torch.nn.Module): ([-1, -1], torch.int64, True), ]) def forward(self, x): - return torch.select(x, dim=0, index=0) + return x.select(0,0) @register_test_case(module_factory=lambda: SelectIntModule()) def SelectIntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(5, 5, high=10)) - - -class SelectIntNegativeDimAndIndexStaticModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([5, 5], torch.int64, True), - ]) - def forward(self, x): - return torch.select(x, dim=-1, index=-1) - - -@register_test_case(module_factory=lambda: SelectIntNegativeDimAndIndexStaticModule()) -def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils): - module.forward(tu.randint(5, 5, high=10)) + module.forward(tu.randint(5,5, high=10)) # ============================================================================== @@ -384,7 +366,7 @@ class SelectScatterModule(torch.nn.Module): @register_test_case(module_factory=lambda: SelectScatterModule()) def SelectScattertModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6, 8, 5), tu.rand(8, 5)) + module.forward(torch.rand(6, 8, 5), torch.rand(8, 5)) class SelectScatterStaticModule(torch.nn.Module): def __init__(self): @@ -402,7 +384,7 @@ class SelectScatterStaticModule(torch.nn.Module): @register_test_case(module_factory=lambda: SelectScatterStaticModule()) def SelectScattertStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(6, 8, 5), tu.rand(6, 5)) + module.forward(torch.rand(6, 8, 5), torch.rand(6, 5)) # ============================================================================== @@ -481,47 +463,3 @@ class NarrowVerticalTest2(torch.nn.Module): @register_test_case(module_factory=lambda: NarrowVerticalTest2()) def NarrowVerticalTest2_basic(module, tu: TestUtils): module.forward(tu.rand(6,4)) - -# ============================================================================== - -class SliceCopy_Module(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([10, 4, 4], torch.float32, True), - ([4, 4, 4], torch.float32, True), - ]) - def forward(self, x, y): - xslice = torch.ops.aten.slice(x, 0, 2, 6, 1) - xslice.copy_(y) - return x - - -@register_test_case(module_factory=lambda: SliceCopy_Module()) -def SliceCopy_Module_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) - -# ============================================================================== - -class SliceCopyNegative_Module(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x, y): - xslice = torch.ops.aten.slice(x, 0, 2, -4, 1) - xslice.copy_(y) - return x - - -@register_test_case(module_factory=lambda: SliceCopyNegative_Module()) -def SliceCopyNegative_Module_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/threshold.py b/python/torch_mlir_e2e_test/test_suite/threshold.py index 674f88e89..8efa7e7e2 100644 --- a/python/torch_mlir_e2e_test/test_suite/threshold.py +++ b/python/torch_mlir_e2e_test/test_suite/threshold.py @@ -99,7 +99,7 @@ class Threshold1dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold1dFloatModule()) def Threshold1dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4)) + module.forward(torch.randn(4)) class Threshold2dFloatModule(torch.nn.Module): @@ -117,7 +117,7 @@ class Threshold2dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold2dFloatModule()) def Threshold2dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5)) + module.forward(torch.randn(4, 5)) class Threshold3dFloatModule(torch.nn.Module): @@ -135,7 +135,7 @@ class Threshold3dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: Threshold3dFloatModule()) def Threshold3dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6)) + module.forward(torch.randn(4, 5, 6)) class ThresholdBackward1dIntModule(torch.nn.Module): @@ -211,7 +211,7 @@ class ThresholdBackward1dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule()) def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4), tu.rand(4)) + module.forward(torch.randn(4), torch.randn(4)) class ThresholdBackward2dFloatModule(torch.nn.Module): @@ -230,7 +230,7 @@ class ThresholdBackward2dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule()) def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5), tu.rand(4, 5)) + module.forward(torch.randn(4, 5), torch.randn(4, 5)) class ThresholdBackward3dFloatModule(torch.nn.Module): @@ -249,7 +249,7 @@ class ThresholdBackward3dFloatModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule()) def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), tu.rand(4, 5, 6)) + module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6)) class ThresholdBackward1dMixedModule(torch.nn.Module): @@ -268,7 +268,7 @@ class ThresholdBackward1dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule()) def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4), tu.randint(4, high=10)) + module.forward(torch.randn(4), tu.randint(4, high=10)) class ThresholdBackward2dMixedModule(torch.nn.Module): @@ -287,7 +287,7 @@ class ThresholdBackward2dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule()) def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, 5, high=20), tu.rand(4, 5)) + module.forward(tu.randint(4, 5, high=20), torch.randn(4, 5)) class ThresholdBackward3dMixedModule(torch.nn.Module): @@ -306,4 +306,4 @@ class ThresholdBackward3dMixedModule(torch.nn.Module): @register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule()) def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4, 5, 6), tu.randint(4, 5, 6, high=10)) + module.forward(torch.randn(4, 5, 6), tu.randint(4, 5, 6, high=10)) diff --git a/python/torch_mlir_e2e_test/test_suite/type_promotion.py b/python/torch_mlir_e2e_test/test_suite/type_promotion.py index 41c03ec18..f2ff36fd8 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_promotion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_promotion.py @@ -51,7 +51,7 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module): @register_test_case( module_factory=lambda: TypePromotionDifferentCategoryModule()) def TypePromotionDifferentCategoryModule_basic(module, tu: TestUtils): - module.forward(tu.randint(4, high=10), tu.rand(4)) + module.forward(tu.randint(4, high=10), torch.randn(4)) class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module): diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 65012500d..0cb776b0d 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -c54ce93106ef7d893be87a9f7b0e0bd98724b539 +a4dd47e06e53b0ee51081a62aa63a98fde260d67 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 93c8cf99c..b4f7a6340 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,4 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.1.0.dev20230310 +torch==2.0.0.dev20230116 +torchvision==0.15.0.dev20230116 diff --git a/requirements.txt b/requirements.txt index f346b53da..01478ee1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ --r pytorch-requirements.txt -r build-requirements.txt --r test-requirements.txt + +# Test Requirements +pillow +dill +multiprocess diff --git a/setup.py b/setup.py index 11459041d..9ba763351 100644 --- a/setup.py +++ b/setup.py @@ -41,14 +41,12 @@ from setuptools import setup, Extension from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +import torch PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC_DEFAULT = True -TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) -if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: - import torch # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): @@ -92,7 +90,6 @@ class CMakeBuild(build_py): f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", - f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}", ] os.makedirs(cmake_build_dir, exist_ok=True) @@ -146,7 +143,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setup( - name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core", + name="torch-mlir", version=f"{PACKAGE_VERSION}", author="Sean Silva", author_email="silvasean@google.com", @@ -161,8 +158,10 @@ setup( }, ext_modules=[ CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), - ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], - install_requires=["numpy", ] + ( - [f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), + ], + install_requires=[ + "numpy", + f"torch=={torch.__version__}".split("+", 1)[0], + ], zip_safe=False, ) diff --git a/test-requirements.txt b/test-requirements.txt deleted file mode 100644 index e752531e2..000000000 --- a/test-requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ --r torchvision-requirements.txt - -pillow -dill -multiprocess diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 51407b488..8b444bd1d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON TORCH_MLIR_ENABLE_JIT_IR_IMPORTER - TORCH_MLIR_ENABLE_STABLEHLO + TORCH_MLIR_ENABLE_MHLO ) configure_lit_site_cfg( diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index aae5c91e7..bea58bd40 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // ----- @@ -7,7 +7,7 @@ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { -// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { @@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { -// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> // CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { @@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> // CHECK: return %[[T4]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { @@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { @@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc // CHECK: %[[VAL_7:.*]] = arith.constant 1 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor // CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex> -// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32> func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { @@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?], // CHECK-LABEL: func.func @torch.aten.batch_norm$training( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK-LABEL: func.func @torch.aten.batch_norm$inference( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01 @@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor to tensor -// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor to tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> @@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor // CHECK: %none = torch.constant.none -// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor -// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?], // CHECK-LABEL: func @torch.aten.native_layer_norm( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32> -// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32> // CHECK: %int4 = torch.constant.int 4 // CHECK: %int5 = torch.constant.int 5 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %true = torch.constant.bool true // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64> -// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> -// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> -// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> -// CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> -// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> -// CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> -// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64> +// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> +// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32> // CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32> func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { @@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> // CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor) -> tensor -// CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor) -> tensor +// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { @@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index b1d560e4f..6b3faace0 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.gelu( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -7,12 +7,12 @@ // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor -// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor -// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor +// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor -// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor -// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor -// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor +// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor +// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor +// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.log$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.exp$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.neg$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.rsqrt$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor +// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?],f32> @@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> @@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?],f32> @@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> @@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> @@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> // CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32> func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { @@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor +// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> -// CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> @@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> -// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> @@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "trunc" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor -// CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor -// CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor -// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor +// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor +// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor +// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "floor" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor +// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { diff --git a/test/Conversion/TorchToMhlo/gather.mlir b/test/Conversion/TorchToMhlo/gather.mlir index ea4ca9b82..a20b32d49 100644 --- a/test/Conversion/TorchToMhlo/gather.mlir +++ b/test/Conversion/TorchToMhlo/gather.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.index_select$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { @@ -10,8 +10,8 @@ // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> -// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32> +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { @@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1 // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> { @@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor +// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> { diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index 628969956..165c874ea 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -1,10 +1,10 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.mm$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> @@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> @@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> // CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> @@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> @@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> @@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> @@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> @@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> @@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> @@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> @@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T4]] : !torch.vtensor<[],f32> @@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-LABEL: func.func @torch.aten.matmul$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 @@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> @@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.mm$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor -// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> -// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> @@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[T_13:.*]] = torch.constant.bool false -// CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) // CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> @@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %false = torch.constant.bool false -// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor // CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 // CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 // CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> -// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> @@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]]) +// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> // CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> // CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> @@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> @@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> -// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> +// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> // CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> // CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { @@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32> +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32> // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> // CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 @@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 // CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 // CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> -// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> -// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> +// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> +// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> // CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> -// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> -// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]]) +// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> +// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToMhlo/lit.local.cfg b/test/Conversion/TorchToMhlo/lit.local.cfg index d4f752cd7..829a5662f 100644 --- a/test/Conversion/TorchToMhlo/lit.local.cfg +++ b/test/Conversion/TorchToMhlo/lit.local.cfg @@ -1,2 +1,2 @@ -if not config.enable_stablehlo: +if not config.enable_mhlo: config.unsupported = True diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToMhlo/pooling.mlir index 98805bdd8..684eb7828 100644 --- a/test/Conversion/TorchToMhlo/pooling.mlir +++ b/test/Conversion/TorchToMhlo/pooling.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // ----- @@ -13,11 +13,11 @@ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): -// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor -// CHECK: stablehlo.return %[[VAL_10]] : tensor +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: mhlo.return %[[VAL_10]] : tensor // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> @@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): -// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor -// CHECK: stablehlo.return %[[VAL_10]] : tensor +// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: mhlo.return %[[VAL_10]] : tensor // CHECK: }) // CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> @@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64> // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64 // CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64> -// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor -// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor -// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ +// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor +// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor +// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor -// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor +// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor +// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor +// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor +// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor +// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor, tensor // CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> @@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): -// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor -// CHECK: stablehlo.return %[[IVAL_2]] : tensor +// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: mhlo.return %[[IVAL_2]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 @@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor // CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): -// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor -// CHECK: stablehlo.return %[[IVAL_5]] : tensor +// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor +// CHECK: mhlo.return %[[IVAL_5]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({ +// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: stablehlo.return %[[T10]] : tensor +// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: mhlo.return %[[T10]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor -// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor +// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor +// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32> diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 8a6ec8d72..70f3570d8 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -42,7 +42,7 @@ // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { @@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { @@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { @@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { @@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 // CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> -// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> +// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor, tensor<2xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,224],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { @@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 // CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor -> !torch.vtensor<[?,120,4,64],f32> // CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { @@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor) -> tensor<1xf32> +// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor) -> tensor<1xf32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[T3]] : !torch.vtensor<[1],f32> func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { @@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T3]] : !torch.vtensor<[],f32> func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { @@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32 // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32> func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { @@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> ! // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32> func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { @@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32 // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> // CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64> -// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> // CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32> func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { @@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { @@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { @@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32> func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 73b2cbeab..dc037353e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -939,25 +939,6 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten return %0 : !torch.vtensor<[3,5],i1> } -// ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,128],i1> -> tensor<1x128xi1> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_3:.*]] = torch.constant.none -// CHECK: %[[VAL_4:.*]] = torch.constant.bool false -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x128xi1>) -> tensor<1x128xi64> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> -// CHECK: } -func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> { - %int4 = torch.constant.int 4 - %none = torch.constant.none - %false = torch.constant.bool false - %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[1,128],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128],si64> - return %0 : !torch.vtensor<[1,128],si64> -} - // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>, diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 7c6dac439..cb5e3ead1 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -502,24 +502,6 @@ func.func @torch.prim.max.int$constant() -> !torch.int { return %0 : !torch.int } -// CHECK-LABEL: func.func @torch.prim.min.int$identity( -// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { -// CHECK: return %[[ARG]] : !torch.int -func.func @torch.prim.min.int$identity(%arg0: !torch.int) -> !torch.int { - %0 = torch.prim.min.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.int - return %0 : !torch.int -} - -// CHECK-LABEL: func.func @torch.prim.min.int$constant() -> !torch.int { -// CHECK: %[[INT1:.*]] = torch.constant.int -1 -// CHECK: return %[[INT1]] : !torch.int -func.func @torch.prim.min.int$constant() -> !torch.int { - %int-1 = torch.constant.int -1 - %int3 = torch.constant.int 3 - %0 = torch.prim.min.int %int-1, %int3 : !torch.int, !torch.int -> !torch.int - return %0 : !torch.int -} - // CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int { // CHECK: %[[M1:.*]] = torch.constant.int -1 // CHECK: return %[[M1]] : !torch.int @@ -1062,16 +1044,6 @@ func.func @torch.aten.remainder.int() -> !torch.int { return %ret : !torch.int } -// CHECK-LABEL: func.func @torch.aten.pow.int_float() -> !torch.float { -// CHECK: %[[FLOAT_8:.*]] = torch.constant.float 8.000000e+00 -// CHECK: return %[[FLOAT_8]] : !torch.float -func.func @torch.aten.pow.int_float() -> !torch.float { - %cst2 = torch.constant.int 2 - %float3.0 = torch.constant.float 3.0 - %ret = torch.aten.pow.int_float %cst2, %float3.0: !torch.int, !torch.float -> !torch.float - return %ret : !torch.float -} - // CHECK-LABEL: func.func @torch.prim.dtype$bfloat16( // CHECK-SAME: %[[T:.*]]: !torch.tensor<*,bf16>) -> !torch.int { // CHECK: %[[CST:.*]] = torch.constant.int 15 @@ -1201,26 +1173,6 @@ func.func @torch.tensor_static_info_cast$refine(%arg0: !torch.vtensor<[], f32>) return %1 : !torch.vtensor } -// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor { -// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor -// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor -func.func @torch.tensor_static_info_cast$refine$dtype(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor { - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<[],unk> - %1 = torch.aten.relu %0 : !torch.vtensor<[],unk> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine$shape( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor { -// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.relu %[[ARG]] : !torch.vtensor<[],f32> -> !torch.vtensor -// CHECK-NEXT: return %[[RESULT]] : !torch.vtensor -func.func @torch.tensor_static_info_cast$refine$shape(%arg0: !torch.vtensor<[], f32>) -> !torch.vtensor { - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32> - %1 = torch.aten.relu %0 : !torch.vtensor<*,f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - // CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor) -> !torch.vtensor { // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor to !torch.vtensor<[],f32> @@ -1232,28 +1184,6 @@ func.func @torch.tensor_static_info_cast$no_refine(%arg0: !torch.vtensor) -> !to return %1 : !torch.vtensor } -// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],unk>) -> !torch.vtensor { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],unk> to !torch.vtensor<[],f32> -// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @torch.tensor_static_info_cast$no_refine$dtype(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor { - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[],f32> - %1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// CHECK-LABEL: func.func @torch.tensor_static_info_cast$no_refine$shape( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<*,f32> to !torch.vtensor<[],f32> -// CHECK: %[[RESULT:.*]] = torch.aten.relu %[[CAST]] : !torch.vtensor<[],f32> -> !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @torch.tensor_static_info_cast$no_refine$shape(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<[],f32> - %1 = torch.aten.relu %0 : !torch.vtensor<[],f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - // CHECK-LABEL: func.func @torch.tensor_static_info_cast$refine_allowed_ops( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.tuple { // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor @@ -1323,15 +1253,6 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int { return %scalar : !torch.int } -// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int { -// CHECK: %[[NUM:.*]] = torch.constant.int 1 -// CHECK: return %[[NUM]] : !torch.int -func.func @torch.aten.Int.float() -> !torch.int { - %float1 = torch.constant.float 1.0 - %int1 = torch.aten.Int.float %float1 : !torch.float -> !torch.int - return %int1 : !torch.int -} - // CHECK-LABEL: func.func @torch.aten.Float.Tensor( // CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float { // CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64> @@ -1728,16 +1649,6 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], return %2 : !torch.vtensor<[],si64> } -// CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float { -// CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00 -// CHECK: return %[[FLOAT_1]] : !torch.float -func.func @torch.aten.sub.float$fold() -> !torch.float { - %float1 = torch.constant.float 1.0 - %float2 = torch.constant.float 2.0 - %0 = torch.aten.sub.float %float1, %float2 : !torch.float, !torch.float -> !torch.float - return %0 : !torch.float -} - // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[INT6]] = torch.constant.int 6 // CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> @@ -1880,52 +1791,3 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor< %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32> return %0 : !torch.vtensor<[4],f32> } - -// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64> -func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> - %2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - return %2 : !torch.vtensor<[],si64> -} - -// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> -func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> - %2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - return %2 : !torch.vtensor<[],si64> -} - -// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number -// CHECK: return %[[VAL_1]] : !torch.number -func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { - %int1 = torch.constant.int 1 - %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> - %1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number - return %1 : !torch.number -} - -// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number { -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number -// CHECK: return %[[VAL_0]] : !torch.number -func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number { - %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> - %1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number - return %1 : !torch.number -} diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 254a348cd..84e7d63da 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -101,17 +101,17 @@ torch.class_type @c { // ----- // expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}} -func.func private @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) +func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) // ----- // expected-error @+1 {{'torch.type_bound' must be TypeAttr}} -func.func private @f(%arg0: i32 {torch.type_bound = 1}) +func.func @f(%arg0: i32 {torch.type_bound = 1}) // ----- // expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}} -func.func private @f(%arg0: i32 {torch.type_bound = i32}) +func.func @f(%arg0: i32 {torch.type_bound = i32}) // ----- @@ -265,19 +265,3 @@ torch.global_slot.module_initializer { @tensor(%1 : !torch.tensor) ] } - -// ----- - -func.func @torch.tensor_static_info_cast$shape_mismatch(%arg0: !torch.vtensor<[],unk>) -> !torch.vtensor<[?],unk> { - // expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<[],unk>' and result type '!torch.vtensor<[?],unk>' are cast incompatible}} - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],unk> to !torch.vtensor<[?],unk> - return %0 : !torch.vtensor<[?],unk> -} - -// ----- - -func.func @torch.tensor_static_info_cast$dtype_mismatch(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<*,f64> { - // expected-error@+1 {{'torch.tensor_static_info_cast' op operand type '!torch.vtensor<*,f32>' and result type '!torch.vtensor<*,f64>' are cast incompatible}} - %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor<*,f64> - return %0 : !torch.vtensor<*,f64> -} diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir index 6fc29daab..e058e0d67 100644 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ b/test/Dialect/Torch/refine-types-ops.mlir @@ -137,22 +137,6 @@ func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[ return %ret : !torch.tensor } -// ----- -// CHECK-LABEL: func.func @torch.aten.cat$promote_type( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - // ----- // CHECK-LABEL: func.func @torch.aten._shape_as_tensor( // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index ae02a0339..391fd6d2a 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -161,27 +161,3 @@ func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor return } - -// ----- - -// The data-flow analysis does not always propagate information to the entire graph. -// This results in some lattice elements being uninitialized, which must be properly -// handled when using the lattice elements to rewrite the graph. -// In this particular case, the presence of the loop causes `torch.copy.to_vtensor` -// to end up with an uninitialized lattice element. This is the simplest graph I was -// able to come up with that reproduces such behavior. - -// CHECK-LABEL: func.func @uninitialized_lattice_elements( -// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32> - -func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> { - %true = torch.constant.bool true - %1 = torch.constant.int 0 - %2 = torch.prim.Loop %1, %true, init(%arg3) { - ^bb0(%arg1: !torch.int, %arg2: !torch.tensor): - torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor) - } : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor - %3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32> - %4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32> - return %4 : !torch.vtensor<*,f32> -} diff --git a/test/Dialect/Torch/verify-backend-contract-error.mlir b/test/Dialect/Torch/verify-backend-contract-error.mlir index eb9c6c581..5accee126 100644 --- a/test/Dialect/Torch/verify-backend-contract-error.mlir +++ b/test/Dialect/Torch/verify-backend-contract-error.mlir @@ -1,7 +1,7 @@ -// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s -func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { - // expected-error @below {{unsupported by backend contract: tensor with unknown rank}} - // expected-note @below {{this is likely due to a missing transfer function}} - %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor - return %t : !torch.vtensor +// RUN: torch-mlir-opt -torch-verify-backend-contract -split-input-file -verify-diagnostics %s +func.func @f(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + // expected-error @+2 {{found an op that was marked as backend illegal}} + // expected-note @+1 {{this is likely due to}} + %t = torch.aten.t %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %t : !torch.vtensor<[?,?],f32> } diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index f7ac86747..339975f3e 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.python_executable = "@Python3_EXECUTABLE@" config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ -config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@ +config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@ import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/test/python/smoketest.py b/test/python/smoketest.py index bb97927e9..88e0a10f7 100644 --- a/test/python/smoketest.py +++ b/test/python/smoketest.py @@ -5,8 +5,3 @@ from torch_mlir.dialects import torch with torch_mlir.ir.Context() as ctx: torch.register_dialect(ctx) - with torch_mlir.ir.Location.unknown() as loc: - module = torch_mlir.ir.Module.create(loc) - with torch_mlir.ir.InsertionPoint.at_block_begin(module.body): - n = torch.ConstantNoneOp() - module.operation.print() \ No newline at end of file diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index 9bf123480..9976b06b7 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -12,10 +12,6 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "torch-mlir/InitAll.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "stablehlo/dialect/Register.h" -#endif - using namespace mlir; int main(int argc, char **argv) { @@ -25,10 +21,7 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); mlir::torch::registerAllDialects(registry); - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::stablehlo::registerAllDialects(registry); -#endif + return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry, /*preloadDialectsInContext=*/false)); diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt deleted file mode 100644 index 8e5329074..000000000 --- a/torchvision-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html ---pre -torchvision==0.15.0.dev20230310 diff --git a/utils/bazel/docker/Dockerfile b/utils/bazel/docker/Dockerfile index 7f78226b4..2382b6100 100644 --- a/utils/bazel/docker/Dockerfile +++ b/utils/bazel/docker/Dockerfile @@ -13,12 +13,12 @@ RUN apt-get update && \ unzip # Install clang -ARG REPO_NAME="deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-16 main" +ARG REPO_NAME="deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy main" RUN echo $REPO_NAME >> /etc/apt/sources.list.d/llvm.list && \ wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ apt-get update && \ apt-get install -y \ - clang-16 + clang # Install bazel ARG ARCH="x86_64" @@ -29,8 +29,6 @@ RUN wget -q https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSIO # Install torch-mlir requirements COPY requirements.txt /opt/app/requirements.txt COPY build-requirements.txt /opt/app/build-requirements.txt -COPY test-requirements.txt /opt/app/test-requirements.txt -COPY torchvision-requirements.txt /opt/app/torchvision-requirements.txt COPY pytorch-requirements.txt /opt/app/pytorch-requirements.txt WORKDIR /opt/app RUN python3 -m pip install --upgrade pip diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index ec4d1cb90..86b4060b8 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -268,7 +268,7 @@ gentbl_cc_library( ( [ "-gen-pass-decls", - "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_MHLO", ], "include/torch-mlir/Conversion/Passes.h.inc", ), @@ -434,13 +434,13 @@ cc_library( ) cc_library( - name = "TorchMLIRTorchToStablehlo", + name = "TorchMLIRTorchToMhlo", srcs = glob([ "lib/Conversion/*.h", - "lib/Conversion/TorchToStablehlo/*.h", - "lib/Conversion/TorchToStablehlo/*.cpp", + "lib/Conversion/TorchToMhlo/*.h", + "lib/Conversion/TorchToMhlo/*.cpp", ]), - hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), + hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]), strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -449,7 +449,6 @@ cc_library( ":TorchMLIRTorchConversionDialect", "@llvm-project//mlir:Dialect", "@mlir-hlo//:mlir_hlo", - "@mlir-hlo//:transforms_passes", ], ) @@ -461,16 +460,13 @@ cc_library( hdrs = [ "include/torch-mlir/Conversion/Passes.h", ], - defines = [ - "TORCH_MLIR_ENABLE_STABLEHLO", - ], strip_include_prefix = "include", deps = [ ":TorchMLIRTorchConversionToMLProgram", ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", - ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTosa", ], @@ -493,8 +489,8 @@ cc_library( ":TorchMLIRTorchPasses", ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", + ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", - ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTosa", "@llvm-project//mlir:ConversionPasses", diff --git a/utils/bazel/torch-mlir-overlay/test/BUILD.bazel b/utils/bazel/torch-mlir-overlay/test/BUILD.bazel index d29391305..2db2a7751 100644 --- a/utils/bazel/torch-mlir-overlay/test/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/test/BUILD.bazel @@ -23,7 +23,7 @@ expand_template( # All disabled, but required to substituted because they are not in quotes. "@MLIR_ENABLE_BINDINGS_PYTHON@": "0", "@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0", - "@TORCH_MLIR_ENABLE_STABLEHLO@": "0", + "@TORCH_MLIR_ENABLE_MHLO@": "0", }, template = "lit.site.cfg.py.in", ) diff --git a/whl-requirements.txt b/whl-requirements.txt index f628a4180..554744257 100644 --- a/whl-requirements.txt +++ b/whl-requirements.txt @@ -1,5 +1,4 @@ -f build-requirements.txt --f pytorch-requirements.txt # Packaging requirements. packaging