diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 0a7fa5672..85cebda32 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -1,8 +1,6 @@ name: Roll PyTorch on: - schedule: - - cron: '0 12 * * *' workflow_dispatch: jobs: @@ -34,8 +32,8 @@ jobs: python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torch==${PT_RELEASE}" # Read the commit hash from the downloaded whl file without extracting it PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'") - echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet - PT_HASH_CHANGED=$? + PT_HASH_CHANGED=0 + echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet || PT_HASH_CHANGED=$? echo "${PT_HASH}" > pytorch-version.txt rm torch-"${PT_RELEASE}"*.whl # Write the release and hash to the environment file so that we can @@ -44,13 +42,14 @@ jobs: echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - name: Build and test + - name: Build and test (in-tree), also update ODS and shape library if: env.PT_HASH_CHANGED != '0' run: | cd ${GITHUB_WORKSPACE} - TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \ TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ + TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \ ./build_tools/python_deploy/build_linux_packages.sh - name: Push changes to main branch @@ -61,5 +60,5 @@ jobs: git config user.name "Roll PyTorch Action" git fetch --recurse-submodules=no git checkout main - git add pytorch-version.txt pytorch-requirements.txt + git add pytorch-version.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main) diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 56b34e9b2..117f6fbf8 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -61,7 +61,7 @@ jobs: if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} run: | cd $GITHUB_WORKSPACE - TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh + TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-version.txt)" TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' # cross compile, can't test arm64 if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 7bed797c4..ef9b92925 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -94,6 +94,15 @@ jobs: with: release_id: ${{ github.event.inputs.release_id }} + publish_releases: + needs: + - build_linux + - build_macos + + # Publish even if one of the builds failed + if: ${{ always() }} + + steps: - name: Invoke Publish Releases Page uses: benc-uk/workflow-dispatch@v1 with: diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 0148b865a..42cb09e71 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -12,7 +12,6 @@ blacklist: - detach - item - size -- where - copy_ # Disabled for consistency with TS backend diff --git a/build_tools/build_libtorch.sh b/build_tools/build_libtorch.sh index 024ba836b..6c5fd68db 100755 --- a/build_tools/build_libtorch.sh +++ b/build_tools/build_libtorch.sh @@ -58,7 +58,8 @@ checkout_pytorch() { git reset --hard FETCH_HEAD else cd "${PYTORCH_ROOT}" - git reset --hard HEAD + git fetch --depth=1 origin "${TORCH_MLIR_SRC_PYTORCH_BRANCH}" + git reset --hard FETCH_HEAD fi git clean -df git submodule update --init --depth 1 --recursive diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 5b19063d8..aa86e384c 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -53,14 +53,16 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}" TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" # Skip running tests if you want quick iteration TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}" +# Update ODS and shape library files +TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}" PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE" TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}" echo "Setting torch-mlir Python Package version to: ${TORCH_MLIR_PYTHON_PACKAGE_VERSION}" -TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}" +export TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}" echo "Setting torch-mlir PyTorch Repo for source builds to: ${TORCH_MLIR_SRC_PYTORCH_REPO}" -TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}" +export TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}" echo "Setting torch-mlir PyTorch version for source builds to: ${TORCH_MLIR_SRC_PYTORCH_BRANCH}" function run_on_host() { @@ -109,6 +111,7 @@ function run_on_host() { -e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \ -e "TM_PACKAGES=${package}" \ -e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \ + -e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \ -e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \ -e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \ -e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \ @@ -152,6 +155,12 @@ function run_in_docker() { in-tree) setup_venv "$python_version" build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version" + if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then + pushd /main_checkout/torch-mlir + ./build_tools/update_torch_ods.sh + ./build_tools/update_shape_lib.sh + popd + fi if [ "${TM_SKIP_TESTS}" == "OFF" ]; then test_in_tree; fi diff --git a/docs/development.md b/docs/development.md index 1811b8450..aaeb8667b 100644 --- a/docs/development.md +++ b/docs/development.md @@ -28,23 +28,25 @@ Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is ### Building torch-mlir in-tree -The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. +The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. On Windows, use the "Developer PowerShell for Visual Studio" to ensure that the compiler and linker binaries are in the `PATH` variable. ```shell cmake -GNinja -Bbuild \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ -DPython3_FIND_VIRTUALENV=ONLY \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \ - -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \ - -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ + -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="$PWD"/externals/llvm-external-projects/torch-mlir-dialects \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_TARGETS_TO_BUILD=host \ externals/llvm-project/llvm ``` The following additional quality of life flags can be used to reduce build time: +* Enabling clang on Linux +```shell + -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ +``` * Enabling ccache: ```shell -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache @@ -73,8 +75,6 @@ If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, ```shell cmake -GNinja -Bbuild \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ -DPython3_FIND_VIRTUALENV=ONLY \ -DMLIR_DIR="$LLVM_INSTALL_DIR/lib/cmake/mlir/" \ -DLLVM_DIR="$LLVM_INSTALL_DIR/lib/cmake/llvm/" \ @@ -82,7 +82,7 @@ cmake -GNinja -Bbuild \ -DLLVM_TARGETS_TO_BUILD=host \ . ``` -The same QoL CMake flags can be used to enable ccache and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`. +The same QoL CMake flags can be used to enable clang, ccache, and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`. Be aware that the installed version of LLVM needs in general to match the committed version in `externals/llvm-project`. Using a different version may or may not work. @@ -105,15 +105,25 @@ cmake --build build ``` ## Setup Python Environment to export the built Python packages + +### Linux and macOS + ```shell export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples ``` +### Windows PowerShell + +```shell +$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/examples" +``` + ## Testing MLIR output in various dialects To test the compiler's output to the different MLIR dialects, you can use the example `examples/torchscript_resnet18_all_output_types.py`. -Make sure you have activated the virtualenv and set the `PYTHONPATH` above: +Make sure you have activated the virtualenv and set the `PYTHONPATH` above +(if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 231b3cbe6..d00bcc76e 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -463,6 +463,10 @@ TOSA_PASS_SET = { "ArangeStartIntModule_basic", "ArangeStartNegativeStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "NumToTensorIntModule_basic", + "ToDtypeBoolLayoutNoneStaticModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "HardTanhIntModule_basic", } LTC_XFAIL_SET = { @@ -569,8 +573,8 @@ LTC_XFAIL_SET = { "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "LiftFreshCopyModule_basic", "Matmul_dot", - "Matmul_matvec", "MulIntModule_basic", + "DivIntModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", @@ -632,4 +636,8 @@ LTC_XFAIL_SET = { "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index e39c8413b..a79c0e09f 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" @@ -134,7 +134,7 @@ struct TMTensorBufferizePass bufferization::BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. - target.addLegalDialect(); // Mark all TMTensor operations illegal as long as they work on tensors. diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index 3d68625b7..d8af2ef5c 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -101,7 +101,7 @@ namespace { struct TMTensorToLoopsPass : public TMTensorToLoopsBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/CMakeLists.txt b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/CMakeLists.txt index 4b753e310..102d60bce 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/CMakeLists.txt +++ b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/CMakeLists.txt @@ -1,5 +1,5 @@ set(LIBS - MLIRArithmeticDialect + MLIRArithDialect MLIRDialect MLIRLinalgDialect MLIRMemRefDialect diff --git a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp index 1dfc756bb..5e65bb28d 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Dialect.h" @@ -39,7 +39,7 @@ int main(int argc, char **argv) { // Local dialects mlir::torch::TMTensor::TMTensorDialect, // Upstream dialects - mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect, + mlir::arith::ArithDialect, mlir::linalg::LinalgDialect, mlir::func::FuncDialect, mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::tensor::TensorDialect>(); diff --git a/externals/llvm-project b/externals/llvm-project index bebc96956..6f46ff376 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit bebc96956b76bdbc36f1d82a788c810e5b12e2c5 +Subproject commit 6f46ff3765dcdc178b9cf52ebd8c03437806798a diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 7b0ecf782..2f7c1454b 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 7b0ecf7827e3fc07d2af90e147bcedc165bc78ac +Subproject commit 2f7c1454bbe4c4ad0ae1c86c5539ac58b6053b6a diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index f459960ee..e499c6744 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet( MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); +/// Gets the !torch.vtensor type with the tensor attribute. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr); + //===----------------------------------------------------------------------===// // !torch.none type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir-c/Transforms.h b/include/torch-mlir-c/Transforms.h new file mode 100644 index 000000000..ccadd8069 --- /dev/null +++ b/include/torch-mlir-c/Transforms.h @@ -0,0 +1,22 @@ +//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the registration and creation method for +// transformation passes. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_C_TRANSFORMS_H +#define TORCHMLIR_C_TRANSFORMS_H + +#include "mlir-c/Support.h" + +#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" + +#endif // TORCHMLIR_C_TRANSFORMS_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 44f00eea9..3d8e4bffc 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -53,6 +53,9 @@ template llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape); +LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, + Value src, Type destType, Value &result); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8daa2f77c..6eafdcf46 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2192,6 +2192,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ }]; } +def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ AllowsTypeRefinement, HasValueSemantics, @@ -3387,6 +3434,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ }]; } +def Torch_AtenMvOp : Torch_Op<"aten.mv", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$vec + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMvOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -7399,6 +7470,31 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [ }]; } +def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, @@ -8356,6 +8452,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ let hasFolder = 1; } +def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::div.int : (int, int) -> (float)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDivIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenDivIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 92d6186d6..0cb46745c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -383,6 +383,7 @@ class ListOf allowedTypes, string descr> : def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; def AnyTorchListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">; +def AnyTorchListOfTorchFloatType : ListOf<[Torch_FloatType], "Float list type (float[])">; def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">; def AnyTorchListOfTensorType: ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">; @@ -390,7 +391,7 @@ def AnyTorchListOfOptionalTensorType : ListOf<[AnyTorchOptionalTensorType], "Any optional tensor list type (Tensor?[])">; def AnyTorchOptionalListOfTorchIntType : OptionalOf; - +def AnyTorchOptionalListOfTorchFloatType : OptionalOf; // Note: TorchScript does not consider !torch.bool to be a Scalar. def AnyTorchScalarType : Type, diff --git a/include/torch-mlir/Dialect/Torch/Transforms/CMakeLists.txt b/include/torch-mlir/Dialect/Torch/Transforms/CMakeLists.txt index 71b95bfa3..017facf43 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/CMakeLists.txt +++ b/include/torch-mlir/Dialect/Torch/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) +mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header) +mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl) add_public_tablegen_target(TorchMLIRTorchPassIncGen) add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc) diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 98ee5151e..29c24035b 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -21,6 +21,8 @@ class ModuleOp; namespace torch { namespace Torch { +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" + std::unique_ptr> createGlobalizeObjectGraphPass(); std::unique_ptr> @@ -109,6 +111,8 @@ std::unique_ptr> createLowerToBackendContractPass(int maxIterations, bool decompose, ArrayRef backendLegalOps); +std::unique_ptr> createVerifyBackendContractPass(); + StringRef getShapeLibrary(); } // namespace Torch @@ -116,6 +120,13 @@ StringRef getShapeLibrary(); /// Registers all Torch transformation passes. void registerTorchPasses(); +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" + } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index c1ce31aa6..216c8b8ca 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -329,4 +329,16 @@ def LowerToBackendContract let dependentDialects = ["func::FuncDialect"]; } +def VerifyBackendContract + : Pass<"torch-verify-backend-contract", "ModuleOp"> { + let summary = "Check that program satisfies backend contract."; + let constructor = + "mlir::torch::Torch::createVerifyBackendContractPass()"; + let description = [{ + This pass performs a set of inspections to check that program satisfies backend + contract. In case of check failure it prints out the error message and returns + `signalPassFailure()` status. + }]; +} + #endif // TORCHMLIR_TORCH_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index d98dc2eca..323846fe3 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -52,6 +52,9 @@ int getTensorRank(Value tensor); bool isViewLikeOp(Operation *op); +Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, + float value, Type dtype); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 87977a86f..d71796ae8 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI Registration.cpp TorchOps.cpp TorchTypes.cpp + Transforms.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ @@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI MLIRSupport TorchMLIRTorchDialect TorchMLIRInitAll + TorchMLIRTorchPasses ) torch_mlir_target_includes(TorchMLIRCAPI) diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 7465fc06e..27dbda38f 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation( Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context))); } +MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) { + auto attrTensorType = + unwrap(attr).cast().getType().cast(); + return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(), + attrTensorType.getShape(), + attrTensorType.getElementType())); +} + //===----------------------------------------------------------------------===// // torch.none type. //===----------------------------------------------------------------------===// diff --git a/lib/CAPI/Transforms.cpp b/lib/CAPI/Transforms.cpp new file mode 100644 index 000000000..f0f57a72d --- /dev/null +++ b/lib/CAPI/Transforms.cpp @@ -0,0 +1,27 @@ +//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +// Must include the declarations as they carry important visibility attributes. +#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 1c1a2009d..c6e4aee57 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -10,7 +10,7 @@ #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "../PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -121,6 +121,25 @@ public: }; } // namespace +namespace { +class ConvertAtenDivIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDivIntOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value a = + convertScalarToDtype(rewriter, loc, adaptor.a(), rewriter.getF64Type()); + Value b = + convertScalarToDtype(rewriter, loc, adaptor.b(), rewriter.getF64Type()); + rewriter.replaceOpWithNewOp(op, a, b); + return success(); + } +}; +} // namespace + namespace { // Lowers aten integer comparison ops. template @@ -300,7 +319,7 @@ class ConvertTorchToArith : public ConvertTorchToArithBase public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -311,7 +330,7 @@ public: MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; @@ -374,6 +393,8 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d66f18393..20c7f4168 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -16,7 +16,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -252,11 +252,11 @@ public: llvm::all_of(expandShape, [](int64_t value) { return value == kUnknownSize; })) { - for (int i = 0; i < collapseShape.size(); i++) { + for (size_t i = 0; i < collapseShape.size(); i++) { collapseIndices.push_back(i); } - for (int i = 0; i < expandShape.size(); i++) { + for (size_t i = 0; i < expandShape.size(); i++) { expandIndices.push_back(i); } @@ -290,8 +290,8 @@ public: op, "total number of elements mismatch in the expansion"); } - static LogicalResult solveDynamicSize(SmallVector &inputShape, - SmallVector &outputShape) { + static void solveDynamicSize(SmallVector &inputShape, + SmallVector &outputShape) { int64_t inputProduct = 1; int64_t outputProduct = 1; @@ -316,7 +316,7 @@ public: if (inputDynamicValues + outputDynamicValues == 1) { if (inputDynamicValues) { int64_t missingValue = outputProduct / inputProduct; - for (int i = 0; i < inputShape.size(); i++) { + for (size_t i = 0; i < inputShape.size(); i++) { if (inputShape[i] == -1) { inputShape[i] = missingValue; break; @@ -324,7 +324,7 @@ public: } } else { int64_t missingValue = inputProduct / outputProduct; - for (int i = 0; i < outputShape.size(); i++) { + for (size_t i = 0; i < outputShape.size(); i++) { if (outputShape[i] == -1) { outputShape[i] = missingValue; break; @@ -332,8 +332,6 @@ public: } } } - - return success(); } LogicalResult @@ -625,9 +623,6 @@ public: } } - int64_t inputCount = inputAssociations.size(); - int64_t outputCount = outputAssociations.size(); - // Check if the shapes already match up to dynamic sizes. If so, we can just // cast as the result type because the previous loop sets up the necessary // dim checks in case of dynamic sizes. diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 28454c590..a94b09594 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -12,9 +12,10 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -755,6 +756,146 @@ public: }; } // namespace +// `getScaledDims` scales the `dim` value with a scale factor `ScaleFactor`. +// The `dim` and `scaleFactor` are assumed to be of index and float type +// respectively. `scaledDim = int(floor(float(dim) * scaleFactor))`. +static Value getScaledDims(OpBuilder &builder, Location loc, Value dim, + Value scaleFactor) { + + Value dimInt = castIndexToInt64(builder, loc, dim); + Value dimFp = + builder.create(loc, scaleFactor.getType(), dimInt); + Value scaleDim = builder.create(loc, dimFp, scaleFactor); + Value floorDim = builder.create(loc, scaleDim); + Value scaledDimToIndex = castIntToIndex( + builder, loc, + builder.create(loc, dimInt.getType(), floorDim)); + + return scaledDimToIndex; +} + +// `getScaleFactor` returns the scale factor from input to output dimension. +// The `dim` and `scaledDim` are assumed to be of index and int64 type +// respectively. scale_factor = (scaled_dim // dim). +static Value getScaleFactor(OpBuilder &builder, Location loc, Value dim, + Value scaledDim) { + Value dimInt = castIndexToInt64(builder, loc, dim); + Value scaleFactorInt = + builder.create(loc, scaledDim, dimInt); + return scaleFactorInt; +} + +// N, C, H, W = input_tensor.shape +// N, C, H_scaled, W_scaled = out_tensor.shape +// H_factor, W_factor = H_scaled/H, W_scaled/W + +// for i in range(N): +// for j in range(C): +// for k in range(H_scaled): +// for l in range(W_scaled): +// out_tensor[i, j, k, l] = input[i, j, k//H_factor, l//W_factor] + +namespace { +class ConvertAtenUpsampleNearest2dVecOp + : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUpsampleNearest2dVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value input = adaptor.input(); + + Type resultType = getTypeConverter()->convertType(op.getResult().getType()); + auto inputType = input.getType().cast(); + auto inputRank = inputType.getRank(); + Type elementType = inputType.getElementType(); + + SmallVector dims = getTensorSizes(rewriter, loc, input); + SmallVector scaleFactorsInt; + + // The dimension at which the scaling starts. + unsigned hDimOffset = 2; + + if (!adaptor.scale_factors().getType().isa()) { + SmallVector scaleFactorsTorchFloat; + if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the scale_factors is not constructed from " + "ListConstruct"); + SmallVector scaleFactorsFloatValues; + scaleFactorsFloatValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat); + // Convert float values to int values. + // int_value = (int64_t)ceil(float_value) + for (auto floatValue : scaleFactorsFloatValues) { + Value ceilVal = rewriter.create(loc, floatValue); + Value intVal = rewriter.create( + loc, rewriter.getI64Type(), ceilVal); + scaleFactorsInt.push_back(intVal); + } + + for (unsigned i = 0; i < scaleFactorsFloatValues.size(); i++) + dims[hDimOffset + i] = getScaledDims( + rewriter, loc, dims[hDimOffset + i], scaleFactorsFloatValues[i]); + + } else { + + SmallVector outputSizeTorchInt; + if (!getListConstructElements(op.output_size(), outputSizeTorchInt)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the output_size is not constructed from " + "ListConstruct"); + SmallVector outputSizeIntValues; + outputSizeIntValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), outputSizeTorchInt); + + for (unsigned i = 0; i < outputSizeTorchInt.size(); i++) { + auto scaleFactorVal = getScaleFactor( + rewriter, loc, dims[hDimOffset + i], outputSizeIntValues[i]); + scaleFactorsInt.push_back(scaleFactorVal); + dims[hDimOffset + i] = + castIntToIndex(rewriter, loc, outputSizeIntValues[i]); + } + } + + Value outTensor = + rewriter.create(loc, dims, elementType); + + AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); + SmallVector iteratorTypes(inputRank, + getParallelIteratorTypeName()); + + Value finalRes = + rewriter + .create( + loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) + indices.push_back(b.create(loc, i)); + + for (unsigned i = 0; i < (inputRank - hDimOffset); i++) + indices[i + hDimOffset] = b.create( + loc, indices[i + hDimOffset], + castIntToIndex(rewriter, loc, scaleFactorsInt[i])); + + Value retVal = + b.create(loc, input, indices); + b.create(loc, retVal); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, finalRes); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg:: populateIndirectDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -770,4 +911,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9c01db32c..c9e92b625 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index d1828c474..fc1379e92 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index c8118b21d..427b42181 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index dc4e9704c..256d88a5a 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -152,12 +152,10 @@ public: nestedLoc, oldIndex.getType(), rewriter.create(loc, dim)); - Value predicate; - if (inElementType.isa()) - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); - auto resultMax = rewriter.create( - nestedLoc, predicate, newValue, oldValue); + auto resultMax = rewriter.create( + nestedLoc, newValue, oldValue); + Value predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 013fdbb48..592e71229 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 24b487c5b..d1d13b776 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -11,7 +11,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index f8ebc349f..484307ecf 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -11,7 +11,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -43,7 +43,7 @@ public: registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -53,7 +53,7 @@ public: ConversionTarget target(*context); target.addLegalDialect(); + tensor::TensorDialect, arith::ArithDialect>(); target.addLegalOp(); TypeConverter typeConverter; diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 70b158e09..eadde8ee0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -199,6 +199,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto bitwiseOrTensor = dyn_cast(op)) { + if (bitwiseOrTensor.getType() + .cast() + .getDtype() + .isa()) { + bitwiseOrTensor.emitError( + "Bitwise_Or does not support floating point dtype"); + return nullptr; + } + Type dtype = converter->convertType(bitwiseOrTensor.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (auto logicalOr = dyn_cast(op)) { MLIRContext *context = op->getContext(); Type floatDtype = mlir::FloatType::getF64(context); @@ -1006,12 +1022,12 @@ public: AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, + AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, AtenBitwiseNotOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -1483,10 +1499,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, + AtenBitwiseOrTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp>(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 57a50a688..8cdab3960 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -11,7 +11,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index a61202c0e..3e7cbf3b0 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -14,7 +14,7 @@ #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/utils/hlo_utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" diff --git a/lib/Conversion/TorchToMhlo/Gather.cpp b/lib/Conversion/TorchToMhlo/Gather.cpp index 1b1863347..12da1c8e3 100644 --- a/lib/Conversion/TorchToMhlo/Gather.cpp +++ b/lib/Conversion/TorchToMhlo/Gather.cpp @@ -13,7 +13,7 @@ #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 2f5c7df4d..a97933ce8 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -13,7 +13,7 @@ #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index 5cab3e7d1..9a93981c5 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -9,7 +9,7 @@ #include "./MhloLegalizeUtils.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp index 514f941a4..60529e886 100644 --- a/lib/Conversion/TorchToMhlo/Pooling.cpp +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -13,7 +13,7 @@ #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index a185a27d4..b4f815436 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -13,7 +13,7 @@ #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 67ff28c39..21d991bd7 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -12,7 +12,7 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" @@ -42,14 +42,14 @@ public: registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); + tensor::TensorDialect, arith::ArithDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp index e89f3f389..baf34d7cb 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -13,7 +13,7 @@ #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 0f833eab5..b0c2e821c 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -10,7 +10,7 @@ #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "../PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" @@ -321,7 +321,7 @@ namespace { class ConvertTorchToSCF : public ConvertTorchToSCFBase { public: void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -329,7 +329,7 @@ public: MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); + arith::ArithDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index d92a29554..03dcd1fd2 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -10,7 +10,7 @@ #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "../PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -614,7 +614,7 @@ public: registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -623,7 +623,7 @@ public: MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); TypeConverter typeConverter; diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1e7407da3..a4fc236e1 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -12,7 +12,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "../PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" @@ -3055,6 +3055,133 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimNumToTensorScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + // Only supports integer operand type, because for the floating point operand + // type result tensor has to be of type `f64` which is not supported in the + // tosa. + int64_t initValue; + if (!matchPattern(op.a(), m_TorchConstantInt(&initValue))) + return rewriter.notifyMatchFailure( + op, "unimplemented: input should be a torch constant int"); + + DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue}); + rewriter.replaceOpWithNewOp(op, resultType, constAttr); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + ValsemVariantAtenCopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + auto srcType = adaptor.src().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + if (!srcType || !srcType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + // The non_blocking should be a constant `False`. + bool nonBlocking; + if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking must be a constant"); + } else if (nonBlocking) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking is expected to be false"); + } + + SmallVector selfShape(selfType.getShape()); + SmallVector srcShape(srcType.getShape()); + + if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) { + // If we reach here, then it means the given case is handled by implicit + // broadcasting done by tosa. + Value result; + if (failed(tosa::tosaCastTensorToType( + rewriter, op, adaptor.src(), + getTypeConverter()->convertType(op.getType()), result))) + return rewriter.notifyMatchFailure( + op, "unimplemented: cast to result type not supported"); + rewriter.replaceOp(op, result); + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: valsem.aten.copy op not supported for this case."); +} + +// Legalizes the torch.aten.to.dtype op +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenToDtypeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only tensor types with static shape are supported"); + + // The non_blocking arg should be a constant `False`. + bool nonBlocking; + if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking arg must be a constant"); + } else if (nonBlocking) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non_blocking arg is expected to be false"); + } + + // The copy arg should be a constant `False`. + bool copy; + if (!matchPattern(op.copy(), m_TorchConstantBool(©))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: copy arg must be a constant"); + } else if (copy) { + return rewriter.notifyMatchFailure( + op, "unimplemented: copy arg is expected to be false"); + } + + // Only `none`, `contiguous` and `preserve` memory_format is supported. + if (!op.memory_format().getType().isa()) { + int64_t memoryFormat; + if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat))) + return rewriter.notifyMatchFailure( + op, "unimplemented: the memory format should be specified in " + "an integer constant"); + if (memoryFormat != torch_upstream::MemoryFormat::Contiguous && + memoryFormat != torch_upstream::MemoryFormat::Preserve) + return rewriter.notifyMatchFailure( + op, "unimplemented: only none, contiguous and preserve " + "memory_format is supported"); + } + + auto resultTy = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + + Value result; + if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy, + result))) + return rewriter.notifyMatchFailure(op, "conversion to result type failed"); + + rewriter.replaceOp(op, result); + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3511,7 +3638,7 @@ public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -3519,7 +3646,7 @@ public: MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); + arith::ArithDialect>(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -3704,6 +3831,9 @@ public: INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index e0c213837..685a6dd86 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -221,6 +221,64 @@ llvm::Optional getConstTensor(PatternRewriter &rewriter, return const_op.getResult(); } +static LogicalResult checkValidityOfCast(Type src, Type dest) { + if ((src.isInteger(64) && dest.isInteger(32)) || + (src.isInteger(32) && dest.isInteger(64)) || + (src.isInteger(64) && dest.isInteger(1)) || + (src.isInteger(32) && dest.isInteger(1)) || + (src.isInteger(8) && dest.isInteger(1)) || + (src.isF32() && dest.isInteger(1))) { + return success(); + } + return failure(); +} + +// Template specialization for float +LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, + Value src, Type destType, Value &result) { + + Type srcElemTy = src.getType().dyn_cast().getElementType(); + Type destElemTy = destType.dyn_cast().getElementType(); + + if (failed(checkValidityOfCast(srcElemTy, destElemTy))) + return rewriter.notifyMatchFailure( + op, "casting to result dtype is invalid or unsupported"); + + if (destElemTy.isInteger(1)) { + auto srcType = src.getType().dyn_cast(); + SmallVector srcShape(srcType.getShape()); + uint64_t num_total_elements = 1; + for (int64_t a : srcShape) + num_total_elements *= a; + + llvm::Optional constOp; + if (srcElemTy.isInteger(64)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(32)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isF32()) { + SmallVector values(num_total_elements, 0.0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(8)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } + Value equalToZero = rewriter.create(op->getLoc(), destType, + src, constOp.value()); + result = rewriter.create(op->getLoc(), destType, + equalToZero); + } else { + result = rewriter.create(op->getLoc(), destType, src); + } + return success(); +} + // Template instantiation template llvm::Optional getConstTensor(PatternRewriter &, Operation *, diff --git a/lib/Conversion/Utils/CMakeLists.txt b/lib/Conversion/Utils/CMakeLists.txt index 6b352bdc5..3f0f67b49 100644 --- a/lib/Conversion/Utils/CMakeLists.txt +++ b/lib/Conversion/Utils/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils LINK_LIBS PUBLIC - MLIRArithmeticDialect + MLIRArithDialect MLIRLinalgDialect TorchMLIRTorchDialect ) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index a0bf7bb67..d800a6907 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -9,7 +9,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b0289af82..3f9105d76 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2175,6 +2175,20 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenDivIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenDivIntOp::fold(ArrayRef operands) { + int64_t lhs, rhs; + bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); + bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); + if (lConstant && rConstant) + return getF64FloatAttr(getContext(), double(lhs) / rhs); + return nullptr; +} + +//===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// @@ -2185,8 +2199,6 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { return nullptr; } -//===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // PrimMaxIntOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 73c47ec26..4574b2399 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -631,6 +631,21 @@ public: }; } // namespace +// Decompose aten.mv into: aten.matmul. +namespace { +class DecomposeAtenMvOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMvOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.self(); + Value rhs = op.vec(); + rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); + return success(); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { @@ -2185,8 +2200,9 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ToCopyOp op, PatternRewriter &rewriter) const override { - Value zero = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(0.0)); + Type resultDtype = op.getType().cast().getDtype(); + Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, + resultDtype); Value emptyTensor = rewriter.create( op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(), op.device(), op.pin_memory(), op.memory_format()); @@ -2859,6 +2875,8 @@ public: patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); target.addIllegalOp(); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c9fd3a484..9307030e2 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -242,6 +242,16 @@ public: }); } }; + +class VerifyBackendContractPass + : public VerifyBackendContractBase { +public: + void runOnOperation() override { + if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) { + return signalPassFailure(); + } + } +}; } // namespace std::unique_ptr> @@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass( return std::make_unique(maxIterations, decompose, backendLegalOps); } + +std::unique_ptr> +mlir::torch::Torch::createVerifyBackendContractPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 3681eef27..0e7a8eac9 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -11,17 +11,8 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - -namespace { -#define GEN_PASS_REGISTRATION -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" -} // end namespace - void mlir::torch::registerTorchPasses() { - ::registerPasses(); + mlir::torch::registerPasses(); mlir::PassPipelineRegistration( "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b286546ec..00c7f257f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, - AtenIndexTensorHackedTwinOp>(op)) { + AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } @@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); @@ -767,8 +767,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming possibly-zero rank. if (isa(op)) { + AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, + AtenThresholdBackwardOp, AtenFloorDivideOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultType( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 95b8c703d..48d62564c 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -23,877 +23,6 @@ StringRef mlir::torch::Torch::getShapeLibrary() { #endif // clang-format off return "module {\n" -" func.func @__torch__.torch._decomp.decompositions.nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %str = torch.constant.str \"Expected a single element grad_output tensor, but got: {}\"\n" -" %str_0 = torch.constant.str \"Expected a tensor of dimension 1 and tensor.size[0] == {} but got: dimension {} and tensor.size[0] == {}\"\n" -" %str_1 = torch.constant.str \"AssertionError: weight tensor should be defined either for all or no classes\"\n" -" %int-1 = torch.constant.int -1\n" -" %str_2 = torch.constant.str \"{} ({} elements)\"\n" -" %str_3 = torch.constant.str \"expected total_weight to be a single element tensor, got: \"\n" -" %str_4 = torch.constant.str \"AssertionError: \"\n" -" %str_5 = torch.constant.str \"size mismatch (got input: {}, target: {})\"\n" -" %true = torch.constant.bool true\n" -" %str_6 = torch.constant.str \"AssertionError: 0D or 1D target tensor expected, multi-target not supported\"\n" -" %none = torch.constant.none\n" -" %str_7 = torch.constant.str \"AssertionError: input tensor should be 1D or 2D\"\n" -" %false = torch.constant.bool false\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.Uninitialized : !torch.optional\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.le.int %int0, %1 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.le.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %5 = torch.aten.le.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %9 = torch.prim.If %8 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %35 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %36 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %37 = torch.aten.eq.int %35, %36 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %37 : !torch.bool\n" -" }\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %36 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" -" %37 = torch.aten.format(%str_5, %35, %36) : !torch.str, !torch.list, !torch.list -> !torch.str\n" -" %38 = torch.aten.add.str %str_4, %37 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %38, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %11 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" -" %36 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %37 = torch.aten.format(%str_2, %35, %36) : !torch.str, !torch.list, !torch.int -> !torch.str\n" -" %38 = torch.prim.TupleConstruct %str_3, %37 : !torch.str, !torch.str -> !torch.tuple\n" -" %39 = torch.aten.str %38 : !torch.tuple -> !torch.str\n" -" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %12 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %35 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %36 = torch.aten.numel %35 : !torch.tensor -> !torch.int\n" -" %37 = torch.aten.size.int %arg1, %int-1 : !torch.tensor, !torch.int -> !torch.int\n" -" %38 = torch.aten.eq.int %36, %37 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %38 : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.optional) {\n" -" torch.prim.If.yield %arg3 : !torch.optional\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.optional\n" -" }\n" -" %15 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %16 = torch.prim.If %15 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %16 -> () {\n" -" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.eq.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %37 = torch.prim.If %36 -> (!torch.bool) {\n" -" %38 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %40 = torch.aten.eq.int %38, %39 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %40 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %37 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %38 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %40 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %41 = torch.aten.format(%str_0, %38, %39, %40) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" -" %42 = torch.aten.add.str %str_4, %41 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %42, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" } else {\n" -" %35 = torch.aten.dim %arg0 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.le.int %35, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %37 = torch.prim.If %36 -> (!torch.bool) {\n" -" %38 = torch.aten.numel %arg0 : !torch.tensor -> !torch.int\n" -" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %39 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %37 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %38 = torch.aten.size %arg0 : !torch.tensor -> !torch.list\n" -" %39 = torch.aten.format(%str, %38) : !torch.str, !torch.list -> !torch.str\n" -" %40 = torch.aten.add.str %str_4, %39 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %40, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" }\n" -" %17 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %18 = torch.aten.lt.int %17, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %19 = torch.prim.If %18 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %20 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %21 = torch.prim.If %20 -> (!torch.tensor) {\n" -" %35 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %22 = torch.aten.unsqueeze %arg2, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %23 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %24 = torch.operator \"aten.scatter.value\"(%23, %19, %22, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %25 = torch.aten.dim %24 : !torch.tensor -> !torch.int\n" -" %26 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" -" %27 = torch.aten.gt.int %25, %26 : !torch.int, !torch.int -> !torch.bool\n" -" %28 = torch.prim.If %27 -> (!torch.bool) {\n" -" %35 = torch.aten.dim %21 : !torch.tensor -> !torch.int\n" -" %36 = torch.aten.gt.int %35, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %36 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %35 = torch.aten.unsqueeze %21, %19 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %21 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.__isnot__ %14, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" -" %35 = torch.prim.unchecked_cast %14 : !torch.optional -> !torch.tensor\n" -" %36 = torch.prim.ListConstruct : () -> !torch.list\n" -" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %37, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %42 = torch.aten.append.t %36, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %38 = torch.aten.size.int %35, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %39 = torch.aten._set_item.t %36, %19, %38 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %40 = torch.aten.reshape %35, %36 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %41 = torch.aten.mul.Tensor %29, %40 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %41 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" }\n" -" %32 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %33 = torch.prim.If %32 -> (!torch.tensor) {\n" -" %35 = torch.aten.ne.Scalar %22, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %36 = torch.aten.where.ScalarOther %35, %31, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %36 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %31 : !torch.tensor\n" -" }\n" -" %34 = torch.aten.mul.Tensor %24, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %34 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._nll_loss_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %true = torch.constant.bool true\n" -" %false = torch.constant.bool false\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %none = torch.constant.none\n" -" %int2 = torch.constant.int 2\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.lt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %3 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %4 = torch.prim.If %3 -> (!torch.tensor) {\n" -" %18 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %18 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %5 = torch.aten.unsqueeze %arg2, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %6 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %7 = torch.operator \"aten.scatter.value\"(%6, %2, %5, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %8 = torch.aten.dim %7 : !torch.tensor -> !torch.int\n" -" %9 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" -" %10 = torch.aten.gt.int %8, %9 : !torch.int, !torch.int -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %18 = torch.aten.dim %4 : !torch.tensor -> !torch.int\n" -" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %19 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %12 = torch.prim.If %11 -> (!torch.tensor) {\n" -" %18 = torch.aten.unsqueeze %4, %2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %18 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" %13 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %14 = torch.prim.If %13 -> (!torch.tensor) {\n" -" %18 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %19 = torch.prim.ListConstruct : () -> !torch.list\n" -" %20 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %20, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %25 = torch.aten.append.t %19, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %21 = torch.aten.size.int %18, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %22 = torch.aten._set_item.t %19, %2, %21 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %23 = torch.aten.reshape %18, %19 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %24 = torch.aten.mul.Tensor %12, %23 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %24 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %12 : !torch.tensor\n" -" }\n" -" %15 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %16 = torch.prim.If %15 -> (!torch.tensor) {\n" -" %18 = torch.aten.ne.Scalar %5, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %19 = torch.aten.where.ScalarOther %18, %14, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %19 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %14 : !torch.tensor\n" -" }\n" -" %17 = torch.aten.mul.Tensor %7, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %17 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.nll_loss2d_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tensor) -> !torch.tensor {\n" -" %true = torch.constant.bool true\n" -" %float-1.000000e00 = torch.constant.float -1.000000e+00\n" -" %str = torch.constant.str \"expected total_weight to be a single element tensor, got: {} ( {}, elements)\"\n" -" %str_0 = torch.constant.str \"size mismatch (got input: {}, target: {}\"\n" -" %false = torch.constant.bool false\n" -" %str_1 = torch.constant.str \"only batches of spatial targets supported (3D tensors) but got targets of dimension: {}\"\n" -" %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: \"\n" -" %str_3 = torch.constant.str \"only batches of spatial inputs supported (4D tensors), but got input of dimension: {}\"\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.format(%str_3, %29) : !torch.str, !torch.int -> !torch.str\n" -" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.dim %arg2 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.format(%str_1, %29) : !torch.str, !torch.int -> !torch.str\n" -" %31 = torch.aten.add.str %str_2, %30 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %31, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.size.int %arg1, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %5 = torch.aten.size.int %arg2, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" %29 = torch.aten.size.int %arg1, %int2 : !torch.tensor, !torch.int -> !torch.int\n" -" %30 = torch.aten.size.int %arg2, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %31 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %8 = torch.prim.If %7 -> (!torch.bool) {\n" -" %29 = torch.aten.size.int %arg1, %int3 : !torch.tensor, !torch.int -> !torch.int\n" -" %30 = torch.aten.size.int %arg2, %int2 : !torch.tensor, !torch.int -> !torch.int\n" -" %31 = torch.aten.eq.int %29, %30 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %31 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %30 = torch.aten.size %arg2 : !torch.tensor -> !torch.list\n" -" %31 = torch.aten.format(%str_0, %29, %30) : !torch.str, !torch.list, !torch.list -> !torch.str\n" -" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %9 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %10 = torch.aten.eq.int %9, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %29 = torch.aten.size %arg6 : !torch.tensor -> !torch.list\n" -" %30 = torch.aten.numel %arg6 : !torch.tensor -> !torch.int\n" -" %31 = torch.aten.format(%str, %29, %30) : !torch.str, !torch.list, !torch.int -> !torch.str\n" -" %32 = torch.aten.add.str %str_2, %31 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %11 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %12 = torch.aten.lt.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.int) {\n" -" torch.prim.If.yield %int0 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int1 : !torch.int\n" -" }\n" -" %14 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %15 = torch.prim.If %14 -> (!torch.tensor) {\n" -" %29 = torch.aten.div.Tensor %arg0, %arg6 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %16 = torch.aten.unsqueeze %arg2, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %17 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %18 = torch.operator \"aten.scatter.value\"(%17, %13, %16, %float-1.000000e00) : (!torch.tensor, !torch.int, !torch.tensor, !torch.float) -> !torch.tensor\n" -" %19 = torch.aten.dim %18 : !torch.tensor -> !torch.int\n" -" %20 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" -" %21 = torch.aten.gt.int %19, %20 : !torch.int, !torch.int -> !torch.bool\n" -" %22 = torch.prim.If %21 -> (!torch.bool) {\n" -" %29 = torch.aten.dim %15 : !torch.tensor -> !torch.int\n" -" %30 = torch.aten.gt.int %29, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %30 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %23 = torch.prim.If %22 -> (!torch.tensor) {\n" -" %29 = torch.aten.unsqueeze %15, %13 : !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %29 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %15 : !torch.tensor\n" -" }\n" -" %24 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %25 = torch.prim.If %24 -> (!torch.tensor) {\n" -" %29 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %30 = torch.prim.ListConstruct : () -> !torch.list\n" -" %31 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" torch.prim.Loop %31, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %36 = torch.aten.append.t %30, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %32 = torch.aten.size.int %29, %int0 : !torch.tensor, !torch.int -> !torch.int\n" -" %33 = torch.aten._set_item.t %30, %13, %32 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %34 = torch.aten.reshape %29, %30 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %23, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %23 : !torch.tensor\n" -" }\n" -" %26 = torch.aten.ge.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %27 = torch.prim.If %26 -> (!torch.tensor) {\n" -" %29 = torch.aten.ne.Scalar %16, %arg5 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %30 = torch.aten.where.ScalarOther %29, %25, %int0 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" torch.prim.If.yield %30 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %25 : !torch.tensor\n" -" }\n" -" %28 = torch.aten.mul.Tensor %18, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" return %28 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._log_softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" -" %false = torch.constant.bool false\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %true = torch.constant.bool true\n" -" %0 = torch.aten.exp %arg1 : !torch.tensor -> !torch.tensor\n" -" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" -" %2 = torch.aten.sum.dim_IntList %arg0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %3 = torch.aten.mul.Tensor %0, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.sub.Tensor %arg0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %6 = torch.aten.ne.int %5, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %8 = torch.aten.to.dtype %4, %arg3, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %8 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" return %7 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._cast_grad_to_input_dtype(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int) -> !torch.tensor {\n" -" %none = torch.constant.none\n" -" %false = torch.constant.bool false\n" -" %0 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %1 = torch.aten.ne.int %0, %arg2 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.tensor) {\n" -" %3 = torch.aten.to.dtype %arg1, %arg2, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %3 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg1 : !torch.tensor\n" -" }\n" -" return %2 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions._softmax_backward_data(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.int, %arg3: !torch.int) -> !torch.tensor {\n" -" %false = torch.constant.bool false\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %true = torch.constant.bool true\n" -" %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" -" %2 = torch.aten.sum.dim_IntList %0, %1, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %3 = torch.aten.mul.Tensor %arg1, %2 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.sub.Tensor %0, %3, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.prim.dtype %arg0 : !torch.tensor -> !torch.int\n" -" %6 = torch.aten.ne.int %5, %arg3 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %8 = torch.aten.to.dtype %4, %arg3, %false, %false, %none : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %8 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %4 : !torch.tensor\n" -" }\n" -" return %7 : !torch.tensor\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.log_sigmoid_forward(%arg0: !torch.tensor) -> !torch.tuple {\n" -" %int1 = torch.constant.int 1\n" -" %none = torch.constant.none\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.new_zeros %arg0, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %2 = torch.aten.minimum %1, %arg0 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %3 = torch.aten.abs %arg0 : !torch.tensor -> !torch.tensor\n" -" %4 = torch.aten.neg %3 : !torch.tensor -> !torch.tensor\n" -" %5 = torch.aten.exp %4 : !torch.tensor -> !torch.tensor\n" -" %6 = torch.operator \"prim.is_cuda\"(%arg0) : (!torch.tensor) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.tensor) {\n" -" %11 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %12 = torch.aten.new_zeros %arg0, %11, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %12 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %5 : !torch.tensor\n" -" }\n" -" %8 = torch.aten.log1p %5 : !torch.tensor -> !torch.tensor\n" -" %9 = torch.aten.sub.Tensor %2, %8, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %10 = torch.prim.TupleConstruct %9, %7 : !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %10 : !torch.tuple\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_layer_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.tensor, %arg4: !torch.tensor, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.list) -> !torch.tuple, optional, optional> {\n" -" %false = torch.constant.bool false\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %int2 = torch.constant.int 2\n" -" %0 = torch.aten.size %arg1 : !torch.tensor -> !torch.list\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" -" %4 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %5 = torch.aten.slice.t %0, %none, %3, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" -" %6 = torch.prim.ListConstruct : () -> !torch.list\n" -" %7 = torch.aten.__range_length %3, %1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %7, %true, init() {\n" -" ^bb0(%arg8: !torch.int):\n" -" %17 = torch.aten.__derive_index %arg8, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten.append.t %6, %17 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %8 = torch.prim.ListConstruct : () -> !torch.list\n" -" %9 = torch.aten.__range_length %int0, %3, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %9, %true, init() {\n" -" ^bb0(%arg8: !torch.int):\n" -" %17 = torch.aten.__derive_index %arg8, %int0, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten.append.t %8, %17 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %10 = torch.aten.len.t %4 : !torch.list -> !torch.int\n" -" %11 = torch.prim.Loop %10, %true, init(%int1) {\n" -" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" -" %17 = torch.aten.__getitem__.t %4, %arg8 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %12 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" -" %13 = torch.prim.Loop %12, %true, init(%int1) {\n" -" ^bb0(%arg8: !torch.int, %arg9: !torch.int):\n" -" %17 = torch.aten.__getitem__.t %5, %arg8 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.aten.mul.int %arg9, %17 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%18 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %14 = torch.aten.le.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %15 = torch.prim.If %14 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %17 = torch.aten.le.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %17 : !torch.bool\n" -" }\n" -" %16 = torch.prim.If %15 -> (!torch.tuple, optional, optional>) {\n" -" %17 = torch.aten.new_zeros %arg1, %0, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %18 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %19 = torch.aten.new_zeros %arg1, %18, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %20 = torch.aten.slice.t %0, %3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" -" %21 = torch.aten.new_zeros %arg1, %20, %none, %none, %none, %none : !torch.tensor, !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" %22 = torch.prim.TupleConstruct %17, %19, %21 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" %23 = torch.derefine %22 : !torch.tuple to !torch.tuple, optional, optional>\n" -" torch.prim.If.yield %23 : !torch.tuple, optional, optional>\n" -" } else {\n" -" %17 = torch.aten.mean.dim %arg1, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %18 = torch.aten.var.dim %arg1, %6, %false, %true : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %19 = torch.aten.reciprocal %arg4 : !torch.tensor -> !torch.tensor\n" -" %20 = torch.aten.pow.Tensor_Scalar %19, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %21 = torch.aten.sub.Tensor %20, %18, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %22 = torch.aten.detach %21 : !torch.tensor -> !torch.tensor\n" -" %23 = torch.aten.add.Tensor %18, %22, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %24 = torch.aten.sqrt %23 : !torch.tensor -> !torch.tensor\n" -" %25 = torch.aten.reciprocal %24 : !torch.tensor -> !torch.tensor\n" -" %26 = torch.aten.sub.Tensor %arg1, %17, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %27 = torch.aten.mul.Tensor %26, %25 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %28 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %46 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" -" %47 = torch.aten.mul.Tensor %arg0, %46 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" } else {\n" -" torch.prim.If.yield %arg0 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.mul.Scalar %29, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %31 = torch.aten.sum.dim_IntList %29, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %32 = torch.aten.mul.Tensor %29, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %33 = torch.aten.sum.dim_IntList %32, %6, %true, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %34 = torch.aten.mul.Tensor %27, %33 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %35 = torch.aten.sub.Tensor %30, %31, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %36 = torch.aten.sub.Tensor %35, %34, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %37 = torch.aten.__getitem__.t %arg7, %int0 : !torch.list, !torch.int -> !torch.bool\n" -" %38 = torch.prim.If %37 -> (!torch.tensor) {\n" -" %46 = torch.aten.div.Scalar %25, %11 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %47 = torch.aten.mul.Tensor %46, %36 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.zeros_like %arg1, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %46 : !torch.tensor\n" -" }\n" -" %39 = torch.aten.__getitem__.t %arg7, %int1 : !torch.list, !torch.int -> !torch.bool\n" -" %40 = torch.prim.If %39 -> (!torch.bool) {\n" -" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %46 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %41 = torch.prim.If %40 -> (!torch.tensor) {\n" -" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" -" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" -" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %50 = torch.aten.sum.dim_IntList %49, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %50 : !torch.tensor\n" -" } else {\n" -" %49 = torch.aten.mul.Tensor %arg0, %27 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %48 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" -" %48 = torch.prim.unchecked_cast %arg5 : !torch.optional -> !torch.tensor\n" -" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %48 = torch.prim.ListConstruct : () -> !torch.list\n" -" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" }\n" -" %42 = torch.aten.__getitem__.t %arg7, %int2 : !torch.list, !torch.int -> !torch.bool\n" -" %43 = torch.prim.If %42 -> (!torch.bool) {\n" -" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %46 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %44 = torch.prim.If %43 -> (!torch.tensor) {\n" -" %46 = torch.aten.len.t %8 : !torch.list -> !torch.int\n" -" %47 = torch.aten.gt.int %46, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %48 = torch.prim.If %47 -> (!torch.tensor) {\n" -" %49 = torch.aten.sum.dim_IntList %arg0, %8, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %49 = torch.aten.clone %arg0, %none : !torch.tensor, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %48 : !torch.tensor\n" -" } else {\n" -" %46 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %47 = torch.prim.If %46 -> (!torch.tensor) {\n" -" %48 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" -" %49 = torch.aten.zeros_like %48, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" } else {\n" -" %48 = torch.prim.ListConstruct : () -> !torch.list\n" -" %49 = torch.aten.zeros %48, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %49 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %47 : !torch.tensor\n" -" }\n" -" %45 = torch.prim.TupleConstruct %38, %41, %44 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple, optional, optional>\n" -" torch.prim.If.yield %45 : !torch.tuple, optional, optional>\n" -" }\n" -" return %16 : !torch.tuple, optional, optional>\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.recompute_mean_var(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.tuple {\n" -" %false = torch.constant.bool false\n" -" %none = torch.constant.none\n" -" %int1 = torch.constant.int 1\n" -" %int2 = torch.constant.int 2\n" -" %0 = torch.aten.mean.dim %arg0, %arg2, %arg3, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %1 = torch.aten.var.dim %arg0, %arg2, %false, %arg3 : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %2 = torch.aten.reciprocal %arg1 : !torch.tensor -> !torch.tensor\n" -" %3 = torch.aten.mul.Scalar %2, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %4 = torch.aten.pow.Tensor_Scalar %3, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %5 = torch.aten.sub.Tensor %4, %1, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %6 = torch.aten.detach %5 : !torch.tensor -> !torch.tensor\n" -" %7 = torch.aten.add.Tensor %1, %6, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %8 = torch.aten.sqrt %7 : !torch.tensor -> !torch.tensor\n" -" %9 = torch.aten.reciprocal %8 : !torch.tensor -> !torch.tensor\n" -" %10 = torch.aten.mul.Scalar %9, %int1 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %11 = torch.prim.TupleConstruct %0, %10 : !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %11 : !torch.tuple\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.native_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.bool, %arg8: !torch.float, %arg9: !torch.list) -> !torch.tuple, optional> {\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %true = torch.constant.bool true\n" -" %str_0 = torch.constant.str \"AssertionError: when train=True, save_mean and save_invstd are required\"\n" -" %false = torch.constant.bool false\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: rank of the input must be at least 2\"\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %float1.000000e00 = torch.constant.float 1.000000e+00\n" -" %0 = torch.prim.Uninitialized : !torch.tensor\n" -" %1 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %2 = torch.aten.ge.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %4 = torch.prim.Loop %3, %true, init(%int1) {\n" -" ^bb0(%arg10: !torch.int, %arg11: !torch.int):\n" -" %34 = torch.aten.size.int %arg1, %arg10 : !torch.tensor, !torch.int -> !torch.int\n" -" %35 = torch.aten.mul.int %arg11, %34 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%35 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" %5 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %6 = torch.operator \"aten.div.int\"(%4, %5) : (!torch.int, !torch.int) -> !torch.float\n" -" %7:2 = torch.prim.If %arg7 -> (!torch.tensor, !torch.tensor) {\n" -" %34 = torch.aten.__isnot__ %arg5, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.bool) {\n" -" %52 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %52 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %35 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %36 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %37 = torch.aten.dim %arg1 : !torch.tensor -> !torch.int\n" -" %38 = torch.prim.ListConstruct : () -> !torch.list\n" -" %39 = torch.aten.__range_length %int2, %37, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %39, %true, init() {\n" -" ^bb0(%arg10: !torch.int):\n" -" %52 = torch.aten.__derive_index %arg10, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %53 = torch.aten.append.t %38, %52 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %40 = torch.aten.add.t %36, %38 : !torch.list, !torch.list -> !torch.list\n" -" %41 = torch.aten.__isnot__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %42 = torch.prim.If %41 -> (!torch.tensor) {\n" -" %52 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.tensor\n" -" torch.prim.If.yield %52 : !torch.tensor\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.tensor\n" -" }\n" -" %43 = torch.aten.mean.dim %arg1, %40, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %44 = torch.aten.var.dim %arg1, %40, %false, %false : !torch.tensor, !torch.list, !torch.bool, !torch.bool -> !torch.tensor\n" -" %45 = torch.aten.reciprocal %42 : !torch.tensor -> !torch.tensor\n" -" %46 = torch.aten.pow.Tensor_Scalar %45, %int2 : !torch.tensor, !torch.int -> !torch.tensor\n" -" %47 = torch.aten.sub.Tensor %46, %44, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %48 = torch.aten.detach %47 : !torch.tensor -> !torch.tensor\n" -" %49 = torch.aten.add.Tensor %44, %48, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %50 = torch.aten.sqrt %49 : !torch.tensor -> !torch.tensor\n" -" %51 = torch.aten.reciprocal %50 : !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %43, %51 : !torch.tensor, !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.bool) {\n" -" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %40 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %40 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %36:2 = torch.prim.If %35 -> (!torch.tensor, !torch.tensor) {\n" -" %39 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.tensor\n" -" %40 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.tensor\n" -" torch.prim.If.yield %40, %39 : !torch.tensor, !torch.tensor\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0, %0 : !torch.tensor, !torch.tensor\n" -" }\n" -" %37 = torch.aten.add.Scalar %36#0, %arg8, %int1 : !torch.tensor, !torch.float, !torch.int -> !torch.tensor\n" -" %38 = torch.aten.rsqrt %37 : !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %36#1, %38 : !torch.tensor, !torch.tensor\n" -" }\n" -" %8 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" -" %9 = torch.operator \"aten.mul.left_t\"(%8, %1) : (!torch.list, !torch.int) -> !torch.list\n" -" %10 = torch.aten.size.int %arg1, %int1 : !torch.tensor, !torch.int -> !torch.int\n" -" %11 = torch.aten._set_item.t %9, %int1, %10 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %12 = torch.prim.ListConstruct : () -> !torch.list\n" -" torch.prim.Loop %1, %true, init() {\n" -" ^bb0(%arg10: !torch.int):\n" -" %34 = torch.aten.ne.int %arg10, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %34 -> () {\n" -" %35 = torch.aten.append.t %12, %arg10 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %13 = torch.aten.reshape %7#0, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %14 = torch.aten.div.float %float1.000000e00, %6 : !torch.float, !torch.float -> !torch.float\n" -" %15 = torch.aten.sum.dim_IntList %arg0, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %16 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %17 = torch.aten.mul.Tensor %arg0, %16 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %18 = torch.aten.sum.dim_IntList %17, %12, %false, %none : !torch.tensor, !torch.list, !torch.bool, !torch.none -> !torch.tensor\n" -" %19 = torch.aten.mul.Scalar %15, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" -" %20 = torch.aten.reshape %19, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %21 = torch.aten.mul.Scalar %18, %14 : !torch.tensor, !torch.float -> !torch.tensor\n" -" %22 = torch.aten.mul.Tensor %7#1, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %23 = torch.aten.mul.Tensor %21, %22 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %24 = torch.aten.reshape %23, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %25 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %26 = torch.prim.If %25 -> (!torch.tensor) {\n" -" %34 = torch.aten.reshape %7#1, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" %35 = torch.aten.mul.Scalar %34, %float1.000000e00 : !torch.tensor, !torch.float -> !torch.tensor\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" } else {\n" -" %34 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %7#1, %34 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %36 = torch.aten.reshape %35, %9 : !torch.tensor, !torch.list -> !torch.tensor\n" -" torch.prim.If.yield %36 : !torch.tensor\n" -" }\n" -" %27 = torch.prim.If %arg7 -> (!torch.tensor) {\n" -" %34 = torch.aten.sub.Tensor %arg1, %13, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %35 = torch.aten.mul.Tensor %34, %24 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" %36 = torch.aten.sub.Tensor %arg0, %35, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %37 = torch.aten.sub.Tensor %36, %20, %int1 : !torch.tensor, !torch.tensor, !torch.int -> !torch.tensor\n" -" %38 = torch.aten.mul.Tensor %37, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %38 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.mul.Tensor %arg0, %26 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" }\n" -" %28 = torch.aten.__getitem__.t %arg9, %int1 : !torch.list, !torch.int -> !torch.bool\n" -" %29 = torch.prim.If %28 -> (!torch.tensor) {\n" -" %34 = torch.aten.mul.Tensor %18, %7#1 : !torch.tensor, !torch.tensor -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %35 = torch.prim.If %34 -> (!torch.tensor) {\n" -" %36 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.tensor\n" -" %37 = torch.aten.zeros_like %36, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %37 : !torch.tensor\n" -" } else {\n" -" %36 = torch.prim.ListConstruct : () -> !torch.list\n" -" %37 = torch.aten.zeros %36, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %37 : !torch.tensor\n" -" }\n" -" torch.prim.If.yield %35 : !torch.tensor\n" -" }\n" -" %30 = torch.aten.__getitem__.t %arg9, %int2 : !torch.list, !torch.int -> !torch.bool\n" -" %31 = torch.prim.If %30 -> (!torch.tensor) {\n" -" torch.prim.If.yield %15 : !torch.tensor\n" -" } else {\n" -" %34 = torch.aten.zeros_like %15, %none, %none, %none, %none, %none : !torch.tensor, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor\n" -" torch.prim.If.yield %34 : !torch.tensor\n" -" }\n" -" %32 = torch.prim.TupleConstruct %27, %29, %31 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" %33 = torch.derefine %32 : !torch.tuple to !torch.tuple, optional>\n" -" return %33 : !torch.tuple, optional>\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions_for_jvp.prod(%arg0: !torch.list) -> !torch.int {\n" -" %true = torch.constant.bool true\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.prim.Loop %0, %true, init(%int1) {\n" -" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" -" %2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.mul.int %arg2, %2 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop.condition %true, iter(%3 : !torch.int)\n" -" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" -" func.func @__torch__.torch._decomp.decompositions.cudnn_batch_norm_backward(%arg0: !torch.tensor, %arg1: !torch.tensor, %arg2: !torch.tensor, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional, %arg7: !torch.float, %arg8: !torch.tensor) -> !torch.tuple {\n" -" %true = torch.constant.bool true\n" -" %0 = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list\n" -" %result0, %result1, %result2 = torch.aten.native_batch_norm_backward %arg1, %arg0, %arg2, %arg3, %arg4, %arg5, %arg6, %true, %arg7, %0 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.optional, !torch.optional, !torch.optional, !torch.optional, !torch.bool, !torch.float, !torch.list -> !torch.tensor, !torch.tensor, !torch.tensor\n" -" %1 = torch.prim.TupleConstruct %result0, %result1, %result2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple\n" -" return %1 : !torch.tuple\n" -" }\n" " func.func @__torch__.torch.jit._shape_functions.unary(%arg0: !torch.list) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" @@ -5629,83 +4758,95 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %12 : !torch.list\n" " }\n" -" func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.optional> {\n" -" %str = torch.constant.str \"AssertionError: Either output_size or scale_factors must be presented\"\n" -" %str_0 = torch.constant.str \"AssertionError: \"\n" -" %str_1 = torch.constant.str \"AssertionError: Must specify exactly one of output_size and scale_factors\"\n" +" func.func @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of output_size and scale_factors\"\n" +" %str_1 = torch.constant.str \"AssertionError: Either output_size or scale_factors must be presented\"\n" +" %false = torch.constant.bool false\n" " %none = torch.constant.none\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" " %int2 = torch.constant.int 2\n" " %int3 = torch.constant.int 3\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %2 = torch.aten.append.t %0, %1 : !torch.list, !torch.int -> !torch.list\n" -" %3 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.append.t %0, %3 : !torch.list, !torch.int -> !torch.list\n" -" %5 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.optional>) {\n" -" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" -" %8 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" +" %0 = torch.prim.Uninitialized : !torch.optional>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.append.t %1, %2 : !torch.list, !torch.int -> !torch.list\n" +" %4 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.append.t %1, %4 : !torch.list, !torch.int -> !torch.list\n" +" %6 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.__isnot__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.optional>) {\n" +" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.optional>) {\n" +" torch.prim.If.yield %arg2 : !torch.optional>\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.optional>\n" +" }\n" +" %14 = torch.aten.len.t %11 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %9 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" -" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" +" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.append.t %1, %16 : !torch.list, !torch.int -> !torch.list\n" +" %18 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %13 : !torch.optional>\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.optional>\n" +" }\n" +" %10 = torch.aten.__isnot__ %9, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" %11 = torch.prim.unchecked_cast %9 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %11 = torch.aten.__getitem__.t %7, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.append.t %0, %11 : !torch.list, !torch.int -> !torch.list\n" -" %13 = torch.aten.__getitem__.t %7, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %14 = torch.aten.append.t %0, %13 : !torch.list, !torch.int -> !torch.list\n" -" %15 = torch.derefine %0 : !torch.list to !torch.optional>\n" -" torch.prim.If.yield %15 : !torch.optional>\n" -" } else {\n" -" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.optional>) {\n" -" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" -" %10 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If %10 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %11 = torch.aten.len.t %9 : !torch.list -> !torch.int\n" -" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %12 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %13 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %14 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %15 = torch.operator \"aten.mul.int_float\"(%13, %14) : (!torch.int, !torch.float) -> !torch.float\n" -" %16 = torch.aten.Int.float %15 : !torch.float -> !torch.int\n" -" %17 = torch.aten.append.t %0, %16 : !torch.list, !torch.int -> !torch.list\n" -" %18 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %19 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %20 = torch.operator \"aten.mul.int_float\"(%18, %19) : (!torch.int, !torch.float) -> !torch.float\n" -" %21 = torch.aten.Int.float %20 : !torch.float -> !torch.int\n" -" %22 = torch.aten.append.t %0, %21 : !torch.list, !torch.int -> !torch.list\n" -" %23 = torch.derefine %0 : !torch.list to !torch.optional>\n" -" torch.prim.If.yield %23 : !torch.optional>\n" +" %13 = torch.aten.len.t %11 : !torch.list -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" %9 = torch.derefine %none : !torch.none to !torch.optional>\n" -" torch.prim.If.yield %9 : !torch.optional>\n" +" torch.prim.If.yield\n" " }\n" -" torch.prim.If.yield %8 : !torch.optional>\n" +" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float\n" +" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" +" %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" +" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float\n" +" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" +" %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" " }\n" -" return %6 : !torch.optional>\n" +" return %1 : !torch.list\n" " }\n" " func.func @__torch__.torch.jit._shape_functions.argmax(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" " %true = torch.constant.bool true\n" @@ -6723,6 +5864,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mv\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mv(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7238,6 +6383,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7850,6 +6999,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.upsample_nearest2d(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional>, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 1ff3b1608..d729f81ae 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) { TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, AtenToDeviceOp>(op); } + +Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, + Location loc, float value, + Type dtype) { + // Creating constants satisfying backend contract. + if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) || + dtype.isInteger(1)) + return rewriter.create( + loc, rewriter.getI64IntegerAttr((int64_t)value)); + if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16()) + return rewriter.create(loc, + rewriter.getF64FloatAttr(value)); + llvm::report_fatal_error( + "unhandled type for getConstantWithGivenDtypeAndValue"); +} diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 9be768986..f352e0175 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -8,13 +8,13 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index f7eb50aa6..3e5b38969 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -34,13 +34,13 @@ using namespace mlir::tosa; // Pass registration //===----------------------------------------------------------------------===// -namespace { +namespace reg { #define GEN_PASS_REGISTRATION #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" -} // end namespace +} // end namespace reg void mlir::torch::registerTorchConversionPasses() { - ::registerPasses(); + reg::registerPasses(); mlir::PassPipelineRegistration<>( "torch-backend-to-linalg-on-tensors-backend-pipeline", "Pipeline lowering torch backend contract to linalg-on-tensors backend " diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 62c3e8faf..cae5586b8 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -10,7 +10,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -71,8 +71,7 @@ class VerifyLinalgOnTensorsBackendContractPass // Basic scalar operations. target.addDynamicallyLegalDialect(isLegalScalarOp); target.addDynamicallyLegalDialect(isLegalScalarOp); - target.addDynamicallyLegalDialect( - isLegalScalarOp); + target.addDynamicallyLegalDialect(isLegalScalarOp); // Tensor operations should go through linalg and the tensor dialect. target.addDynamicallyLegalDialect(opHasLegalTypes); diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp index c28ac45eb..bbb7a0a08 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp @@ -10,7 +10,7 @@ #include "PassDetail.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -53,7 +53,7 @@ class VerifyMhloBackendContractPass target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(context); if (failed(applyFullConversion(module, target, std::move(patterns)))) { diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index e86948ebb..a29e14a3d 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -9,7 +9,7 @@ #include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index cde11a481..873c92e93 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -15,7 +15,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -304,7 +304,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); if (failed(applyPartialConversion(func, target, std::move(patterns)))) { @@ -352,7 +352,7 @@ class MemrefCopyOpToLinalg : public OpRewritePattern { LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override { Operation *linalgCopy = createLinalgCopyOp( - rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target()); + rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget()); rewriter.replaceOp(copyOp, linalgCopy->getResults()); return success(); } diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index 6c1e3dbd4..35b0151e9 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -65,7 +65,7 @@ def run_pipeline_with_repro_report(module, {description} failed with the following diagnostics: {sys.stderr.getvalue()} - Error can be reproduced with: + For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. """ diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index aaad8a0b7..207c77961 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -391,7 +391,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( c10::optional device, c10::optional pin_memory) { return at::functionalization:: - functionalize_aten_op::call( + functionalize_aten_op_symint::call( self, size, stride, dtype, layout, device, pin_memory); } @@ -400,7 +400,7 @@ at::Tensor LazyNativeFunctions::narrow_copy_symint( int64_t dim, c10::SymInt start, c10::SymInt length) { - return at::functionalization::functionalize_aten_op::call(self, dim, start, length); } at::Tensor LazyNativeFunctions::pixel_shuffle( @@ -426,7 +426,7 @@ at::Tensor LazyNativeFunctions::slice_backward_symint( c10::SymInt start, c10::SymInt end, c10::SymInt step) { - return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, start, end, step); } at::Tensor LazyNativeFunctions::diagonal_backward( diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index c0e96c40b..0913adb2b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -600,6 +600,9 @@ def aten〇numpy_T(self: List[int]) -> List[int]: def aten〇matmul(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.matmul(self, other) +def aten〇mv(self: List[int], vec: List[int]) -> List[int]: + return upstream_shape_functions.mv(self, vec) + def aten〇mm(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) @@ -863,6 +866,9 @@ def aten〇minimum(self: List[int], other: List[int]) -> List[int]: def aten〇maximum(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇bitwise_or〇Tensor(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1195,6 +1201,9 @@ def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis def aten〇frobenius_norm〇dim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇upsample_nearest2d〇vec(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]: + return upstream_shape_functions.upsample_nearest2d(input, output_size, scale_factors) + # ============================================================================== # Shape library generator main(). # ============================================================================== diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index fc807b88f..35b8694e6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -34,6 +34,8 @@ TORCH_TYPE_TO_ODS_TYPE = { "bool?": "AnyTorchOptionalBoolType", "float": "Torch_FloatType", "float?": "AnyTorchOptionalFloatType", + "float[]": "AnyTorchListOfTorchFloatType", + "float[]?": "AnyTorchOptionalListOfTorchFloatType", "t[]": "AnyTorchListType", "t": "AnyTorchType", "t1": "AnyTorchType", @@ -285,6 +287,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)", @@ -333,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") + emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) @@ -515,6 +519,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") # Dict ops. @@ -566,6 +571,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) + emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)") diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp index 8a8dff5e2..5e9a2a0dd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/ivalue_importer.cpp @@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.prim.ListConstruct", loc, torchMlirTorchListTypeGet( - getMlirTypeFromTorchType(loc, list.elementType())), + getMlirTypeFromTorchType(loc, list.elementType(), importOptions)), elems); return mlirOperationGetResult(operation, 0); } @@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.prim.DictConstruct", loc, torchMlirTorchDictTypeGet( - getMlirTypeFromTorchType(loc, dict.keyType()), - getMlirTypeFromTorchType(loc, dict.valueType())), + getMlirTypeFromTorchType(loc, dict.keyType(), importOptions), + getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)), keys, values); return mlirOperationGetResult(operation, 0); } @@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { at::Tensor tensor = ivalue.toTensor().contiguous(); MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); - MlirOperation tensorOp = createMlirOperationAtEnd( - importBlock, "torch.tensor.literal", loc, - torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements), - toMlirNamedAttribute("value", denseElements)); + MlirOperation tensorOp; + + if (importOptions.assumeTensorsHaveValueSemantics) { + tensorOp = createMlirOperationAtEnd( + importBlock, "torch.vtensor.literal", loc, + torchMlirTorchValueTensorTypeGetFromAttribute(denseElements), + toMlirNamedAttribute("value", denseElements)); + } else { + tensorOp = createMlirOperationAtEnd( + importBlock, "torch.tensor.literal", loc, + torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements), + toMlirNamedAttribute("value", denseElements)); + } + MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0); // Construct the complete tensor value. This is trivial for most tensors, but @@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { // compiler stages that are building a statically modeled quantization // representation will need to convert this to their representation. std::vector shape(tensor.sizes().begin(), tensor.sizes().end()); - MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet( - context, shape.size(), shape.data(), - getMlirTypeForTorchScalarType(loc, tensor.scalar_type())); + MlirType quantizedTensorType; + if (importOptions.assumeTensorsHaveValueSemantics) { + quantizedTensorType = torchMlirTorchValueTensorTypeGet( + context, shape.size(), shape.data(), + getMlirTypeForTorchScalarType(loc, tensor.scalar_type())); + } else { + quantizedTensorType = torchMlirTorchNonValueTensorTypeGet( + context, shape.size(), shape.data(), + getMlirTypeForTorchScalarType(loc, tensor.scalar_type())); + } if (tensor.qscheme() == c10::kPerTensorAffine) { MlirValue qScale = importIValue(c10::IValue(tensor.q_scale())); MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point())); @@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) { "name", mlirStringAttrGet( context, toMlirStringRef(classAttribute.getName()))), toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType( - loc, classAttribute.getType()))), + loc, classAttribute.getType(), importOptions))), isPrivate); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp index da85a9b24..ca4bd600f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.cpp @@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj) } torch::jit::StrongFunctionPtr -ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { +ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function, + py::object maybeImportOptions) { + ImportOptions importOptions; + if (!maybeImportOptions.is_none()) { + importOptions = py::cast(maybeImportOptions); + } MlirBlock block = getBodyBlock(); MlirOperation terminator = this->terminator; - MlirOperation func = importJitFunctionAsFuncOp(context, function.function_); + MlirOperation func = importJitFunctionAsFuncOp(context, function.function_, + [](int) -> MlirAttribute { return {nullptr}; }, importOptions); mlirBlockInsertOwnedOperationBefore(block, terminator, func); return function; } @@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) { .def(py::init(), py::arg("context") = py::none()) .def_property_readonly("context", &ModuleBuilder::getContextObj) .def_property_readonly("module", &ModuleBuilder::getModuleObj) - .def("import_function", &ModuleBuilder::importFunction) + .def("import_function", &ModuleBuilder::importFunction, py::arg("function"), + py::arg("importOptions") = py::none()) .def("import_module", &ModuleBuilder::importModule, py::arg("module"), py::arg("classAnnotator") = py::none(), py::arg("importOptions") = py::none()); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h index 6e1e0beea..08695e15f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/module_builder.h @@ -39,7 +39,8 @@ public: // Just a bit of naming cruft. // Returns the same function, making it suitable as a nested decorator. torch::jit::StrongFunctionPtr - importFunction(torch::jit::StrongFunctionPtr function); + importFunction(torch::jit::StrongFunctionPtr function, + py::object maybeImportOptions); // Imports a torch::jit::Module into the current module, using the // annotations, if not none, provided in `maybeClassAnnotator` which should be diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 9a9bebb9b..1dec6ab27 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, c10::attr::value))))); } else if (output->type()->cast()) { MlirAttribute attr = importAttribute(loc, node, c10::attr::value); - op = createMlirOperation( - "torch.tensor.literal", loc, - torchMlirTorchNonValueTensorTypeGetFromAttribute(attr), - toMlirNamedAttribute("value", attr)); + if (importOptions.assumeTensorsHaveValueSemantics) { + op = createMlirOperation( + "torch.vtensor.literal", loc, + torchMlirTorchValueTensorTypeGetFromAttribute(attr), + toMlirNamedAttribute("value", attr)); + } else { + op = createMlirOperation( + "torch.tensor.literal", loc, + torchMlirTorchNonValueTensorTypeGetFromAttribute(attr), + toMlirNamedAttribute("value", attr)); + } } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.device", loc, diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index cdfcbb54f..cf945d0cf 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 2, 4)) +class ToCopyBoolDTypeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 5, 5], torch.uint8, True), + ]) + def forward(self, x): + return torch.ops.aten._to_copy(x, dtype=torch.bool) + + +@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule()) +def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8)) + + # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index a098db07a..c95deba3b 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -705,3 +705,80 @@ class Conv_Transpose3dModule(torch.nn.Module): @register_test_case(module_factory=lambda: Conv_Transpose3dModule()) def Conv_Transpose3dModule_basic(module, tu: TestUtils): module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2)) + +class UpSampleNearest2dSameSize(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec): + return torch._C._nn.upsample_nearest2d(inputVec, + output_size=[11, 11], + scale_factors=None) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dSameSize()) +def UpSampleNearest2dStaticSize_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4)) + + +class UpSampleNearest2dDiffSize(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, inputVec): + return torch._C._nn.upsample_nearest2d(inputVec, + output_size=[8, 11], + scale_factors=None) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dDiffSize()) +def UpSampleNearest2dDynamicSize_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 2, 2)) + + +class UpSampleNearest2dDiffFactor(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, inputVec): + return torch._C._nn.upsample_nearest2d(inputVec, + output_size=None, + scale_factors=[2.3, 4.7]) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dDiffFactor()) +def UpSampleNearest2dDynamicFactor_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 2, 2)) + + +class UpSampleNearest2dSameFactor(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec): + return torch._C._nn.upsample_nearest2d(inputVec, + output_size=None, + scale_factors=[2.0, 2.0]) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor()) +def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 6770d7237..d14d4bb2d 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1553,6 +1553,31 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseOrIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.bitwise_or(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseOrIntegerModule()) +def ElementwiseOrIntegerModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(3, 4, low=-10, high=10)) + + +# ============================================================================== + + class ElementwiseNotIntegerModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/matmul.py b/python/torch_mlir_e2e_test/test_suite/matmul.py index e1ecfa6a3..e40086bb7 100644 --- a/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module): def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) +# ============================================================================== + +class Mv(torch.nn.Module): + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, m, v): + return torch.mv(m, v) + + +@register_test_case(module_factory=lambda: Mv()) +def Mv_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2), tu.rand(2)) \ No newline at end of file diff --git a/python/torch_mlir_e2e_test/test_suite/scalar.py b/python/torch_mlir_e2e_test/test_suite/scalar.py index 16ce64bb0..95879b44e 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -104,6 +104,31 @@ def MulIntModule_basic(module, tu: TestUtils): # ============================================================================== +class DivIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + # Cast the result to float to make e2e test baseline result to be a float. + # Without the cast, baseline result is a Tensor which is unexpected. + return float(torch.ops.aten.div(int(lhs), int(rhs))) + + +@register_test_case(module_factory=lambda: DivIntModule()) +def DivIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-10, high=10), tu.randint(low=3, high=10)) + + +# ============================================================================== + + class DivFloatModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 2df66184e..53f2d2e0a 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) -class ToDtypeBoolLayoutNoneModule(torch.nn.Module): +class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([None, ([-1, -1], torch.float32, True)]) + @annotate_args([None, ([3, 5], torch.int64, True)]) def forward(self, x): return torch.ops.aten.to(x, dtype=torch.bool, @@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module): memory_format=None) -@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule()) -def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 5)) +@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule()) +def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5)) class TypeAsSameModule(torch.nn.Module): diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index dc016a434..84230deb9 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==1.13.0.dev20220930 +torch==1.13.0.dev20221004 diff --git a/pytorch-version.txt b/pytorch-version.txt index e85cfd939..591132c78 100644 --- a/pytorch-version.txt +++ b/pytorch-version.txt @@ -1 +1 @@ -3bf7094ddb95e8b9bcd2b1f35589e729aa2f4248 +9f3d8fec5747fde5191618eb895fbec2d50edf93 diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2edcb9b82..48061c15c 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -847,3 +847,69 @@ func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { %0 = torch.aten.arange.start_step %int0, %int5, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64> return %0 : !torch.vtensor<[5],si64> } + +// ----- +// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<1> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> +} + +// ----- +// CHECK-LABEL: func.func @torch.valsem.aten.copy( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { +// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8> +// CHECK: %[[CST5:.*]] = torch.constant.int 5 +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[CST11:.*]] = torch.constant.int 11 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> +func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { + %int5 = torch.constant.int 5 + %int1 = torch.constant.int 1 + %int11 = torch.constant.int 11 + %none = torch.constant.none + %false = torch.constant.bool false + %int0 = torch.constant.int 0 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1> + %2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list -> !torch.vtensor<[1,1,5,5],i1> + %4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1> + return %4 : !torch.vtensor<[1,1,5,5],i1> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64> +// CHECK: %[[CST11:.*]] = torch.constant.int 11 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64> +// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> +func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { + %int11 = torch.constant.int 11 + %none = torch.constant.none + %false = torch.constant.bool false + %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1> + return %0 : !torch.vtensor<[3,5],i1> +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 109e39810..b231a3bfe 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1351,6 +1351,16 @@ func.func @torch.aten.div.float$fold_cst_operands() -> !torch.float { return %0 : !torch.float } +// CHECK-LABEL: func.func @torch.aten.div.int$fold_cst_operands( +// CHECK: %[[CST:.*]] = torch.constant.float 5.000000e-01 +// CHECK: return %[[CST]] : !torch.float +func.func @torch.aten.div.int$fold_cst_operands() -> !torch.float { + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.aten.div.int %int2, %int4 : !torch.int, !torch.int -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.to.dtype_layout$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32> diff --git a/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py b/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py new file mode 100644 index 000000000..e57c20fe5 --- /dev/null +++ b/test/python/importer/jit_ir/ivalue_import/tensors-value-semantics.py @@ -0,0 +1,42 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See LICENSE.pytorch for license information. + +import typing + +import torch +from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder + +# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s + +mb = ModuleBuilder() + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ones_i32 = torch.ones(1, dtype=torch.int32) + self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8) + self.arange = torch.nn.Parameter(torch.arange(3.0)) + +# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32> +# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32> +# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8> +# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00 +# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0 +# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8> +# CHECK: %[[ROOT:.*]] = torch.nn_module { +# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32> +# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32> +# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8> +# CHECK: } +test_module = TestModule() +recursivescriptmodule = torch.jit.script(test_module) + +import_options = ImportOptions() +import_options.assumeTensorsHaveValueSemantics = True + +class_annotator = ClassAnnotator() + +# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. +mb.import_module(recursivescriptmodule._c, class_annotator, import_options) +mb.module.operation.print() diff --git a/test/python/importer/jit_ir/node_import/prim.py b/test/python/importer/jit_ir/node_import/prim.py index 21ec33c92..2565c6c41 100644 --- a/test/python/importer/jit_ir/node_import/prim.py +++ b/test/python/importer/jit_ir/node_import/prim.py @@ -5,7 +5,7 @@ import typing import torch -from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder from utils import create_script_function @@ -162,3 +162,34 @@ graph(): mb.module.operation.print() print() + +# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number { +# CHECK: %[[A:.*]] = torch.tensor.literal +# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit +# CHECK: return %[[RET]] : !torch.number +import_options = ImportOptions() +import_options.assumeTensorsHaveValueSemantics = False +mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """ +graph(): + %0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + %1 : Scalar = aten::ScalarImplicit(%0) + return (%1) +""", parse_tensor_constants=True), import_options) + +mb.module.operation.print() +print() + +# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number { +# CHECK: %[[A:.*]] = torch.vtensor.literal +# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit +# CHECK: return %[[RET]] : !torch.number +import_options.assumeTensorsHaveValueSemantics = True +mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """ +graph(): + %0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() + %1 : Scalar = aten::ScalarImplicit(%0) + return (%1) +""", parse_tensor_constants=True), import_options) + +mb.module.operation.print() +print() diff --git a/test/python/importer/jit_ir/node_import/utils.py b/test/python/importer/jit_ir/node_import/utils.py index 6e8d1ac45..613ccb6a8 100644 --- a/test/python/importer/jit_ir/node_import/utils.py +++ b/test/python/importer/jit_ir/node_import/utils.py @@ -10,6 +10,6 @@ from torch._C import CompilationUnit # RUN: %PYTHON %s # Import TorchScript IR string as ScriptFunction. -def create_script_function(func_name, ts_ir_str): +def create_script_function(func_name, ts_ir_str, **kwargs): cu = CompilationUnit() - return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)) + return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs)) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 1bd831223..6dab2682f 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -337,7 +337,7 @@ cc_library( strip_include_prefix = "include", deps = [ ":TorchMLIRTorchDialect", - "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", @@ -373,7 +373,7 @@ cc_library( ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", ":TorchMLIRTorchDialect", - "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:LinalgDialect", @@ -397,7 +397,7 @@ cc_library( ":TorchMLIRConversionPassesIncGen", ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", - "@llvm-project//mlir:ArithmeticDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:LinalgDialect", @@ -809,7 +809,7 @@ cc_library( ":TorchMLIRRefBackendPassIncGen", ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchConversionDialect", - "@llvm-project//mlir:ArithmeticTransforms", + "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathTransforms",