mirror of https://github.com/llvm/torch-mlir
Merge branch 'llvm:main' into windows-autogen_ltc_backend.py
commit
65ca4c8454
|
@ -1,8 +1,6 @@
|
|||
name: Roll PyTorch
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 12 * * *'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
@ -34,8 +32,8 @@ jobs:
|
|||
python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torch==${PT_RELEASE}"
|
||||
# Read the commit hash from the downloaded whl file without extracting it
|
||||
PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'")
|
||||
echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet
|
||||
PT_HASH_CHANGED=$?
|
||||
PT_HASH_CHANGED=0
|
||||
echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet || PT_HASH_CHANGED=$?
|
||||
echo "${PT_HASH}" > pytorch-version.txt
|
||||
rm torch-"${PT_RELEASE}"*.whl
|
||||
# Write the release and hash to the environment file so that we can
|
||||
|
@ -44,13 +42,14 @@ jobs:
|
|||
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
|
||||
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
|
||||
|
||||
- name: Build and test
|
||||
- name: Build and test (in-tree), also update ODS and shape library
|
||||
if: env.PT_HASH_CHANGED != '0'
|
||||
run: |
|
||||
cd ${GITHUB_WORKSPACE}
|
||||
TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \
|
||||
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
|
||||
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
|
||||
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
|
||||
TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \
|
||||
./build_tools/python_deploy/build_linux_packages.sh
|
||||
|
||||
- name: Push changes to main branch
|
||||
|
@ -61,5 +60,5 @@ jobs:
|
|||
git config user.name "Roll PyTorch Action"
|
||||
git fetch --recurse-submodules=no
|
||||
git checkout main
|
||||
git add pytorch-version.txt pytorch-requirements.txt
|
||||
git add pytorch-version.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)
|
||||
|
|
|
@ -61,7 +61,7 @@ jobs:
|
|||
if: ${{ matrix.os-arch == 'ubuntu-x86_64' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh
|
||||
TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-version.txt)" TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh
|
||||
- name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}'
|
||||
# cross compile, can't test arm64
|
||||
if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }}
|
||||
|
|
|
@ -94,6 +94,15 @@ jobs:
|
|||
with:
|
||||
release_id: ${{ github.event.inputs.release_id }}
|
||||
|
||||
publish_releases:
|
||||
needs:
|
||||
- build_linux
|
||||
- build_macos
|
||||
|
||||
# Publish even if one of the builds failed
|
||||
if: ${{ always() }}
|
||||
|
||||
steps:
|
||||
- name: Invoke Publish Releases Page
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
with:
|
||||
|
|
|
@ -12,7 +12,6 @@ blacklist:
|
|||
- detach
|
||||
- item
|
||||
- size
|
||||
- where
|
||||
- copy_
|
||||
|
||||
# Disabled for consistency with TS backend
|
||||
|
|
|
@ -58,7 +58,8 @@ checkout_pytorch() {
|
|||
git reset --hard FETCH_HEAD
|
||||
else
|
||||
cd "${PYTORCH_ROOT}"
|
||||
git reset --hard HEAD
|
||||
git fetch --depth=1 origin "${TORCH_MLIR_SRC_PYTORCH_BRANCH}"
|
||||
git reset --hard FETCH_HEAD
|
||||
fi
|
||||
git clean -df
|
||||
git submodule update --init --depth 1 --recursive
|
||||
|
|
|
@ -53,14 +53,16 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
|
|||
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
|
||||
# Skip running tests if you want quick iteration
|
||||
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
|
||||
# Update ODS and shape library files
|
||||
TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}"
|
||||
|
||||
PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
|
||||
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
|
||||
echo "Setting torch-mlir Python Package version to: ${TORCH_MLIR_PYTHON_PACKAGE_VERSION}"
|
||||
|
||||
TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}"
|
||||
export TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}"
|
||||
echo "Setting torch-mlir PyTorch Repo for source builds to: ${TORCH_MLIR_SRC_PYTORCH_REPO}"
|
||||
TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}"
|
||||
export TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}"
|
||||
echo "Setting torch-mlir PyTorch version for source builds to: ${TORCH_MLIR_SRC_PYTORCH_BRANCH}"
|
||||
|
||||
function run_on_host() {
|
||||
|
@ -109,6 +111,7 @@ function run_on_host() {
|
|||
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
|
||||
-e "TM_PACKAGES=${package}" \
|
||||
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
|
||||
-e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \
|
||||
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
|
||||
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
|
||||
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
|
||||
|
@ -152,6 +155,12 @@ function run_in_docker() {
|
|||
in-tree)
|
||||
setup_venv "$python_version"
|
||||
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
|
||||
if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then
|
||||
pushd /main_checkout/torch-mlir
|
||||
./build_tools/update_torch_ods.sh
|
||||
./build_tools/update_shape_lib.sh
|
||||
popd
|
||||
fi
|
||||
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
|
||||
test_in_tree;
|
||||
fi
|
||||
|
|
|
@ -28,23 +28,25 @@ Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is
|
|||
|
||||
### Building torch-mlir in-tree
|
||||
|
||||
The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects.
|
||||
The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. On Windows, use the "Developer PowerShell for Visual Studio" to ensure that the compiler and linker binaries are in the `PATH` variable.
|
||||
|
||||
```shell
|
||||
cmake -GNinja -Bbuild \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER=clang \
|
||||
-DCMAKE_CXX_COMPILER=clang++ \
|
||||
-DPython3_FIND_VIRTUALENV=ONLY \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="$PWD"/externals/llvm-external-projects/torch-mlir-dialects \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_TARGETS_TO_BUILD=host \
|
||||
externals/llvm-project/llvm
|
||||
```
|
||||
The following additional quality of life flags can be used to reduce build time:
|
||||
* Enabling clang on Linux
|
||||
```shell
|
||||
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
|
||||
```
|
||||
* Enabling ccache:
|
||||
```shell
|
||||
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
|
@ -73,8 +75,6 @@ If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`,
|
|||
```shell
|
||||
cmake -GNinja -Bbuild \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER=clang \
|
||||
-DCMAKE_CXX_COMPILER=clang++ \
|
||||
-DPython3_FIND_VIRTUALENV=ONLY \
|
||||
-DMLIR_DIR="$LLVM_INSTALL_DIR/lib/cmake/mlir/" \
|
||||
-DLLVM_DIR="$LLVM_INSTALL_DIR/lib/cmake/llvm/" \
|
||||
|
@ -82,7 +82,7 @@ cmake -GNinja -Bbuild \
|
|||
-DLLVM_TARGETS_TO_BUILD=host \
|
||||
.
|
||||
```
|
||||
The same QoL CMake flags can be used to enable ccache and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`.
|
||||
The same QoL CMake flags can be used to enable clang, ccache, and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`.
|
||||
|
||||
Be aware that the installed version of LLVM needs in general to match the committed version in `externals/llvm-project`. Using a different version may or may not work.
|
||||
|
||||
|
@ -105,15 +105,25 @@ cmake --build build
|
|||
```
|
||||
|
||||
## Setup Python Environment to export the built Python packages
|
||||
|
||||
### Linux and macOS
|
||||
|
||||
```shell
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples
|
||||
```
|
||||
|
||||
### Windows PowerShell
|
||||
|
||||
```shell
|
||||
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/examples"
|
||||
```
|
||||
|
||||
## Testing MLIR output in various dialects
|
||||
|
||||
To test the compiler's output to the different MLIR dialects, you can use the example `examples/torchscript_resnet18_all_output_types.py`.
|
||||
|
||||
Make sure you have activated the virtualenv and set the `PYTHONPATH` above:
|
||||
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
|
||||
(if running on Windows, modify the environment variable as shown above):
|
||||
```shell
|
||||
source mlir_venv/bin/activate
|
||||
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples
|
||||
|
|
|
@ -463,6 +463,10 @@ TOSA_PASS_SET = {
|
|||
"ArangeStartIntModule_basic",
|
||||
"ArangeStartNegativeStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"ToDtypeBoolLayoutNoneStaticModule_basic",
|
||||
"ToCopyBoolDTypeStaticModule_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
@ -569,8 +573,8 @@ LTC_XFAIL_SET = {
|
|||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_matvec",
|
||||
"MulIntModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"NewEmptyModuleDefaultDtype_basic",
|
||||
|
@ -632,4 +636,8 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"AtenIntTensorByteDtypeModule_basic",
|
||||
"AtenIntTensorCharDtypeModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
}
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
|
@ -134,7 +134,7 @@ struct TMTensorBufferizePass
|
|||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
|
||||
// Mark all Standard operations legal.
|
||||
target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
|
||||
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,
|
||||
memref::MemRefDialect, tensor::TensorDialect>();
|
||||
|
||||
// Mark all TMTensor operations illegal as long as they work on tensors.
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -101,7 +101,7 @@ namespace {
|
|||
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, func::FuncDialect,
|
||||
mlir::arith::ArithmeticDialect, math::MathDialect,
|
||||
mlir::arith::ArithDialect, math::MathDialect,
|
||||
memref::MemRefDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
set(LIBS
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRDialect
|
||||
MLIRLinalgDialect
|
||||
MLIRMemRefDialect
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -39,7 +39,7 @@ int main(int argc, char **argv) {
|
|||
// Local dialects
|
||||
mlir::torch::TMTensor::TMTensorDialect,
|
||||
// Upstream dialects
|
||||
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
|
||||
mlir::arith::ArithDialect, mlir::linalg::LinalgDialect,
|
||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
||||
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit bebc96956b76bdbc36f1d82a788c810e5b12e2c5
|
||||
Subproject commit 6f46ff3765dcdc178b9cf52ebd8c03437806798a
|
|
@ -1 +1 @@
|
|||
Subproject commit 7b0ecf7827e3fc07d2af90e147bcedc165bc78ac
|
||||
Subproject commit 2f7c1454bbe4c4ad0ae1c86c5539ac58b6053b6a
|
|
@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
|
|||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/// Gets the !torch.vtensor type with the tensor attribute.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header declares the registration and creation method for
|
||||
// transformation passes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_C_TRANSFORMS_H
|
||||
#define TORCHMLIR_C_TRANSFORMS_H
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||
|
||||
#endif // TORCHMLIR_C_TRANSFORMS_H
|
|
@ -53,6 +53,9 @@ template <typename T>
|
|||
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> vec, ArrayRef<int64_t> shape);
|
||||
|
||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result);
|
||||
|
||||
// Creates a TOSA operation and performs shape inference on the individual
|
||||
// op. This allows shape inference during the framework to TOSA lowering.
|
||||
template <typename TosaOp, typename... Args>
|
||||
|
|
|
@ -2192,6 +2192,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -3387,6 +3434,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMvOp : Torch_Op<"aten.mv", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$vec
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenMvOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -7399,6 +7470,31 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchOptionalListOfTorchIntType:$output_size,
|
||||
AnyTorchOptionalListOfTorchFloatType:$scale_factors
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -8356,6 +8452,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div.int : (int, int) -> (float)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a,
|
||||
Torch_IntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDivIntOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenDivIntOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -383,6 +383,7 @@ class ListOf<list<Type> allowedTypes, string descr> :
|
|||
|
||||
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
def AnyTorchListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
||||
def AnyTorchListOfTorchFloatType : ListOf<[Torch_FloatType], "Float list type (float[])">;
|
||||
def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">;
|
||||
def AnyTorchListOfTensorType:
|
||||
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
|
||||
|
@ -390,7 +391,7 @@ def AnyTorchListOfOptionalTensorType :
|
|||
ListOf<[AnyTorchOptionalTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">;
|
||||
|
||||
def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">;
|
||||
// Note: TorchScript does not consider !torch.bool to be a Scalar.
|
||||
def AnyTorchScalarType :
|
||||
Type<CPred<"isValidSubtype($_self, ::mlir::torch::Torch::NumberType::get($_self.getContext()))">,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header)
|
||||
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl)
|
||||
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)
|
||||
|
|
|
@ -21,6 +21,8 @@ class ModuleOp;
|
|||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -109,6 +111,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
|||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
|
||||
|
||||
StringRef getShapeLibrary();
|
||||
|
||||
} // namespace Torch
|
||||
|
@ -116,6 +120,13 @@ StringRef getShapeLibrary();
|
|||
/// Registers all Torch transformation passes.
|
||||
void registerTorchPasses();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -329,4 +329,16 @@ def LowerToBackendContract
|
|||
let dependentDialects = ["func::FuncDialect"];
|
||||
}
|
||||
|
||||
def VerifyBackendContract
|
||||
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
|
||||
let summary = "Check that program satisfies backend contract.";
|
||||
let constructor =
|
||||
"mlir::torch::Torch::createVerifyBackendContractPass()";
|
||||
let description = [{
|
||||
This pass performs a set of inspections to check that program satisfies backend
|
||||
contract. In case of check failure it prints out the error message and returns
|
||||
`signalPassFailure()` status.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -52,6 +52,9 @@ int getTensorRank(Value tensor);
|
|||
|
||||
bool isViewLikeOp(Operation *op);
|
||||
|
||||
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
||||
float value, Type dtype);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
|||
Registration.cpp
|
||||
TorchOps.cpp
|
||||
TorchTypes.cpp
|
||||
Transforms.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
|
||||
|
@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
|
|||
MLIRSupport
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRInitAll
|
||||
TorchMLIRTorchPasses
|
||||
)
|
||||
|
||||
torch_mlir_target_includes(TorchMLIRCAPI)
|
||||
|
|
|
@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
|||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||
auto attrTensorType =
|
||||
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
|
||||
attrTensorType.getShape(),
|
||||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/CAPI/Pass.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
// Must include the declarations as they carry important visibility attributes.
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc"
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -10,7 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -121,6 +121,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenDivIntOp : public OpConversionPattern<AtenDivIntOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenDivIntOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenDivIntOp op,
|
||||
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value a =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.a(), rewriter.getF64Type());
|
||||
Value b =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.b(), rewriter.getF64Type());
|
||||
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Lowers aten integer comparison ops.
|
||||
template <typename AtenOp, arith::CmpIPredicate Pred>
|
||||
|
@ -300,7 +319,7 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
|
|||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
registry.insert<math::MathDialect>();
|
||||
|
@ -311,7 +330,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
||||
arith::ArithmeticDialect, tensor::TensorDialect,
|
||||
arith::ArithDialect, tensor::TensorDialect,
|
||||
cf::ControlFlowDialect, math::MathDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
@ -374,6 +393,8 @@ public:
|
|||
target.addIllegalOp<AtenSubFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenDivIntOp>();
|
||||
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDivFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||
typeConverter, context);
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -252,11 +252,11 @@ public:
|
|||
llvm::all_of(expandShape,
|
||||
[](int64_t value) { return value == kUnknownSize; })) {
|
||||
|
||||
for (int i = 0; i < collapseShape.size(); i++) {
|
||||
for (size_t i = 0; i < collapseShape.size(); i++) {
|
||||
collapseIndices.push_back(i);
|
||||
}
|
||||
|
||||
for (int i = 0; i < expandShape.size(); i++) {
|
||||
for (size_t i = 0; i < expandShape.size(); i++) {
|
||||
expandIndices.push_back(i);
|
||||
}
|
||||
|
||||
|
@ -290,8 +290,8 @@ public:
|
|||
op, "total number of elements mismatch in the expansion");
|
||||
}
|
||||
|
||||
static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
|
||||
SmallVector<int64_t> &outputShape) {
|
||||
static void solveDynamicSize(SmallVector<int64_t> &inputShape,
|
||||
SmallVector<int64_t> &outputShape) {
|
||||
int64_t inputProduct = 1;
|
||||
int64_t outputProduct = 1;
|
||||
|
||||
|
@ -316,7 +316,7 @@ public:
|
|||
if (inputDynamicValues + outputDynamicValues == 1) {
|
||||
if (inputDynamicValues) {
|
||||
int64_t missingValue = outputProduct / inputProduct;
|
||||
for (int i = 0; i < inputShape.size(); i++) {
|
||||
for (size_t i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] == -1) {
|
||||
inputShape[i] = missingValue;
|
||||
break;
|
||||
|
@ -324,7 +324,7 @@ public:
|
|||
}
|
||||
} else {
|
||||
int64_t missingValue = inputProduct / outputProduct;
|
||||
for (int i = 0; i < outputShape.size(); i++) {
|
||||
for (size_t i = 0; i < outputShape.size(); i++) {
|
||||
if (outputShape[i] == -1) {
|
||||
outputShape[i] = missingValue;
|
||||
break;
|
||||
|
@ -332,8 +332,6 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -625,9 +623,6 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
int64_t inputCount = inputAssociations.size();
|
||||
int64_t outputCount = outputAssociations.size();
|
||||
|
||||
// Check if the shapes already match up to dynamic sizes. If so, we can just
|
||||
// cast as the result type because the previous loop sets up the necessary
|
||||
// dim checks in case of dynamic sizes.
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -755,6 +756,146 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// `getScaledDims` scales the `dim` value with a scale factor `ScaleFactor`.
|
||||
// The `dim` and `scaleFactor` are assumed to be of index and float type
|
||||
// respectively. `scaledDim = int(floor(float(dim) * scaleFactor))`.
|
||||
static Value getScaledDims(OpBuilder &builder, Location loc, Value dim,
|
||||
Value scaleFactor) {
|
||||
|
||||
Value dimInt = castIndexToInt64(builder, loc, dim);
|
||||
Value dimFp =
|
||||
builder.create<arith::SIToFPOp>(loc, scaleFactor.getType(), dimInt);
|
||||
Value scaleDim = builder.create<arith::MulFOp>(loc, dimFp, scaleFactor);
|
||||
Value floorDim = builder.create<math::FloorOp>(loc, scaleDim);
|
||||
Value scaledDimToIndex = castIntToIndex(
|
||||
builder, loc,
|
||||
builder.create<arith::FPToSIOp>(loc, dimInt.getType(), floorDim));
|
||||
|
||||
return scaledDimToIndex;
|
||||
}
|
||||
|
||||
// `getScaleFactor` returns the scale factor from input to output dimension.
|
||||
// The `dim` and `scaledDim` are assumed to be of index and int64 type
|
||||
// respectively. scale_factor = (scaled_dim // dim).
|
||||
static Value getScaleFactor(OpBuilder &builder, Location loc, Value dim,
|
||||
Value scaledDim) {
|
||||
Value dimInt = castIndexToInt64(builder, loc, dim);
|
||||
Value scaleFactorInt =
|
||||
builder.create<arith::CeilDivSIOp>(loc, scaledDim, dimInt);
|
||||
return scaleFactorInt;
|
||||
}
|
||||
|
||||
// N, C, H, W = input_tensor.shape
|
||||
// N, C, H_scaled, W_scaled = out_tensor.shape
|
||||
// H_factor, W_factor = H_scaled/H, W_scaled/W
|
||||
|
||||
// for i in range(N):
|
||||
// for j in range(C):
|
||||
// for k in range(H_scaled):
|
||||
// for l in range(W_scaled):
|
||||
// out_tensor[i, j, k, l] = input[i, j, k//H_factor, l//W_factor]
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUpsampleNearest2dVecOp
|
||||
: public OpConversionPattern<AtenUpsampleNearest2dVecOp> {
|
||||
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenUpsampleNearest2dVecOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.input();
|
||||
|
||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inputType.getRank();
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||
SmallVector<Value, 2> scaleFactorsInt;
|
||||
|
||||
// The dimension at which the scaling starts.
|
||||
unsigned hDimOffset = 2;
|
||||
|
||||
if (!adaptor.scale_factors().getType().isa<Torch::NoneType>()) {
|
||||
SmallVector<Value, 2> scaleFactorsTorchFloat;
|
||||
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the scale_factors is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value, 2> scaleFactorsFloatValues;
|
||||
scaleFactorsFloatValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
|
||||
// Convert float values to int values.
|
||||
// int_value = (int64_t)ceil(float_value)
|
||||
for (auto floatValue : scaleFactorsFloatValues) {
|
||||
Value ceilVal = rewriter.create<math::CeilOp>(loc, floatValue);
|
||||
Value intVal = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), ceilVal);
|
||||
scaleFactorsInt.push_back(intVal);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < scaleFactorsFloatValues.size(); i++)
|
||||
dims[hDimOffset + i] = getScaledDims(
|
||||
rewriter, loc, dims[hDimOffset + i], scaleFactorsFloatValues[i]);
|
||||
|
||||
} else {
|
||||
|
||||
SmallVector<Value, 2> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.output_size(), outputSizeTorchInt))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the output_size is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value, 2> outputSizeIntValues;
|
||||
outputSizeIntValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||
|
||||
for (unsigned i = 0; i < outputSizeTorchInt.size(); i++) {
|
||||
auto scaleFactorVal = getScaleFactor(
|
||||
rewriter, loc, dims[hDimOffset + i], outputSizeIntValues[i]);
|
||||
scaleFactorsInt.push_back(scaleFactorVal);
|
||||
dims[hDimOffset + i] =
|
||||
castIntToIndex(rewriter, loc, outputSizeIntValues[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, dims, elementType);
|
||||
|
||||
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||
SmallVector<StringRef> iteratorTypes(inputRank,
|
||||
getParallelIteratorTypeName());
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), ValueRange{}, outTensor,
|
||||
/*indexingMaps=*/idMap,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
SmallVector<Value> indices;
|
||||
for (unsigned i = 0; i < inputRank; i++)
|
||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
|
||||
for (unsigned i = 0; i < (inputRank - hDimOffset); i++)
|
||||
indices[i + hDimOffset] = b.create<arith::FloorDivSIOp>(
|
||||
loc, indices[i + hDimOffset],
|
||||
castIntToIndex(rewriter, loc, scaleFactorsInt[i]));
|
||||
|
||||
Value retVal =
|
||||
b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
b.create<linalg::YieldOp>(loc, retVal);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::
|
||||
populateIndirectDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
@ -770,4 +911,6 @@ void mlir::torch::torch_to_linalg::
|
|||
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
|
||||
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
|
||||
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -152,12 +152,10 @@ public:
|
|||
nestedLoc, oldIndex.getType(),
|
||||
rewriter.create<linalg::IndexOp>(loc, dim));
|
||||
|
||||
Value predicate;
|
||||
if (inElementType.isa<mlir::FloatType>())
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
auto resultMax = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newValue, oldValue);
|
||||
auto resultMax = rewriter.create<arith::MaxFOp>(
|
||||
nestedLoc, newValue, oldValue);
|
||||
Value predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
@ -43,7 +43,7 @@ public:
|
|||
registry.insert<math::MathDialect>();
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
cf::ControlFlowDialect, math::MathDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
||||
tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "Utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -199,6 +199,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
|
||||
if (bitwiseOrTensor.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
bitwiseOrTensor.emitError(
|
||||
"Bitwise_Or does not support floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type dtype = converter->convertType(bitwiseOrTensor.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::OrIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
|
||||
MLIRContext *context = op->getContext();
|
||||
Type floatDtype = mlir::FloatType::getF64(context);
|
||||
|
@ -1006,12 +1022,12 @@ public:
|
|||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
|
||||
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
|
||||
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenBitwiseNotOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
@ -1483,10 +1499,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp,
|
||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenBitwiseOrTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/utils/hlo_utils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -42,14 +42,14 @@ public:
|
|||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
||||
tensor::TensorDialect, arith::ArithDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -321,7 +321,7 @@ namespace {
|
|||
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<scf::SCFDialect, arith::ArithmeticDialect>();
|
||||
registry.insert<scf::SCFDialect, arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
@ -329,7 +329,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
arith::ArithDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -614,7 +614,7 @@ public:
|
|||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<TMTensorDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
@ -623,7 +623,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect,
|
||||
Torch::TorchDialect, TMTensorDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -3055,6 +3055,133 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
TypeConverter *typeConverter = this->getTypeConverter();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
// Only supports integer operand type, because for the floating point operand
|
||||
// type result tensor has to be of type `f64` which is not supported in the
|
||||
// tosa.
|
||||
int64_t initValue;
|
||||
if (!matchPattern(op.a(), m_TorchConstantInt(&initValue)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input should be a torch constant int");
|
||||
|
||||
DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<ValsemVariantAtenCopyOp>::matchAndRewrite(
|
||||
ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
auto srcType = adaptor.src().getType().dyn_cast<TensorType>();
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
if (!srcType || !srcType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
// The non_blocking should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking is expected to be false");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> selfShape(selfType.getShape());
|
||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||
|
||||
if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) {
|
||||
// If we reach here, then it means the given case is handled by implicit
|
||||
// broadcasting done by tosa.
|
||||
Value result;
|
||||
if (failed(tosa::tosaCastTensorToType(
|
||||
rewriter, op, adaptor.src(),
|
||||
getTypeConverter()->convertType(op.getType()), result)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: cast to result type not supported");
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: valsem.aten.copy op not supported for this case.");
|
||||
}
|
||||
|
||||
// Legalizes the torch.aten.to.dtype op
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||
AtenToDtypeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types with static shape are supported");
|
||||
|
||||
// The non_blocking arg should be a constant `False`.
|
||||
bool nonBlocking;
|
||||
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking arg must be a constant");
|
||||
} else if (nonBlocking) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: non_blocking arg is expected to be false");
|
||||
}
|
||||
|
||||
// The copy arg should be a constant `False`.
|
||||
bool copy;
|
||||
if (!matchPattern(op.copy(), m_TorchConstantBool(©))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: copy arg must be a constant");
|
||||
} else if (copy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: copy arg is expected to be false");
|
||||
}
|
||||
|
||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
|
||||
int64_t memoryFormat;
|
||||
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the memory format should be specified in "
|
||||
"an integer constant");
|
||||
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only none, contiguous and preserve "
|
||||
"memory_format is supported");
|
||||
}
|
||||
|
||||
auto resultTy = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Value result;
|
||||
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy,
|
||||
result)))
|
||||
return rewriter.notifyMatchFailure(op, "conversion to result type failed");
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -3511,7 +3638,7 @@ public:
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<tosa::TosaDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
||||
|
@ -3519,7 +3646,7 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
arith::ArithDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -3704,6 +3831,9 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -221,6 +221,64 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
return const_op.getResult();
|
||||
}
|
||||
|
||||
static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
||||
if ((src.isInteger(64) && dest.isInteger(32)) ||
|
||||
(src.isInteger(32) && dest.isInteger(64)) ||
|
||||
(src.isInteger(64) && dest.isInteger(1)) ||
|
||||
(src.isInteger(32) && dest.isInteger(1)) ||
|
||||
(src.isInteger(8) && dest.isInteger(1)) ||
|
||||
(src.isF32() && dest.isInteger(1))) {
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Template specialization for float
|
||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||
Value src, Type destType, Value &result) {
|
||||
|
||||
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
||||
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
|
||||
|
||||
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "casting to result dtype is invalid or unsupported");
|
||||
|
||||
if (destElemTy.isInteger(1)) {
|
||||
auto srcType = src.getType().dyn_cast<TensorType>();
|
||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : srcShape)
|
||||
num_total_elements *= a;
|
||||
|
||||
llvm::Optional<Value> constOp;
|
||||
if (srcElemTy.isInteger(64)) {
|
||||
SmallVector<int64_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isInteger(32)) {
|
||||
SmallVector<int32_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isF32()) {
|
||||
SmallVector<float> values(num_total_elements, 0.0);
|
||||
constOp =
|
||||
tosa::getConstTensor<float>(rewriter, op, values, srcShape).value();
|
||||
} else if (srcElemTy.isInteger(8)) {
|
||||
SmallVector<int8_t> values(num_total_elements, 0);
|
||||
constOp =
|
||||
tosa::getConstTensor<int8_t>(rewriter, op, values, srcShape).value();
|
||||
}
|
||||
Value equalToZero = rewriter.create<tosa::EqualOp>(op->getLoc(), destType,
|
||||
src, constOp.value());
|
||||
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
|
||||
equalToZero);
|
||||
} else {
|
||||
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Template instantiation
|
||||
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
|
|
|
@ -5,7 +5,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils
|
|||
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmeticDialect
|
||||
MLIRArithDialect
|
||||
MLIRLinalgDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
|
|
@ -2175,6 +2175,20 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
if (lConstant && rConstant)
|
||||
return getF64FloatAttr(getContext(), double(lhs) / rhs);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCeilFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -2185,8 +2199,6 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimMaxIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -631,6 +631,21 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.mv into: aten.matmul.
|
||||
namespace {
|
||||
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMvOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value lhs = op.self();
|
||||
Value rhs = op.vec();
|
||||
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
|
@ -2185,8 +2200,9 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value zero = rewriter.create<ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0.0));
|
||||
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
|
||||
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
|
||||
resultDtype);
|
||||
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||
op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(),
|
||||
op.device(), op.pin_memory(), op.memory_format());
|
||||
|
@ -2859,6 +2875,8 @@ public:
|
|||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
patterns.add<DecomposeAtenMvOp>(context);
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
patterns.add<DecomposeAtenTOp>(context);
|
||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||
|
|
|
@ -242,6 +242,16 @@ public:
|
|||
});
|
||||
}
|
||||
};
|
||||
|
||||
class VerifyBackendContractPass
|
||||
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
|
|||
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
|
||||
backendLegalOps);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::Torch::createVerifyBackendContractPass() {
|
||||
return std::make_unique<VerifyBackendContractPass>();
|
||||
}
|
||||
|
|
|
@ -11,17 +11,8 @@
|
|||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::torch::registerTorchPasses() {
|
||||
::registerPasses();
|
||||
mlir::torch::registerPasses();
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||
|
|
|
@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
|
||||
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
|
||||
AtenIndexTensorHackedTwinOp>(op)) {
|
||||
AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
|
@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
||||
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
@ -767,8 +767,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
// Promote the two dtypes assuming possibly-zero rank.
|
||||
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
||||
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp,
|
||||
AtenFloorDivideOp>(op)) {
|
||||
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||
AtenThresholdBackwardOp, AtenFloorDivideOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultType(
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||
AtenNarrowOp, AtenToDeviceOp>(op);
|
||||
}
|
||||
|
||||
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||
Location loc, float value,
|
||||
Type dtype) {
|
||||
// Creating constants satisfying backend contract.
|
||||
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
|
||||
dtype.isInteger(1))
|
||||
return rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((int64_t)value));
|
||||
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
|
||||
return rewriter.create<ConstantFloatOp>(loc,
|
||||
rewriter.getF64FloatAttr(value));
|
||||
llvm::report_fatal_error(
|
||||
"unhandled type for getConstantWithGivenDtypeAndValue");
|
||||
}
|
||||
|
|
|
@ -8,13 +8,13 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
|
|
|
@ -34,13 +34,13 @@ using namespace mlir::tosa;
|
|||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace reg {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
|
||||
} // end namespace
|
||||
} // end namespace reg
|
||||
|
||||
void mlir::torch::registerTorchConversionPasses() {
|
||||
::registerPasses();
|
||||
reg::registerPasses();
|
||||
mlir::PassPipelineRegistration<>(
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
@ -71,8 +71,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
// Basic scalar operations.
|
||||
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);
|
||||
|
||||
// Tensor operations should go through linalg and the tensor dialect.
|
||||
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -53,7 +53,7 @@ class VerifyMhloBackendContractPass
|
|||
target.addLegalDialect<mhlo::MhloDialect>();
|
||||
target.addLegalDialect<chlo::ChloDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
|
@ -304,7 +304,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<math::MathDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<arith::ArithDialect>();
|
||||
target.addIllegalOp<math::TanhOp>();
|
||||
target.addIllegalOp<math::ErfOp>();
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
|
@ -352,7 +352,7 @@ class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
|
|||
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *linalgCopy = createLinalgCopyOp(
|
||||
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
|
||||
rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget());
|
||||
rewriter.replaceOp(copyOp, linalgCopy->getResults());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ def run_pipeline_with_repro_report(module,
|
|||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
For Torch-MLIR developers, the error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
Add '{debug_options}' to get the IR dump for debugging purpose.
|
||||
"""
|
||||
|
|
|
@ -391,7 +391,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
|
|||
c10::optional<at::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return at::functionalization::
|
||||
functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(
|
||||
functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
|
||||
self, size, stride, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
|
@ -400,7 +400,7 @@ at::Tensor LazyNativeFunctions::narrow_copy_symint(
|
|||
int64_t dim,
|
||||
c10::SymInt start,
|
||||
c10::SymInt length) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
||||
narrow_copy)>::call(self, dim, start, length);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::pixel_shuffle(
|
||||
|
@ -426,7 +426,7 @@ at::Tensor LazyNativeFunctions::slice_backward_symint(
|
|||
c10::SymInt start,
|
||||
c10::SymInt end,
|
||||
c10::SymInt step) {
|
||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
||||
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
||||
}
|
||||
at::Tensor LazyNativeFunctions::diagonal_backward(
|
||||
|
|
|
@ -600,6 +600,9 @@ def aten〇numpy_T(self: List[int]) -> List[int]:
|
|||
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.matmul(self, other)
|
||||
|
||||
def aten〇mv(self: List[int], vec: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.mv(self, vec)
|
||||
|
||||
def aten〇mm(self: List[int], mat2: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.mm(self, mat2)
|
||||
|
||||
|
@ -863,6 +866,9 @@ def aten〇minimum(self: List[int], other: List[int]) -> List[int]:
|
|||
def aten〇maximum(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇bitwise_or〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
|
@ -1195,6 +1201,9 @@ def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
|
|||
def aten〇frobenius_norm〇dim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
|
||||
|
||||
def aten〇upsample_nearest2d〇vec(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
|
||||
return upstream_shape_functions.upsample_nearest2d(input, output_size, scale_factors)
|
||||
|
||||
# ==============================================================================
|
||||
# Shape library generator main().
|
||||
# ==============================================================================
|
||||
|
|
|
@ -34,6 +34,8 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"bool?": "AnyTorchOptionalBoolType",
|
||||
"float": "Torch_FloatType",
|
||||
"float?": "AnyTorchOptionalFloatType",
|
||||
"float[]": "AnyTorchListOfTorchFloatType",
|
||||
"float[]?": "AnyTorchOptionalListOfTorchFloatType",
|
||||
"t[]": "AnyTorchListType",
|
||||
"t": "AnyTorchType",
|
||||
"t1": "AnyTorchType",
|
||||
|
@ -285,6 +287,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::abs : (Tensor) -> (Tensor)",
|
||||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::square : (Tensor) -> (Tensor)",
|
||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||
|
@ -333,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
|
@ -515,6 +519,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
|
||||
|
||||
|
||||
# Dict ops.
|
||||
|
@ -566,6 +571,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
emit("aten::add.float_int : (float, int) -> (float)")
|
||||
|
|
|
@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.ListConstruct", loc,
|
||||
torchMlirTorchListTypeGet(
|
||||
getMlirTypeFromTorchType(loc, list.elementType())),
|
||||
getMlirTypeFromTorchType(loc, list.elementType(), importOptions)),
|
||||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.DictConstruct", loc,
|
||||
torchMlirTorchDictTypeGet(
|
||||
getMlirTypeFromTorchType(loc, dict.keyType()),
|
||||
getMlirTypeFromTorchType(loc, dict.valueType())),
|
||||
getMlirTypeFromTorchType(loc, dict.keyType(), importOptions),
|
||||
getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)),
|
||||
keys, values);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirOperation tensorOp;
|
||||
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
} else {
|
||||
tensorOp = createMlirOperationAtEnd(
|
||||
importBlock, "torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
}
|
||||
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
||||
// Construct the complete tensor value. This is trivial for most tensors, but
|
||||
|
@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
// compiler stages that are building a statically modeled quantization
|
||||
// representation will need to convert this to their representation.
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
MlirType quantizedTensorType;
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
quantizedTensorType = torchMlirTorchValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
} else {
|
||||
quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
|
||||
}
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
|
||||
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
|
||||
|
@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
|||
"name", mlirStringAttrGet(
|
||||
context, toMlirStringRef(classAttribute.getName()))),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
|
||||
loc, classAttribute.getType()))),
|
||||
loc, classAttribute.getType(), importOptions))),
|
||||
isPrivate);
|
||||
}
|
||||
|
||||
|
|
|
@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
|
|||
}
|
||||
|
||||
torch::jit::StrongFunctionPtr
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions) {
|
||||
ImportOptions importOptions;
|
||||
if (!maybeImportOptions.is_none()) {
|
||||
importOptions = py::cast<ImportOptions>(maybeImportOptions);
|
||||
}
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
|
||||
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||
return function;
|
||||
}
|
||||
|
@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) {
|
|||
.def(py::init<py::object>(), py::arg("context") = py::none())
|
||||
.def_property_readonly("context", &ModuleBuilder::getContextObj)
|
||||
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
|
||||
.def("import_function", &ModuleBuilder::importFunction)
|
||||
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
|
||||
py::arg("importOptions") = py::none())
|
||||
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
|
||||
py::arg("classAnnotator") = py::none(),
|
||||
py::arg("importOptions") = py::none());
|
||||
|
|
|
@ -39,7 +39,8 @@ public:
|
|||
// Just a bit of naming cruft.
|
||||
// Returns the same function, making it suitable as a nested decorator.
|
||||
torch::jit::StrongFunctionPtr
|
||||
importFunction(torch::jit::StrongFunctionPtr function);
|
||||
importFunction(torch::jit::StrongFunctionPtr function,
|
||||
py::object maybeImportOptions);
|
||||
|
||||
// Imports a torch::jit::Module into the current module, using the
|
||||
// annotations, if not none, provided in `maybeClassAnnotator` which should be
|
||||
|
|
|
@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
|||
c10::attr::value)))));
|
||||
} else if (output->type()->cast<c10::TensorType>()) {
|
||||
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
|
||||
op = createMlirOperation(
|
||||
"torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
if (importOptions.assumeTensorsHaveValueSemantics) {
|
||||
op = createMlirOperation(
|
||||
"torch.vtensor.literal", loc,
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
} else {
|
||||
op = createMlirOperation(
|
||||
"torch.tensor.literal", loc,
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
|
||||
toMlirNamedAttribute("value", attr));
|
||||
}
|
||||
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.device", loc,
|
||||
|
|
|
@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
class ToCopyBoolDTypeStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 5, 5], torch.uint8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.bool)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule())
|
||||
def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
|
@ -705,3 +705,80 @@ class Conv_Transpose3dModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
|
||||
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2))
|
||||
|
||||
class UpSampleNearest2dSameSize(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[11, 11],
|
||||
scale_factors=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dSameSize())
|
||||
def UpSampleNearest2dStaticSize_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 4, 4))
|
||||
|
||||
|
||||
class UpSampleNearest2dDiffSize(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[8, 11],
|
||||
scale_factors=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffSize())
|
||||
def UpSampleNearest2dDynamicSize_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 2, 2))
|
||||
|
||||
|
||||
class UpSampleNearest2dDiffFactor(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=None,
|
||||
scale_factors=[2.3, 4.7])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffFactor())
|
||||
def UpSampleNearest2dDynamicFactor_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 2, 2))
|
||||
|
||||
|
||||
class UpSampleNearest2dSameFactor(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=None,
|
||||
scale_factors=[2.0, 2.0])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor())
|
||||
def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4, 4))
|
||||
|
|
|
@ -1553,6 +1553,31 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseOrIntegerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.bitwise_or(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseOrIntegerModule())
|
||||
def ElementwiseOrIntegerModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
|
||||
tu.randint(3, 4, low=-10, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNotIntegerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
|||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Mv(torch.nn.Module):
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, m, v):
|
||||
return torch.mv(m, v)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mv())
|
||||
def Mv_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
|
@ -104,6 +104,31 @@ def MulIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class DivIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
# Cast the result to float to make e2e test baseline result to be a float.
|
||||
# Without the cast, baseline result is a Tensor which is unexpected.
|
||||
return float(torch.ops.aten.div(int(lhs), int(rhs)))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DivIntModule())
|
||||
def DivIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(low=-10, high=10), tu.randint(low=3, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DivFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
||||
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
@annotate_args([None, ([3, 5], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.bool,
|
||||
|
@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
|||
memory_format=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
|
||||
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
|
||||
def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5))
|
||||
|
||||
|
||||
class TypeAsSameModule(torch.nn.Module):
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torch==1.13.0.dev20220930
|
||||
torch==1.13.0.dev20221004
|
||||
|
|
|
@ -1 +1 @@
|
|||
3bf7094ddb95e8b9bcd2b1f35589e729aa2f4248
|
||||
9f3d8fec5747fde5191618eb895fbec2d50edf93
|
||||
|
|
|
@ -847,3 +847,69 @@ func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
|
|||
%0 = torch.aten.arange.start_step %int0, %int5, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
|
||||
return %0 : !torch.vtensor<[5],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.valsem.aten.copy(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
|
||||
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8>
|
||||
// CHECK: %[[CST5:.*]] = torch.constant.int 5
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1>
|
||||
func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
|
||||
%int5 = torch.constant.int 5
|
||||
%int1 = torch.constant.int 1
|
||||
%int11 = torch.constant.int 11
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
|
||||
%2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],i1>
|
||||
%4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1>
|
||||
return %4 : !torch.vtensor<[1,1,5,5],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
|
||||
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64>
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1>
|
||||
func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
|
||||
%int11 = torch.constant.int 11
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
|
||||
return %0 : !torch.vtensor<[3,5],i1>
|
||||
}
|
||||
|
|
|
@ -1351,6 +1351,16 @@ func.func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
|
|||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.int$fold_cst_operands(
|
||||
// CHECK: %[[CST:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: return %[[CST]] : !torch.float
|
||||
func.func @torch.aten.div.int$fold_cst_operands() -> !torch.float {
|
||||
%int2 = torch.constant.int 2
|
||||
%int4 = torch.constant.int 4
|
||||
%0 = torch.aten.div.int %int2, %int4 : !torch.int, !torch.int -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$same_dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ones_i32 = torch.ones(1, dtype=torch.int32)
|
||||
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
||||
# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||
# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32>
|
||||
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8>
|
||||
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
|
||||
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8>
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32>
|
||||
# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32>
|
||||
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8>
|
||||
# CHECK: }
|
||||
test_module = TestModule()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
|
||||
import_options = ImportOptions()
|
||||
import_options.assumeTensorsHaveValueSemantics = True
|
||||
|
||||
class_annotator = ClassAnnotator()
|
||||
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator, import_options)
|
||||
mb.module.operation.print()
|
|
@ -5,7 +5,7 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
||||
from utils import create_script_function
|
||||
|
||||
|
@ -162,3 +162,34 @@ graph():
|
|||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number {
|
||||
# CHECK: %[[A:.*]] = torch.tensor.literal
|
||||
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||
# CHECK: return %[[RET]] : !torch.number
|
||||
import_options = ImportOptions()
|
||||
import_options.assumeTensorsHaveValueSemantics = False
|
||||
mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """
|
||||
graph():
|
||||
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||
return (%1)
|
||||
""", parse_tensor_constants=True), import_options)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
||||
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number {
|
||||
# CHECK: %[[A:.*]] = torch.vtensor.literal
|
||||
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
|
||||
# CHECK: return %[[RET]] : !torch.number
|
||||
import_options.assumeTensorsHaveValueSemantics = True
|
||||
mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """
|
||||
graph():
|
||||
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
|
||||
%1 : Scalar = aten::ScalarImplicit(%0)
|
||||
return (%1)
|
||||
""", parse_tensor_constants=True), import_options)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -10,6 +10,6 @@ from torch._C import CompilationUnit
|
|||
# RUN: %PYTHON %s
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str):
|
||||
def create_script_function(func_name, ts_ir_str, **kwargs):
|
||||
cu = CompilationUnit()
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs))
|
||||
|
|
|
@ -337,7 +337,7 @@ cc_library(
|
|||
strip_include_prefix = "include",
|
||||
deps = [
|
||||
":TorchMLIRTorchDialect",
|
||||
"@llvm-project//mlir:ArithmeticDialect",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ControlFlowDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
|
@ -373,7 +373,7 @@ cc_library(
|
|||
":TorchMLIRTorchBackendTypeConversion",
|
||||
":TorchMLIRTorchConversionDialect",
|
||||
":TorchMLIRTorchDialect",
|
||||
"@llvm-project//mlir:ArithmeticDialect",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ControlFlowDialect",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
|
@ -397,7 +397,7 @@ cc_library(
|
|||
":TorchMLIRConversionPassesIncGen",
|
||||
":TorchMLIRTorchBackendTypeConversion",
|
||||
":TorchMLIRTorchConversionDialect",
|
||||
"@llvm-project//mlir:ArithmeticDialect",
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:ControlFlowDialect",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
|
@ -809,7 +809,7 @@ cc_library(
|
|||
":TorchMLIRRefBackendPassIncGen",
|
||||
":TorchMLIRTorchBackendTypeConversion",
|
||||
":TorchMLIRTorchConversionDialect",
|
||||
"@llvm-project//mlir:ArithmeticTransforms",
|
||||
"@llvm-project//mlir:ArithTransforms",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MathTransforms",
|
||||
|
|
Loading…
Reference in New Issue