Merge branch 'llvm:main' into windows-autogen_ltc_backend.py

pull/1310/head
Ryuta Suzuki 2022-10-11 06:39:21 +09:00 committed by GitHub
commit 65ca4c8454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
87 changed files with 1272 additions and 1098 deletions

View File

@ -1,8 +1,6 @@
name: Roll PyTorch
on:
schedule:
- cron: '0 12 * * *'
workflow_dispatch:
jobs:
@ -34,8 +32,8 @@ jobs:
python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torch==${PT_RELEASE}"
# Read the commit hash from the downloaded whl file without extracting it
PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'")
echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet
PT_HASH_CHANGED=$?
PT_HASH_CHANGED=0
echo "${PT_HASH}" | cmp - pytorch-version.txt --quiet || PT_HASH_CHANGED=$?
echo "${PT_HASH}" > pytorch-version.txt
rm torch-"${PT_RELEASE}"*.whl
# Write the release and hash to the environment file so that we can
@ -44,13 +42,14 @@ jobs:
echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV}
echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV}
- name: Build and test
- name: Build and test (in-tree), also update ODS and shape library
if: env.PT_HASH_CHANGED != '0'
run: |
cd ${GITHUB_WORKSPACE}
TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \
TM_PACKAGES="in-tree" TM_USE_PYTORCH_BINARY="OFF" \
TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \
TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \
TM_UPDATE_ODS_AND_SHAPE_LIB="ON" \
./build_tools/python_deploy/build_linux_packages.sh
- name: Push changes to main branch
@ -61,5 +60,5 @@ jobs:
git config user.name "Roll PyTorch Action"
git fetch --recurse-submodules=no
git checkout main
git add pytorch-version.txt pytorch-requirements.txt
git add pytorch-version.txt pytorch-requirements.txt lib/Dialect/Torch/Transforms/ShapeLibrary.cpp include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
git diff --cached --exit-code || (git commit -m "update PyTorch version to ${{ env.PT_RELEASE }}" && git push --set-upstream origin main)

View File

@ -61,7 +61,7 @@ jobs:
if: ${{ matrix.os-arch == 'ubuntu-x86_64' }}
run: |
cd $GITHUB_WORKSPACE
TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh
TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-version.txt)" TM_PACKAGES="${{ matrix.llvm-build }}" TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" ./build_tools/python_deploy/build_linux_packages.sh
- name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}'
# cross compile, can't test arm64
if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }}

View File

@ -94,6 +94,15 @@ jobs:
with:
release_id: ${{ github.event.inputs.release_id }}
publish_releases:
needs:
- build_linux
- build_macos
# Publish even if one of the builds failed
if: ${{ always() }}
steps:
- name: Invoke Publish Releases Page
uses: benc-uk/workflow-dispatch@v1
with:

View File

@ -12,7 +12,6 @@ blacklist:
- detach
- item
- size
- where
- copy_
# Disabled for consistency with TS backend

View File

@ -58,7 +58,8 @@ checkout_pytorch() {
git reset --hard FETCH_HEAD
else
cd "${PYTORCH_ROOT}"
git reset --hard HEAD
git fetch --depth=1 origin "${TORCH_MLIR_SRC_PYTORCH_BRANCH}"
git reset --hard FETCH_HEAD
fi
git clean -df
git submodule update --init --depth 1 --recursive

View File

@ -53,14 +53,16 @@ TM_PACKAGES="${TM_PACKAGES:-torch-mlir}"
TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}"
# Skip running tests if you want quick iteration
TM_SKIP_TESTS="${TM_SKIP_TESTS:-OFF}"
# Update ODS and shape library files
TM_UPDATE_ODS_AND_SHAPE_LIB="${TM_UPDATE_ODS_AND_SHAPE_LIB:-OFF}"
PKG_VER_FILE="${repo_root}"/torch_mlir_package_version ; [ -f "$PKG_VER_FILE" ] && . "$PKG_VER_FILE"
TORCH_MLIR_PYTHON_PACKAGE_VERSION="${TORCH_MLIR_PYTHON_PACKAGE_VERSION:-0.0.1}"
echo "Setting torch-mlir Python Package version to: ${TORCH_MLIR_PYTHON_PACKAGE_VERSION}"
TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}"
export TORCH_MLIR_SRC_PYTORCH_REPO="${TORCH_MLIR_SRC_PYTORCH_REPO:-pytorch/pytorch}"
echo "Setting torch-mlir PyTorch Repo for source builds to: ${TORCH_MLIR_SRC_PYTORCH_REPO}"
TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}"
export TORCH_MLIR_SRC_PYTORCH_BRANCH="${TORCH_MLIR_SRC_PYTORCH_BRANCH:-master}"
echo "Setting torch-mlir PyTorch version for source builds to: ${TORCH_MLIR_SRC_PYTORCH_BRANCH}"
function run_on_host() {
@ -109,6 +111,7 @@ function run_on_host() {
-e "TM_PYTHON_VERSIONS=${TM_PYTHON_VERSIONS}" \
-e "TM_PACKAGES=${package}" \
-e "TM_SKIP_TESTS=${TM_SKIP_TESTS}" \
-e "TM_UPDATE_ODS_AND_SHAPE_LIB=${TM_UPDATE_ODS_AND_SHAPE_LIB}" \
-e "TM_USE_PYTORCH_BINARY=${TM_USE_PYTORCH_BINARY}" \
-e "TORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO}" \
-e "TORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH}" \
@ -152,6 +155,12 @@ function run_in_docker() {
in-tree)
setup_venv "$python_version"
build_in_tree "$TM_USE_PYTORCH_BINARY" "$python_version"
if [ "${TM_UPDATE_ODS_AND_SHAPE_LIB}" == "ON" ]; then
pushd /main_checkout/torch-mlir
./build_tools/update_torch_ods.sh
./build_tools/update_shape_lib.sh
popd
fi
if [ "${TM_SKIP_TESTS}" == "OFF" ]; then
test_in_tree;
fi

View File

@ -28,23 +28,25 @@ Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is
### Building torch-mlir in-tree
The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects.
The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. On Windows, use the "Developer PowerShell for Visual Studio" to ensure that the compiler and linker binaries are in the `PATH` variable.
```shell
cmake -GNinja -Bbuild \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DPython3_FIND_VIRTUALENV=ONLY \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_EXTERNAL_PROJECTS="torch-mlir;torch-mlir-dialects" \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=`pwd` \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR=`pwd`/externals/llvm-external-projects/torch-mlir-dialects \
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="$PWD"/externals/llvm-external-projects/torch-mlir-dialects \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_TARGETS_TO_BUILD=host \
externals/llvm-project/llvm
```
The following additional quality of life flags can be used to reduce build time:
* Enabling clang on Linux
```shell
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
```
* Enabling ccache:
```shell
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
@ -73,8 +75,6 @@ If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`,
```shell
cmake -GNinja -Bbuild \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DPython3_FIND_VIRTUALENV=ONLY \
-DMLIR_DIR="$LLVM_INSTALL_DIR/lib/cmake/mlir/" \
-DLLVM_DIR="$LLVM_INSTALL_DIR/lib/cmake/llvm/" \
@ -82,7 +82,7 @@ cmake -GNinja -Bbuild \
-DLLVM_TARGETS_TO_BUILD=host \
.
```
The same QoL CMake flags can be used to enable ccache and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`.
The same QoL CMake flags can be used to enable clang, ccache, and lld. Be sure to have built LLVM with `-DLLVM_ENABLE_PROJECTS=mlir`.
Be aware that the installed version of LLVM needs in general to match the committed version in `externals/llvm-project`. Using a different version may or may not work.
@ -105,15 +105,25 @@ cmake --build build
```
## Setup Python Environment to export the built Python packages
### Linux and macOS
```shell
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples
```
### Windows PowerShell
```shell
$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/examples"
```
## Testing MLIR output in various dialects
To test the compiler's output to the different MLIR dialects, you can use the example `examples/torchscript_resnet18_all_output_types.py`.
Make sure you have activated the virtualenv and set the `PYTHONPATH` above:
Make sure you have activated the virtualenv and set the `PYTHONPATH` above
(if running on Windows, modify the environment variable as shown above):
```shell
source mlir_venv/bin/activate
export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/examples

View File

@ -463,6 +463,10 @@ TOSA_PASS_SET = {
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"NumToTensorIntModule_basic",
"ToDtypeBoolLayoutNoneStaticModule_basic",
"ToCopyBoolDTypeStaticModule_basic",
"HardTanhIntModule_basic",
}
LTC_XFAIL_SET = {
@ -569,8 +573,8 @@ LTC_XFAIL_SET = {
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"LiftFreshCopyModule_basic",
"Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic",
"DivIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
@ -632,4 +636,8 @@ LTC_XFAIL_SET = {
"ElementwiseRemainderScalarModule_Bool_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
}

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
@ -134,7 +134,7 @@ struct TMTensorBufferizePass
bufferization::BufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,
memref::MemRefDialect, tensor::TensorDialect>();
// Mark all TMTensor operations illegal as long as they work on tensors.

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -101,7 +101,7 @@ namespace {
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, func::FuncDialect,
mlir::arith::ArithmeticDialect, math::MathDialect,
mlir::arith::ArithDialect, math::MathDialect,
memref::MemRefDialect, scf::SCFDialect>();
}

View File

@ -1,5 +1,5 @@
set(LIBS
MLIRArithmeticDialect
MLIRArithDialect
MLIRDialect
MLIRLinalgDialect
MLIRMemRefDialect

View File

@ -7,12 +7,12 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
@ -39,7 +39,7 @@ int main(int argc, char **argv) {
// Local dialects
mlir::torch::TMTensor::TMTensorDialect,
// Upstream dialects
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
mlir::arith::ArithDialect, mlir::linalg::LinalgDialect,
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();

@ -1 +1 @@
Subproject commit bebc96956b76bdbc36f1d82a788c810e5b12e2c5
Subproject commit 6f46ff3765dcdc178b9cf52ebd8c03437806798a

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 7b0ecf7827e3fc07d2af90e147bcedc165bc78ac
Subproject commit 2f7c1454bbe4c4ad0ae1c86c5539ac58b6053b6a

View File

@ -204,6 +204,10 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets the !torch.vtensor type with the tensor attribute.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
//===----------------------------------------------------------------------===//
// !torch.none type.
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,22 @@
//===-- torch-mlir-c/Transforms.h - C API for torch passes --------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the registration and creation method for
// transformation passes.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_C_TRANSFORMS_H
#define TORCHMLIR_C_TRANSFORMS_H
#include "mlir-c/Support.h"
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
#endif // TORCHMLIR_C_TRANSFORMS_H

View File

@ -53,6 +53,9 @@ template <typename T>
llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape);
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result);
// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>

View File

@ -2192,6 +2192,53 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
}];
}
def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
AllowsTypeRefinement,
HasValueSemantics,
@ -3387,6 +3434,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
}];
}
def Torch_AtenMvOp : Torch_Op<"aten.mv", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$vec
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMvOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7399,6 +7470,31 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
}];
}
def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$output_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
@ -8356,6 +8452,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
let hasFolder = 1;
}
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::div.int : (int, int) -> (float)`";
let arguments = (ins
Torch_IntType:$a,
Torch_IntType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDivIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenDivIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenNegIntOp : Torch_Op<"aten.neg.int", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -383,6 +383,7 @@ class ListOf<list<Type> allowedTypes, string descr> :
def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
def AnyTorchListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">;
def AnyTorchListOfTorchFloatType : ListOf<[Torch_FloatType], "Float list type (float[])">;
def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">;
def AnyTorchListOfTensorType:
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
@ -390,7 +391,7 @@ def AnyTorchListOfOptionalTensorType :
ListOf<[AnyTorchOptionalTensorType],
"Any optional tensor list type (Tensor?[])">;
def AnyTorchOptionalListOfTorchIntType : OptionalOf<AnyTorchListOfTorchIntType, "Optional torch int list type (int[]?)">;
def AnyTorchOptionalListOfTorchFloatType : OptionalOf<AnyTorchListOfTorchFloatType, "Optional torch float list type (float[]?)">;
// Note: TorchScript does not consider !torch.bool to be a Scalar.
def AnyTorchScalarType :
Type<CPred<"isValidSubtype($_self, ::mlir::torch::Torch::NumberType::get($_self.getContext()))">,

View File

@ -1,5 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header)
mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl)
add_public_tablegen_target(TorchMLIRTorchPassIncGen)
add_mlir_doc(Passes TorchMLIRTorchTransforms ./ -gen-pass-doc)

View File

@ -21,6 +21,8 @@ class ModuleOp;
namespace torch {
namespace Torch {
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
std::unique_ptr<OperationPass<ModuleOp>>
@ -109,6 +111,8 @@ std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyBackendContractPass();
StringRef getShapeLibrary();
} // namespace Torch
@ -116,6 +120,13 @@ StringRef getShapeLibrary();
/// Registers all Torch transformation passes.
void registerTorchPasses();
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace torch
} // namespace mlir

View File

@ -329,4 +329,16 @@ def LowerToBackendContract
let dependentDialects = ["func::FuncDialect"];
}
def VerifyBackendContract
: Pass<"torch-verify-backend-contract", "ModuleOp"> {
let summary = "Check that program satisfies backend contract.";
let constructor =
"mlir::torch::Torch::createVerifyBackendContractPass()";
let description = [{
This pass performs a set of inspections to check that program satisfies backend
contract. In case of check failure it prints out the error message and returns
`signalPassFailure()` status.
}];
}
#endif // TORCHMLIR_TORCH_PASSES

View File

@ -52,6 +52,9 @@ int getTensorRank(Value tensor);
bool isViewLikeOp(Operation *op);
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
float value, Type dtype);
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -3,6 +3,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
Registration.cpp
TorchOps.cpp
TorchTypes.cpp
Transforms.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
@ -16,6 +17,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
MLIRSupport
TorchMLIRTorchDialect
TorchMLIRInitAll
TorchMLIRTorchPasses
)
torch_mlir_target_includes(TorchMLIRCAPI)

View File

@ -246,6 +246,14 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
}
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}
//===----------------------------------------------------------------------===//
// torch.none type.
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,27 @@
//===- CAPIPasses.cpp - C API for Transformations Passes ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/CAPI/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
// Must include the declarations as they carry important visibility attributes.
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
#ifdef __cplusplus
extern "C" {
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.cpp.inc"
#ifdef __cplusplus
}
#endif

View File

@ -10,7 +10,7 @@
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -121,6 +121,25 @@ public:
};
} // namespace
namespace {
class ConvertAtenDivIntOp : public OpConversionPattern<AtenDivIntOp> {
public:
using OpConversionPattern<AtenDivIntOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDivIntOp op,
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value a =
convertScalarToDtype(rewriter, loc, adaptor.a(), rewriter.getF64Type());
Value b =
convertScalarToDtype(rewriter, loc, adaptor.b(), rewriter.getF64Type());
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
return success();
}
};
} // namespace
namespace {
// Lowers aten integer comparison ops.
template <typename AtenOp, arith::CmpIPredicate Pred>
@ -300,7 +319,7 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<cf::ControlFlowDialect>();
registry.insert<math::MathDialect>();
@ -311,7 +330,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
arith::ArithmeticDialect, tensor::TensorDialect,
arith::ArithDialect, tensor::TensorDialect,
cf::ControlFlowDialect, math::MathDialect>();
TypeConverter typeConverter;
@ -374,6 +393,8 @@ public:
target.addIllegalOp<AtenSubFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
target.addIllegalOp<AtenDivIntOp>();
patterns.add<ConvertAtenDivIntOp>(typeConverter, context);
target.addIllegalOp<AtenDivFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
typeConverter, context);

View File

@ -16,7 +16,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -252,11 +252,11 @@ public:
llvm::all_of(expandShape,
[](int64_t value) { return value == kUnknownSize; })) {
for (int i = 0; i < collapseShape.size(); i++) {
for (size_t i = 0; i < collapseShape.size(); i++) {
collapseIndices.push_back(i);
}
for (int i = 0; i < expandShape.size(); i++) {
for (size_t i = 0; i < expandShape.size(); i++) {
expandIndices.push_back(i);
}
@ -290,8 +290,8 @@ public:
op, "total number of elements mismatch in the expansion");
}
static LogicalResult solveDynamicSize(SmallVector<int64_t> &inputShape,
SmallVector<int64_t> &outputShape) {
static void solveDynamicSize(SmallVector<int64_t> &inputShape,
SmallVector<int64_t> &outputShape) {
int64_t inputProduct = 1;
int64_t outputProduct = 1;
@ -316,7 +316,7 @@ public:
if (inputDynamicValues + outputDynamicValues == 1) {
if (inputDynamicValues) {
int64_t missingValue = outputProduct / inputProduct;
for (int i = 0; i < inputShape.size(); i++) {
for (size_t i = 0; i < inputShape.size(); i++) {
if (inputShape[i] == -1) {
inputShape[i] = missingValue;
break;
@ -324,7 +324,7 @@ public:
}
} else {
int64_t missingValue = inputProduct / outputProduct;
for (int i = 0; i < outputShape.size(); i++) {
for (size_t i = 0; i < outputShape.size(); i++) {
if (outputShape[i] == -1) {
outputShape[i] = missingValue;
break;
@ -332,8 +332,6 @@ public:
}
}
}
return success();
}
LogicalResult
@ -625,9 +623,6 @@ public:
}
}
int64_t inputCount = inputAssociations.size();
int64_t outputCount = outputAssociations.size();
// Check if the shapes already match up to dynamic sizes. If so, we can just
// cast as the result type because the previous loop sets up the necessary
// dim checks in case of dynamic sizes.

View File

@ -12,9 +12,10 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -755,6 +756,146 @@ public:
};
} // namespace
// `getScaledDims` scales the `dim` value with a scale factor `ScaleFactor`.
// The `dim` and `scaleFactor` are assumed to be of index and float type
// respectively. `scaledDim = int(floor(float(dim) * scaleFactor))`.
static Value getScaledDims(OpBuilder &builder, Location loc, Value dim,
Value scaleFactor) {
Value dimInt = castIndexToInt64(builder, loc, dim);
Value dimFp =
builder.create<arith::SIToFPOp>(loc, scaleFactor.getType(), dimInt);
Value scaleDim = builder.create<arith::MulFOp>(loc, dimFp, scaleFactor);
Value floorDim = builder.create<math::FloorOp>(loc, scaleDim);
Value scaledDimToIndex = castIntToIndex(
builder, loc,
builder.create<arith::FPToSIOp>(loc, dimInt.getType(), floorDim));
return scaledDimToIndex;
}
// `getScaleFactor` returns the scale factor from input to output dimension.
// The `dim` and `scaledDim` are assumed to be of index and int64 type
// respectively. scale_factor = (scaled_dim // dim).
static Value getScaleFactor(OpBuilder &builder, Location loc, Value dim,
Value scaledDim) {
Value dimInt = castIndexToInt64(builder, loc, dim);
Value scaleFactorInt =
builder.create<arith::CeilDivSIOp>(loc, scaledDim, dimInt);
return scaleFactorInt;
}
// N, C, H, W = input_tensor.shape
// N, C, H_scaled, W_scaled = out_tensor.shape
// H_factor, W_factor = H_scaled/H, W_scaled/W
// for i in range(N):
// for j in range(C):
// for k in range(H_scaled):
// for l in range(W_scaled):
// out_tensor[i, j, k, l] = input[i, j, k//H_factor, l//W_factor]
namespace {
class ConvertAtenUpsampleNearest2dVecOp
: public OpConversionPattern<AtenUpsampleNearest2dVecOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUpsampleNearest2dVecOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value input = adaptor.input();
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
Type elementType = inputType.getElementType();
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
SmallVector<Value, 2> scaleFactorsInt;
// The dimension at which the scaling starts.
unsigned hDimOffset = 2;
if (!adaptor.scale_factors().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> scaleFactorsTorchFloat;
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the scale_factors is not constructed from "
"ListConstruct");
SmallVector<Value, 2> scaleFactorsFloatValues;
scaleFactorsFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
// Convert float values to int values.
// int_value = (int64_t)ceil(float_value)
for (auto floatValue : scaleFactorsFloatValues) {
Value ceilVal = rewriter.create<math::CeilOp>(loc, floatValue);
Value intVal = rewriter.create<arith::FPToSIOp>(
loc, rewriter.getI64Type(), ceilVal);
scaleFactorsInt.push_back(intVal);
}
for (unsigned i = 0; i < scaleFactorsFloatValues.size(); i++)
dims[hDimOffset + i] = getScaledDims(
rewriter, loc, dims[hDimOffset + i], scaleFactorsFloatValues[i]);
} else {
SmallVector<Value, 2> outputSizeTorchInt;
if (!getListConstructElements(op.output_size(), outputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
SmallVector<Value, 2> outputSizeIntValues;
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
for (unsigned i = 0; i < outputSizeTorchInt.size(); i++) {
auto scaleFactorVal = getScaleFactor(
rewriter, loc, dims[hDimOffset + i], outputSizeIntValues[i]);
scaleFactorsInt.push_back(scaleFactorVal);
dims[hDimOffset + i] =
castIntToIndex(rewriter, loc, outputSizeIntValues[i]);
}
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, dims, elementType);
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<StringRef> iteratorTypes(inputRank,
getParallelIteratorTypeName());
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{}, outTensor,
/*indexingMaps=*/idMap,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++)
indices.push_back(b.create<linalg::IndexOp>(loc, i));
for (unsigned i = 0; i < (inputRank - hDimOffset); i++)
indices[i + hDimOffset] = b.create<arith::FloorDivSIOp>(
loc, indices[i + hDimOffset],
castIntToIndex(rewriter, loc, scaleFactorsInt[i]));
Value retVal =
b.create<tensor::ExtractOp>(loc, input, indices);
b.create<linalg::YieldOp>(loc, retVal);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::
populateIndirectDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
@ -770,4 +911,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
}

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -152,12 +152,10 @@ public:
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));
Value predicate;
if (inElementType.isa<mlir::FloatType>())
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
auto resultMax = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newValue, oldValue);
auto resultMax = rewriter.create<arith::MaxFOp>(
nestedLoc, newValue, oldValue);
Value predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"

View File

@ -11,7 +11,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -11,7 +11,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -43,7 +43,7 @@ public:
registry.insert<math::MathDialect>();
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<cf::ControlFlowDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -53,7 +53,7 @@ public:
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithmeticDialect>();
tensor::TensorDialect, arith::ArithDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();
TypeConverter typeConverter;

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -199,6 +199,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
if (auto bitwiseOrTensor = dyn_cast<AtenBitwiseOrTensorOp>(op)) {
if (bitwiseOrTensor.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
bitwiseOrTensor.emitError(
"Bitwise_Or does not support floating point dtype");
return nullptr;
}
Type dtype = converter->convertType(bitwiseOrTensor.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs);
}
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
@ -1006,12 +1022,12 @@ public:
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenBitwiseNotOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@ -1483,10 +1499,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenBitwiseOrTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenRemainderScalarOp, AtenBitwiseNotOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);

View File

@ -11,7 +11,7 @@
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -14,7 +14,7 @@
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

View File

@ -13,7 +13,7 @@
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -13,7 +13,7 @@
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

View File

@ -9,7 +9,7 @@
#include "./MhloLegalizeUtils.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

View File

@ -13,7 +13,7 @@
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

View File

@ -13,7 +13,7 @@
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -12,7 +12,7 @@
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
@ -42,14 +42,14 @@ public:
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect,
tensor::TensorDialect, arith::ArithmeticDialect>();
tensor::TensorDialect, arith::ArithDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

View File

@ -13,7 +13,7 @@
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -10,7 +10,7 @@
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
@ -321,7 +321,7 @@ namespace {
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, arith::ArithmeticDialect>();
registry.insert<scf::SCFDialect, arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -329,7 +329,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect,
arith::ArithmeticDialect>();
arith::ArithDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

View File

@ -10,7 +10,7 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -614,7 +614,7 @@ public:
registry.insert<linalg::LinalgDialect>();
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<TMTensorDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -623,7 +623,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
tensor::TensorDialect, arith::ArithmeticDialect,
tensor::TensorDialect, arith::ArithDialect,
Torch::TorchDialect, TMTensorDialect>();
TypeConverter typeConverter;

View File

@ -12,7 +12,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
@ -3055,6 +3055,133 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
// Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the
// tosa.
int64_t initValue;
if (!matchPattern(op.a(), m_TorchConstantInt(&initValue)))
return rewriter.notifyMatchFailure(
op, "unimplemented: input should be a torch constant int");
DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
return success();
}
template <>
LogicalResult ConvertAtenOp<ValsemVariantAtenCopyOp>::matchAndRewrite(
ValsemVariantAtenCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
auto srcType = adaptor.src().getType().dyn_cast<TensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!srcType || !srcType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
// The non_blocking should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking is expected to be false");
}
SmallVector<int64_t> selfShape(selfType.getShape());
SmallVector<int64_t> srcShape(srcType.getShape());
if (llvm::equal(selfShape, srcShape) || selfShape.size() == 0) {
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
Value result;
if (failed(tosa::tosaCastTensorToType(
rewriter, op, adaptor.src(),
getTypeConverter()->convertType(op.getType()), result)))
return rewriter.notifyMatchFailure(
op, "unimplemented: cast to result type not supported");
rewriter.replaceOp(op, result);
return success();
}
return rewriter.notifyMatchFailure(
op, "unimplemented: valsem.aten.copy op not supported for this case.");
}
// Legalizes the torch.aten.to.dtype op
template <>
LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
AtenToDtypeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
// The non_blocking arg should be a constant `False`.
bool nonBlocking;
if (!matchPattern(op.non_blocking(), m_TorchConstantBool(&nonBlocking))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking arg must be a constant");
} else if (nonBlocking) {
return rewriter.notifyMatchFailure(
op, "unimplemented: non_blocking arg is expected to be false");
}
// The copy arg should be a constant `False`.
bool copy;
if (!matchPattern(op.copy(), m_TorchConstantBool(&copy))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: copy arg must be a constant");
} else if (copy) {
return rewriter.notifyMatchFailure(
op, "unimplemented: copy arg is expected to be false");
}
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.memory_format().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::Preserve)
return rewriter.notifyMatchFailure(
op, "unimplemented: only none, contiguous and preserve "
"memory_format is supported");
}
auto resultTy = getTypeConverter()
->convertType(op.getResult().getType())
.cast<RankedTensorType>();
Value result;
if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.self(), resultTy,
result)))
return rewriter.notifyMatchFailure(op, "conversion to result type failed");
rewriter.replaceOp(op, result);
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -3511,7 +3638,7 @@ public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tosa::TosaDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<arith::ArithDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -3519,7 +3646,7 @@ public:
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<tosa::TosaDialect, tensor::TensorDialect,
arith::ArithmeticDialect>();
arith::ArithDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
@ -3704,6 +3831,9 @@ public:
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -221,6 +221,64 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}
static LogicalResult checkValidityOfCast(Type src, Type dest) {
if ((src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isF32() && dest.isInteger(1))) {
return success();
}
return failure();
}
// Template specialization for float
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) {
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
Type destElemTy = destType.dyn_cast<TensorType>().getElementType();
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
op, "casting to result dtype is invalid or unsupported");
if (destElemTy.isInteger(1)) {
auto srcType = src.getType().dyn_cast<TensorType>();
SmallVector<int64_t> srcShape(srcType.getShape());
uint64_t num_total_elements = 1;
for (int64_t a : srcShape)
num_total_elements *= a;
llvm::Optional<Value> constOp;
if (srcElemTy.isInteger(64)) {
SmallVector<int64_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int64_t>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isInteger(32)) {
SmallVector<int32_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int32_t>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isF32()) {
SmallVector<float> values(num_total_elements, 0.0);
constOp =
tosa::getConstTensor<float>(rewriter, op, values, srcShape).value();
} else if (srcElemTy.isInteger(8)) {
SmallVector<int8_t> values(num_total_elements, 0);
constOp =
tosa::getConstTensor<int8_t>(rewriter, op, values, srcShape).value();
}
Value equalToZero = rewriter.create<tosa::EqualOp>(op->getLoc(), destType,
src, constOp.value());
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
equalToZero);
} else {
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
}
return success();
}
// Template instantiation
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,

View File

@ -5,7 +5,7 @@ add_mlir_conversion_library(TorchMLIRConversionUtils
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils
LINK_LIBS PUBLIC
MLIRArithmeticDialect
MLIRArithDialect
MLIRLinalgDialect
TorchMLIRTorchDialect
)

View File

@ -9,7 +9,7 @@
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@ -2175,6 +2175,20 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenDivIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
if (lConstant && rConstant)
return getF64FloatAttr(getContext(), double(lhs) / rhs);
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//
@ -2185,8 +2199,6 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// PrimMaxIntOp
//===----------------------------------------------------------------------===//

View File

@ -631,6 +631,21 @@ public:
};
} // namespace
// Decompose aten.mv into: aten.matmul.
namespace {
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMvOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
Value rhs = op.vec();
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
return success();
}
};
} // namespace
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
@ -2185,8 +2200,9 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0.0));
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0,
resultDtype);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
op.getLoc(), op.getType(), op.self(), zero, op.dtype(), op.layout(),
op.device(), op.pin_memory(), op.memory_format());
@ -2859,6 +2875,8 @@ public:
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addIllegalOp<AtenMvOp>();
patterns.add<DecomposeAtenMvOp>(context);
target.addIllegalOp<AtenTOp>();
patterns.add<DecomposeAtenTOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);

View File

@ -242,6 +242,16 @@ public:
});
}
};
class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
public:
void runOnOperation() override {
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
@ -250,3 +260,8 @@ mlir::torch::Torch::createLowerToBackendContractPass(
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
backendLegalOps);
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass() {
return std::make_unique<VerifyBackendContractPass>();
}

View File

@ -11,17 +11,8 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // end namespace
void mlir::torch::registerTorchPasses() {
::registerPasses();
mlir::torch::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.",

View File

@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp,
AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp,
AtenIndexTensorHackedTwinOp>(op)) {
AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}
@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
@ -767,8 +767,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming possibly-zero rank.
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp,
AtenFloorDivideOp>(op)) {
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenThresholdBackwardOp, AtenFloorDivideOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(

File diff suppressed because it is too large Load Diff

View File

@ -163,3 +163,18 @@ bool Torch::isViewLikeOp(Operation *op) {
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenToDeviceOp>(op);
}
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Location loc, float value,
Type dtype) {
// Creating constants satisfying backend contract.
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
dtype.isInteger(1))
return rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((int64_t)value));
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
return rewriter.create<ConstantFloatOp>(loc,
rewriter.getF64FloatAttr(value));
llvm::report_fatal_error(
"unhandled type for getConstantWithGivenDtypeAndValue");
}

View File

@ -8,13 +8,13 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"

View File

@ -34,13 +34,13 @@ using namespace mlir::tosa;
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
namespace reg {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
} // end namespace
} // end namespace reg
void mlir::torch::registerTorchConversionPasses() {
::registerPasses();
reg::registerPasses();
mlir::PassPipelineRegistration<>(
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch backend contract to linalg-on-tensors backend "

View File

@ -10,7 +10,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -71,8 +71,7 @@ class VerifyLinalgOnTensorsBackendContractPass
// Basic scalar operations.
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
isLegalScalarOp);
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);
// Tensor operations should go through linalg and the tensor dialect.
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(opHasLegalTypes);

View File

@ -10,7 +10,7 @@
#include "PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -53,7 +53,7 @@ class VerifyMhloBackendContractPass
target.addLegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<chlo::ChloDialect>();
target.addLegalDialect<tensor::TensorDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithDialect>();
RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"

View File

@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@ -304,7 +304,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addIllegalOp<math::TanhOp>();
target.addIllegalOp<math::ErfOp>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
@ -352,7 +352,7 @@ class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
PatternRewriter &rewriter) const override {
Operation *linalgCopy = createLinalgCopyOp(
rewriter, copyOp.getLoc(), copyOp.source(), copyOp.target());
rewriter, copyOp.getLoc(), copyOp.getSource(), copyOp.getTarget());
rewriter.replaceOp(copyOp, linalgCopy->getResults());
return success();
}

View File

@ -65,7 +65,7 @@ def run_pipeline_with_repro_report(module,
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
Add '{debug_options}' to get the IR dump for debugging purpose.
"""

View File

@ -391,7 +391,7 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return at::functionalization::
functionalize_aten_op<ATEN_OP(new_empty_strided)>::call(
functionalize_aten_op_symint<ATEN_OP(new_empty_strided)>::call(
self, size, stride, dtype, layout, device, pin_memory);
}
@ -400,7 +400,7 @@ at::Tensor LazyNativeFunctions::narrow_copy_symint(
int64_t dim,
c10::SymInt start,
c10::SymInt length) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}
at::Tensor LazyNativeFunctions::pixel_shuffle(
@ -426,7 +426,7 @@ at::Tensor LazyNativeFunctions::slice_backward_symint(
c10::SymInt start,
c10::SymInt end,
c10::SymInt step) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}
at::Tensor LazyNativeFunctions::diagonal_backward(

View File

@ -600,6 +600,9 @@ def atennumpy_T(self: List[int]) -> List[int]:
def atenmatmul(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.matmul(self, other)
def atenmv(self: List[int], vec: List[int]) -> List[int]:
return upstream_shape_functions.mv(self, vec)
def atenmm(self: List[int], mat2: List[int]) -> List[int]:
return upstream_shape_functions.mm(self, mat2)
@ -863,6 +866,9 @@ def atenminimum(self: List[int], other: List[int]) -> List[int]:
def atenmaximum(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_orTensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenbitwise_andTensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
@ -1195,6 +1201,9 @@ def atenlinalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[Lis
def atenfrobenius_normdim(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)
def atenupsample_nearest2dvec(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]:
return upstream_shape_functions.upsample_nearest2d(input, output_size, scale_factors)
# ==============================================================================
# Shape library generator main().
# ==============================================================================

View File

@ -34,6 +34,8 @@ TORCH_TYPE_TO_ODS_TYPE = {
"bool?": "AnyTorchOptionalBoolType",
"float": "Torch_FloatType",
"float?": "AnyTorchOptionalFloatType",
"float[]": "AnyTorchListOfTorchFloatType",
"float[]?": "AnyTorchOptionalListOfTorchFloatType",
"t[]": "AnyTorchListType",
"t": "AnyTorchType",
"t1": "AnyTorchType",
@ -285,6 +287,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
@ -333,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
)
@ -515,6 +519,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)")
# Dict ops.
@ -566,6 +571,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
emit("aten::add.float_int : (float, int) -> (float)")

View File

@ -276,7 +276,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc,
torchMlirTorchListTypeGet(
getMlirTypeFromTorchType(loc, list.elementType())),
getMlirTypeFromTorchType(loc, list.elementType(), importOptions)),
elems);
return mlirOperationGetResult(operation, 0);
}
@ -291,8 +291,8 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.DictConstruct", loc,
torchMlirTorchDictTypeGet(
getMlirTypeFromTorchType(loc, dict.keyType()),
getMlirTypeFromTorchType(loc, dict.valueType())),
getMlirTypeFromTorchType(loc, dict.keyType(), importOptions),
getMlirTypeFromTorchType(loc, dict.valueType(), importOptions)),
keys, values);
return mlirOperationGetResult(operation, 0);
}
@ -368,10 +368,20 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirOperation tensorOp;
if (importOptions.assumeTensorsHaveValueSemantics) {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
} else {
tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
}
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
// Construct the complete tensor value. This is trivial for most tensors, but
@ -384,9 +394,16 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
MlirType quantizedTensorType;
if (importOptions.assumeTensorsHaveValueSemantics) {
quantizedTensorType = torchMlirTorchValueTensorTypeGet(
context, shape.size(), shape.data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
} else {
quantizedTensorType = torchMlirTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
getMlirTypeForTorchScalarType(loc, tensor.scalar_type()));
}
if (tensor.qscheme() == c10::kPerTensorAffine) {
MlirValue qScale = importIValue(c10::IValue(tensor.q_scale()));
MlirValue zeroPoint = importIValue(c10::IValue(tensor.q_zero_point()));
@ -463,7 +480,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
"name", mlirStringAttrGet(
context, toMlirStringRef(classAttribute.getName()))),
toMlirNamedAttribute("type", mlirTypeAttrGet(getMlirTypeFromTorchType(
loc, classAttribute.getType()))),
loc, classAttribute.getType(), importOptions))),
isPrivate);
}

View File

@ -124,10 +124,16 @@ ModuleBuilder::ModuleBuilder(pybind11::object contextObj)
}
torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function,
py::object maybeImportOptions) {
ImportOptions importOptions;
if (!maybeImportOptions.is_none()) {
importOptions = py::cast<ImportOptions>(maybeImportOptions);
}
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_,
[](int) -> MlirAttribute { return {nullptr}; }, importOptions);
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function;
}
@ -182,7 +188,8 @@ void ModuleBuilder::bind(py::module &m) {
.def(py::init<py::object>(), py::arg("context") = py::none())
.def_property_readonly("context", &ModuleBuilder::getContextObj)
.def_property_readonly("module", &ModuleBuilder::getModuleObj)
.def("import_function", &ModuleBuilder::importFunction)
.def("import_function", &ModuleBuilder::importFunction, py::arg("function"),
py::arg("importOptions") = py::none())
.def("import_module", &ModuleBuilder::importModule, py::arg("module"),
py::arg("classAnnotator") = py::none(),
py::arg("importOptions") = py::none());

View File

@ -39,7 +39,8 @@ public:
// Just a bit of naming cruft.
// Returns the same function, making it suitable as a nested decorator.
torch::jit::StrongFunctionPtr
importFunction(torch::jit::StrongFunctionPtr function);
importFunction(torch::jit::StrongFunctionPtr function,
py::object maybeImportOptions);
// Imports a torch::jit::Module into the current module, using the
// annotations, if not none, provided in `maybeClassAnnotator` which should be

View File

@ -198,10 +198,17 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
c10::attr::value)))));
} else if (output->type()->cast<c10::TensorType>()) {
MlirAttribute attr = importAttribute(loc, node, c10::attr::value);
op = createMlirOperation(
"torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
if (importOptions.assumeTensorsHaveValueSemantics) {
op = createMlirOperation(
"torch.vtensor.literal", loc,
torchMlirTorchValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
} else {
op = createMlirOperation(
"torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(attr),
toMlirNamedAttribute("value", attr));
}
} else if (output->type()->cast<c10::DeviceObjType>()) {
op = createMlirOperation(
"torch.constant.device", loc,

View File

@ -2513,6 +2513,25 @@ def ToCopyWithDTypeFalsePinMemoryModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
class ToCopyBoolDTypeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 5, 5], torch.uint8, True),
])
def forward(self, x):
return torch.ops.aten._to_copy(x, dtype=torch.bool)
@register_test_case(module_factory=lambda: ToCopyBoolDTypeStaticModule())
def ToCopyBoolDTypeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 1, 5, 5).to(dtype=torch.uint8))
# ==============================================================================

View File

@ -705,3 +705,80 @@ class Conv_Transpose3dModule(torch.nn.Module):
@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
module.forward(torch.randn(5, 2, 5, 6, 4), torch.randn(2, 5, 2, 2, 2))
class UpSampleNearest2dSameSize(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=[11, 11],
scale_factors=None)
@register_test_case(module_factory=lambda: UpSampleNearest2dSameSize())
def UpSampleNearest2dStaticSize_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 4))
class UpSampleNearest2dDiffSize(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=[8, 11],
scale_factors=None)
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffSize())
def UpSampleNearest2dDynamicSize_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 2, 2))
class UpSampleNearest2dDiffFactor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=None,
scale_factors=[2.3, 4.7])
@register_test_case(module_factory=lambda: UpSampleNearest2dDiffFactor())
def UpSampleNearest2dDynamicFactor_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 2, 2))
class UpSampleNearest2dSameFactor(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, inputVec):
return torch._C._nn.upsample_nearest2d(inputVec,
output_size=None,
scale_factors=[2.0, 2.0])
@register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor())
def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 4))

View File

@ -1553,6 +1553,31 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseOrIntegerModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.bitwise_or(x, y)
@register_test_case(module_factory=lambda: ElementwiseOrIntegerModule())
def ElementwiseOrIntegerModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
tu.randint(3, 4, low=-10, high=10))
# ==============================================================================
class ElementwiseNotIntegerModule(torch.nn.Module):
def __init__(self):

View File

@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
# ==============================================================================
class Mv(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, m, v):
return torch.mv(m, v)
@register_test_case(module_factory=lambda: Mv())
def Mv_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2), tu.rand(2))

View File

@ -104,6 +104,31 @@ def MulIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class DivIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.int64, True),
([], torch.int64, True),
])
def forward(self, lhs, rhs):
# Cast the result to float to make e2e test baseline result to be a float.
# Without the cast, baseline result is a Tensor which is unexpected.
return float(torch.ops.aten.div(int(lhs), int(rhs)))
@register_test_case(module_factory=lambda: DivIntModule())
def DivIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(low=-10, high=10), tu.randint(low=3, high=10))
# ==============================================================================
class DivFloatModule(torch.nn.Module):
def __init__(self):

View File

@ -193,13 +193,13 @@ def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
class ToDtypeBoolLayoutNoneStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.float32, True)])
@annotate_args([None, ([3, 5], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.to(x,
dtype=torch.bool,
@ -211,9 +211,9 @@ class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
memory_format=None)
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneStaticModule())
def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5))
class TypeAsSameModule(torch.nn.Module):

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==1.13.0.dev20220930
torch==1.13.0.dev20221004

View File

@ -1 +1 @@
3bf7094ddb95e8b9bcd2b1f35589e729aa2f4248
9f3d8fec5747fde5191618eb895fbec2d50edf93

View File

@ -847,3 +847,69 @@ func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
%0 = torch.aten.arange.start_step %int0, %int5, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
return %0 : !torch.vtensor<[5],si64>
}
// -----
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}
// -----
// CHECK-LABEL: func.func @torch.valsem.aten.copy(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,1,5,5],ui8> -> tensor<1x1x5x5xi8>
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor<i1>) -> tensor<i1>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<0> : tensor<1x1x5x5xi8>} : () -> tensor<1x1x5x5xi8>
// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1>
// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1>
func.func @torch.valsem.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> {
%int5 = torch.constant.int 5
%int1 = torch.constant.int 1
%int11 = torch.constant.int 11
%none = torch.constant.none
%false = torch.constant.bool false
%int0 = torch.constant.int 0
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.to.dtype %0, %int11, %false, %false, %none : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],i1>
%2 = torch.prim.ListConstruct %int1, %int1, %int5, %int5 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.broadcast_to %1, %2 : !torch.vtensor<[],i1>, !torch.list<int> -> !torch.vtensor<[1,1,5,5],i1>
%4 = torch.valsem.aten.copy %3, %arg0, %false : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,1,5,5],ui8>, !torch.bool -> !torch.vtensor<[1,1,5,5],i1>
return %4 : !torch.vtensor<[1,1,5,5],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
// CHECK: %[[INP:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[3,5],si64> -> tensor<3x5xi64>
// CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<0> : tensor<3x5xi64>} : () -> tensor<3x5xi64>
// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1>
func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> {
%int11 = torch.constant.int 11
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
return %0 : !torch.vtensor<[3,5],i1>
}

View File

@ -1351,6 +1351,16 @@ func.func @torch.aten.div.float$fold_cst_operands() -> !torch.float {
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.div.int$fold_cst_operands(
// CHECK: %[[CST:.*]] = torch.constant.float 5.000000e-01
// CHECK: return %[[CST]] : !torch.float
func.func @torch.aten.div.int$fold_cst_operands() -> !torch.float {
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%0 = torch.aten.div.int %int2, %int4 : !torch.int, !torch.int -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$same_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>

View File

@ -0,0 +1,42 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.
import typing
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ones_i32 = torch.ones(1, dtype=torch.int32)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.arange = torch.nn.Parameter(torch.arange(3.0))
# CHECK: %[[ARANGE:.*]] = torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
# CHECK: %[[ONES_I32:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi32>) : !torch.vtensor<[1],si32>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi8>) : !torch.vtensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
# CHECK: %[[ONES_QINT8:.*]] = torch.per_tensor_affine.create %[[ONES_QINT8_DATA]], %[[SCALE]], %[[ZERO_POINT]] : !torch.vtensor<[1],si8>, !torch.float, !torch.int -> !torch.vtensor<[1],!torch.qint8>
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.vtensor<[3],f32>
# CHECK: torch.slot "ones_i32", %[[ONES_I32]] : !torch.vtensor<[1],si32>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.vtensor<[1],!torch.qint8>
# CHECK: }
test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)
import_options = ImportOptions()
import_options.assumeTensorsHaveValueSemantics = True
class_annotator = ClassAnnotator()
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator, import_options)
mb.module.operation.print()

View File

@ -5,7 +5,7 @@
import typing
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
from utils import create_script_function
@ -162,3 +162,34 @@ graph():
mb.module.operation.print()
print()
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar() -> !torch.number {
# CHECK: %[[A:.*]] = torch.tensor.literal
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
# CHECK: return %[[RET]] : !torch.number
import_options = ImportOptions()
import_options.assumeTensorsHaveValueSemantics = False
mb.import_function(create_script_function("__torch__.prim_Constant_scalar", """
graph():
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
%1 : Scalar = aten::ScalarImplicit(%0)
return (%1)
""", parse_tensor_constants=True), import_options)
mb.module.operation.print()
print()
# CHECK-LABEL: func.func @__torch__.prim_Constant_scalar_value_semantics() -> !torch.number {
# CHECK: %[[A:.*]] = torch.vtensor.literal
# CHECK: %[[RET:.*]] = torch.aten.ScalarImplicit
# CHECK: return %[[RET]] : !torch.number
import_options.assumeTensorsHaveValueSemantics = True
mb.import_function(create_script_function("__torch__.prim_Constant_scalar_value_semantics", """
graph():
%0 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
%1 : Scalar = aten::ScalarImplicit(%0)
return (%1)
""", parse_tensor_constants=True), import_options)
mb.module.operation.print()
print()

View File

@ -10,6 +10,6 @@ from torch._C import CompilationUnit
# RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str):
def create_script_function(func_name, ts_ir_str, **kwargs):
cu = CompilationUnit()
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str, **kwargs))

View File

@ -337,7 +337,7 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TorchMLIRTorchDialect",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
@ -373,7 +373,7 @@ cc_library(
":TorchMLIRTorchBackendTypeConversion",
":TorchMLIRTorchConversionDialect",
":TorchMLIRTorchDialect",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:LinalgDialect",
@ -397,7 +397,7 @@ cc_library(
":TorchMLIRConversionPassesIncGen",
":TorchMLIRTorchBackendTypeConversion",
":TorchMLIRTorchConversionDialect",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:LinalgDialect",
@ -809,7 +809,7 @@ cc_library(
":TorchMLIRRefBackendPassIncGen",
":TorchMLIRTorchBackendTypeConversion",
":TorchMLIRTorchConversionDialect",
"@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathTransforms",