mirror of https://github.com/llvm/torch-mlir
[MHLO] Init MHLO integration. (#1083)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1084/head
parent
647e75e029
commit
c61c99e887
|
@ -37,6 +37,7 @@ jobs:
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
|
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \
|
||||||
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
|
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
|
-DTORCH_MLIR_ENABLE_MHLO=ON \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host
|
-DLLVM_TARGETS_TO_BUILD=host
|
||||||
ninja check-torch-mlir-all
|
ninja check-torch-mlir-all
|
||||||
- name: RefBackend - TorchScript end-to-end tests
|
- name: RefBackend - TorchScript end-to-end tests
|
||||||
|
@ -81,6 +82,7 @@ jobs:
|
||||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
-DLLVM_TARGETS_TO_BUILD=host \
|
-DLLVM_TARGETS_TO_BUILD=host \
|
||||||
|
-DTORCH_MLIR_ENABLE_MHLO=ON \
|
||||||
externals/llvm-project/llvm
|
externals/llvm-project/llvm
|
||||||
ninja -Cllvm-build
|
ninja -Cllvm-build
|
||||||
|
|
||||||
|
@ -94,6 +96,7 @@ jobs:
|
||||||
-DMLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir/" \
|
-DMLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir/" \
|
||||||
-DLLVM_DIR="$(pwd)/llvm-build/lib/cmake/llvm/" \
|
-DLLVM_DIR="$(pwd)/llvm-build/lib/cmake/llvm/" \
|
||||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||||
|
-DTORCH_MLIR_ENABLE_MHLO=ON \
|
||||||
-DPython3_EXECUTABLE=$(which python) \
|
-DPython3_EXECUTABLE=$(which python) \
|
||||||
.
|
.
|
||||||
ninja -Cbuild check-torch-mlir-all
|
ninja -Cbuild check-torch-mlir-all
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
[submodule "external/llvm-project"]
|
[submodule "external/llvm-project"]
|
||||||
path = externals/llvm-project
|
path = externals/llvm-project
|
||||||
url = https://github.com/llvm/llvm-project.git
|
url = https://github.com/llvm/llvm-project.git
|
||||||
|
[submodule "externals/mlir-hlo"]
|
||||||
|
path = externals/mlir-hlo
|
||||||
|
url = https://github.com/tensorflow/mlir-hlo.git
|
||||||
|
|
|
@ -36,12 +36,18 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
|
||||||
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
|
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
||||||
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
||||||
|
endif()
|
||||||
|
|
||||||
torch_mlir_add_llvm_external_project(
|
torch_mlir_add_llvm_external_project(
|
||||||
torch-mlir-dialects
|
torch-mlir-dialects
|
||||||
TORCH_MLIR_DIALECTS
|
TORCH_MLIR_DIALECTS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects)
|
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects)
|
||||||
|
|
||||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
message(STATUS "Torch-MLIR out-of-tree build.")
|
||||||
# Out-of-tree build
|
# Out-of-tree build
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
@ -82,10 +88,14 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
|
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
|
||||||
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)
|
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)
|
||||||
else()
|
else()
|
||||||
|
message(STATUS "Torch-MLIR in-tree build.")
|
||||||
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
|
# In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir
|
||||||
# FIXME: This should really be inherited from the LLVM tree. In particular,
|
# FIXME: This should really be inherited from the LLVM tree. In particular,
|
||||||
# it's going to change when cross-compiling.
|
# it's going to change when cross-compiling.
|
||||||
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
set(MLIR_TABLEGEN_EXE mlir-tblgen)
|
||||||
|
if (TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
set(MLIR_PDLL_TABLEGEN_EXE mlir-pdll)
|
||||||
|
endif()
|
||||||
|
|
||||||
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)
|
||||||
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
|
||||||
|
@ -97,6 +107,15 @@ else()
|
||||||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
set(MHLO_BUILD_EMBEDDED ON)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
|
||||||
|
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
|
message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})")
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit eb1042390d39131fe7e330b54dc5f29a79c9a072
|
|
@ -1,5 +1,9 @@
|
||||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||||
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
|
||||||
|
else()
|
||||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||||
|
endif()
|
||||||
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
|
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
|
||||||
|
|
||||||
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
|
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc)
|
||||||
|
|
|
@ -125,4 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||||
|
let summary = "Convert Torch ops to MHLO ops";
|
||||||
|
let description = [{
|
||||||
|
Convert Torch ops to mhlo ops.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_PASSES
|
#endif // TORCHMLIR_CONVERSION_PASSES
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||||
|
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
|
|
@ -34,6 +34,13 @@ void createTorchBackendToTosaBackendPipeline(
|
||||||
OpPassManager &pm,
|
OpPassManager &pm,
|
||||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
void createTorchBackendToMhloBackendPipeline(
|
||||||
|
OpPassManager &pm,
|
||||||
|
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||||
|
#endif
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyInvariantsBeforeBackendLoweringPass();
|
createVerifyInvariantsBeforeBackendLoweringPass();
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,22 @@ add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToStd)
|
add_subdirectory(TorchToStd)
|
||||||
add_subdirectory(TorchToTosa)
|
add_subdirectory(TorchToTosa)
|
||||||
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
add_subdirectory(TorchToMhlo)
|
||||||
|
endif()
|
||||||
add_subdirectory(TorchToTMTensor)
|
add_subdirectory(TorchToTMTensor)
|
||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
|
|
||||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||||
#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS)
|
set(linked_libs TorchMLIRTorchToLinalg
|
||||||
|
TorchMLIRTorchToSCF
|
||||||
|
TorchMLIRTorchToStd
|
||||||
|
TorchMLIRTorchToTosa
|
||||||
|
TorchMLIRTorchToTMTensor
|
||||||
|
TorchMLIRConversionUtils)
|
||||||
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
|
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRConversionPasses
|
add_mlir_library(TorchMLIRConversionPasses
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
|
@ -18,11 +29,6 @@ add_mlir_library(TorchMLIRConversionPasses
|
||||||
Core
|
Core
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
TorchMLIRTorchToLinalg
|
${linked_libs}
|
||||||
TorchMLIRTorchToSCF
|
|
||||||
TorchMLIRTorchToStd
|
|
||||||
TorchMLIRTorchToTosa
|
|
||||||
TorchMLIRTorchToTMTensor
|
|
||||||
TorchMLIRConversionUtils
|
|
||||||
#${torch_mlir_conversion_libs}
|
#${torch_mlir_conversion_libs}
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "./PopulatePatterns.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// AtenTanhOp
|
||||||
|
namespace {
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
|
AtenTanhOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value self = adaptor.self();
|
||||||
|
auto selfTy = self.getType().cast<TensorType>();
|
||||||
|
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
|
rewriter.replaceOpWithNewOp<mhlo::TanhOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
return success();
|
||||||
|
} else {
|
||||||
|
return op.emitError(
|
||||||
|
"Only floating-point datatype legalization currently supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||||
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
|
TorchToMhlo.cpp
|
||||||
|
BasicOp.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MhloDialect
|
||||||
|
TorchMLIRConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MhloDialect
|
||||||
|
TorchMLIRTorchDialect
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_mlir_target_includes(TorchMLIRTorchToMhlo)
|
|
@ -0,0 +1,27 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||||
|
#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
||||||
|
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace torch {
|
||||||
|
namespace torch_to_mhlo {
|
||||||
|
|
||||||
|
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target);
|
||||||
|
|
||||||
|
} // namespace torch_to_mhlo
|
||||||
|
} // namespace torch
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H
|
|
@ -0,0 +1,66 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "./PopulatePatterns.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Traits.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<mhlo::MhloDialect>();
|
||||||
|
registry.insert<tensor::TensorDialect>();
|
||||||
|
registry.insert<arith::ArithmeticDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
|
||||||
|
arith::ArithmeticDialect, Torch::TorchDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
|
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||||
|
target);
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns)))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::createConvertTorchToMhloPass() {
|
||||||
|
return std::make_unique<ConvertTorchToMhlo>();
|
||||||
|
}
|
|
@ -20,6 +20,9 @@
|
||||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||||
|
#endif
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -42,11 +45,19 @@ void mlir::torch::registerTorchConversionPasses() {
|
||||||
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
|
||||||
"contract.",
|
"contract.",
|
||||||
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
|
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
|
||||||
|
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-backend-to-tosa-backend-pipeline",
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
"Pipeline lowering torch backend contract to TOSA backend "
|
"Pipeline lowering torch backend contract to TOSA backend "
|
||||||
"contract.",
|
"contract.",
|
||||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
|
"torch-backend-to-mhlo-backend-pipeline",
|
||||||
|
"Pipeline lowering torch backend contract to MHLO backend "
|
||||||
|
"contract.",
|
||||||
|
TorchConversion::createTorchBackendToMhloBackendPipeline);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
|
@ -118,3 +129,26 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
// correct form.
|
// correct form.
|
||||||
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
|
void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||||
|
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||||
|
// Check some invariants to catch errors in a clear way.
|
||||||
|
pm.addPass(
|
||||||
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
|
||||||
|
|
||||||
|
if (options.optimize) {
|
||||||
|
// Clean up any non-canonical code introduced above..
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
|
}
|
||||||
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
|
// MHLO backend contract.
|
||||||
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
|
@ -44,6 +44,10 @@ class OutputType(Enum):
|
||||||
# for end-users, but can be convenient for development or reporting bugs.
|
# for end-users, but can be convenient for development or reporting bugs.
|
||||||
RAW = 3
|
RAW = 3
|
||||||
|
|
||||||
|
# This output type consists of `mhlo` dialect ops. It can be thought of
|
||||||
|
# as taking the `TORCH` output type and lowering it to MHLO.
|
||||||
|
MHLO = 4
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(spec: Union[str, "OutputType"]) -> "OutputType":
|
def get(spec: Union[str, "OutputType"]) -> "OutputType":
|
||||||
"""Gets an OutputType from allowed way to specify one.
|
"""Gets an OutputType from allowed way to specify one.
|
||||||
|
@ -118,7 +122,8 @@ _example_arg = Union[TensorPlaceholder, torch.Tensor]
|
||||||
def compile(model: torch.nn.Module,
|
def compile(model: torch.nn.Module,
|
||||||
example_args: Union[_example_arg, Sequence[_example_arg]],
|
example_args: Union[_example_arg, Sequence[_example_arg]],
|
||||||
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
output_type: Union[str, "OutputType"] = OutputType.TORCH,
|
||||||
use_tracing=False):
|
use_tracing: bool = False,
|
||||||
|
verbose: bool = False):
|
||||||
"""Convert a PyTorch model to MLIR.
|
"""Convert a PyTorch model to MLIR.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -180,6 +185,11 @@ def compile(model: torch.nn.Module,
|
||||||
"torchscript-module-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Lowering TorchScript IR -> Torch Backend IR")
|
"Lowering TorchScript IR -> Torch Backend IR")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("\n====================")
|
||||||
|
print("Torch Backend IR")
|
||||||
|
print(mb.module)
|
||||||
|
|
||||||
if output_type == OutputType.TORCH:
|
if output_type == OutputType.TORCH:
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
|
@ -188,6 +198,10 @@ def compile(model: torch.nn.Module,
|
||||||
mb.module,
|
mb.module,
|
||||||
"torch-backend-to-tosa-backend-pipeline",
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||||
|
if verbose:
|
||||||
|
print("\n====================")
|
||||||
|
print("TOSA Backend IR")
|
||||||
|
print(mb.module)
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||||
|
@ -195,6 +209,20 @@ def compile(model: torch.nn.Module,
|
||||||
mb.module,
|
mb.module,
|
||||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||||
|
if verbose:
|
||||||
|
print("\n====================")
|
||||||
|
print("LINALG Backend IR")
|
||||||
|
print(mb.module)
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
|
elif output_type == OutputType.MHLO:
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
mb.module,
|
||||||
|
"torch-backend-to-mhlo-backend-pipeline",
|
||||||
|
"Lowering Torch Backend IR -> MHLO Backend IR")
|
||||||
|
if verbose:
|
||||||
|
print("\n====================")
|
||||||
|
print("MHLO Backend IR")
|
||||||
|
print(mb.module)
|
||||||
|
return mb.module
|
||||||
raise Exception(f"Unknown OutputType: {output_type}")
|
raise Exception(f"Unknown OutputType: {output_type}")
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -75,6 +75,7 @@ class CMakeBuild(build_py):
|
||||||
f"-DLLVM_TARGETS_TO_BUILD=host",
|
f"-DLLVM_TARGETS_TO_BUILD=host",
|
||||||
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
|
f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON",
|
||||||
f"-DLLVM_ENABLE_PROJECTS=mlir",
|
f"-DLLVM_ENABLE_PROJECTS=mlir",
|
||||||
|
f"-DTORCH_MLIR_ENABLE_MHLO=ON",
|
||||||
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects",
|
f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects",
|
||||||
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
|
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
|
||||||
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects",
|
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
llvm_canonicalize_cmake_booleans(
|
llvm_canonicalize_cmake_booleans(
|
||||||
MLIR_ENABLE_BINDINGS_PYTHON
|
MLIR_ENABLE_BINDINGS_PYTHON
|
||||||
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
|
TORCH_MLIR_ENABLE_JIT_IR_IMPORTER
|
||||||
|
TORCH_MLIR_ENABLE_MHLO
|
||||||
)
|
)
|
||||||
|
|
||||||
configure_lit_site_cfg(
|
configure_lit_site_cfg(
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
|
@ -0,0 +1,2 @@
|
||||||
|
if not config.enable_mhlo:
|
||||||
|
config.unsupported = True
|
|
@ -15,6 +15,7 @@ config.llvm_exe_ext = "@EXEEXT@"
|
||||||
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@"
|
||||||
config.python_executable = sys.executable
|
config.python_executable = sys.executable
|
||||||
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
|
config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@
|
||||||
|
config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@
|
||||||
|
|
||||||
import lit.llvm
|
import lit.llvm
|
||||||
lit.llvm.initialize(lit_config, config)
|
lit.llvm.initialize(lit_config, config)
|
||||||
|
|
Loading…
Reference in New Issue