diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 6440d370c..bff0da7bc 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -37,6 +37,7 @@ jobs: -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ -DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="${GITHUB_WORKSPACE}/external/llvm-external-projects/torch-mlir-dialects" \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_MHLO=ON \ -DLLVM_TARGETS_TO_BUILD=host ninja check-torch-mlir-all - name: RefBackend - TorchScript end-to-end tests @@ -81,6 +82,7 @@ jobs: -DLLVM_ENABLE_PROJECTS=mlir \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_TARGETS_TO_BUILD=host \ + -DTORCH_MLIR_ENABLE_MHLO=ON \ externals/llvm-project/llvm ninja -Cllvm-build @@ -94,6 +96,7 @@ jobs: -DMLIR_DIR="$(pwd)/llvm-build/lib/cmake/mlir/" \ -DLLVM_DIR="$(pwd)/llvm-build/lib/cmake/llvm/" \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_MHLO=ON \ -DPython3_EXECUTABLE=$(which python) \ . ninja -Cbuild check-torch-mlir-all diff --git a/.gitmodules b/.gitmodules index 62a290ea6..c62e461c9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "external/llvm-project"] path = externals/llvm-project url = https://github.com/llvm/llvm-project.git +[submodule "externals/mlir-hlo"] + path = externals/mlir-hlo + url = https://github.com/tensorflow/mlir-hlo.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b0f48c87..10c35bcfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,12 +36,18 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() +option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) +if(TORCH_MLIR_ENABLE_MHLO) + add_definitions(-DTORCH_MLIR_ENABLE_MHLO) +endif() + torch_mlir_add_llvm_external_project( torch-mlir-dialects TORCH_MLIR_DIALECTS ${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + message(STATUS "Torch-MLIR out-of-tree build.") # Out-of-tree build #------------------------------------------------------------------------------- @@ -82,10 +88,14 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects) else() + message(STATUS "Torch-MLIR in-tree build.") # In-tree build with LLVM_EXTERNAL_PROJECTS=torch-mlir # FIXME: This should really be inherited from the LLVM tree. In particular, # it's going to change when cross-compiling. set(MLIR_TABLEGEN_EXE mlir-tblgen) + if (TORCH_MLIR_ENABLE_MHLO) + set(MLIR_PDLL_TABLEGEN_EXE mlir-pdll) + endif() option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF) option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) @@ -97,6 +107,15 @@ else() set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() +if (TORCH_MLIR_ENABLE_MHLO) + set(MHLO_BUILD_EMBEDDED ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo + ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo + EXCLUDE_FROM_ALL) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) + include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include) +endif() + set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})") diff --git a/externals/mlir-hlo b/externals/mlir-hlo new file mode 160000 index 000000000..eb1042390 --- /dev/null +++ b/externals/mlir-hlo @@ -0,0 +1 @@ +Subproject commit eb1042390d39131fe7e330b54dc5f29a79c9a072 diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index 607ed2f96..9ee80b304 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,5 +1,9 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) +if(TORCH_MLIR_ENABLE_MHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) +else() + mlir_tablegen(Passes.h.inc -gen-pass-decls) +endif() add_public_tablegen_target(TorchMLIRConversionPassIncGen) add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 02a376d19..7b6c9c6d1 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -125,4 +125,14 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; } +#ifdef TORCH_MLIR_ENABLE_MHLO +def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { + let summary = "Convert Torch ops to MHLO ops"; + let description = [{ + Convert Torch ops to mhlo ops. + }]; + let constructor = "mlir::torch::createConvertTorchToMhloPass()"; +} +#endif + #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h new file mode 100644 index 000000000..ef1337770 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToMhloPass(); +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index ce6ef9da1..b83857f9c 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -34,6 +34,13 @@ void createTorchBackendToTosaBackendPipeline( OpPassManager &pm, const torch::Torch::TorchLoweringPipelineOptions &options); +// Do not register the torch-to-mhlo pipeline if mhlo target is disabled +#ifdef TORCH_MLIR_ENABLE_MHLO +void createTorchBackendToMhloBackendPipeline( + OpPassManager &pm, + const torch::Torch::TorchLoweringPipelineOptions &options); +#endif + std::unique_ptr> createVerifyInvariantsBeforeBackendLoweringPass(); diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9cc4019bb..7ff59d7c3 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,11 +2,22 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToStd) add_subdirectory(TorchToTosa) +if(TORCH_MLIR_ENABLE_MHLO) + add_subdirectory(TorchToMhlo) +endif() add_subdirectory(TorchToTMTensor) add_subdirectory(Utils) # TODO: Automate this with add_torch_mlir_conversion_library. -#get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS) +set(linked_libs TorchMLIRTorchToLinalg + TorchMLIRTorchToSCF + TorchMLIRTorchToStd + TorchMLIRTorchToTosa + TorchMLIRTorchToTMTensor + TorchMLIRConversionUtils) +if(TORCH_MLIR_ENABLE_MHLO) + list(APPEND linked_libs TorchMLIRTorchToMhlo) +endif() add_mlir_library(TorchMLIRConversionPasses Passes.cpp @@ -18,11 +29,6 @@ add_mlir_library(TorchMLIRConversionPasses Core LINK_LIBS PUBLIC - TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToStd - TorchMLIRTorchToTosa - TorchMLIRTorchToTMTensor - TorchMLIRConversionUtils + ${linked_libs} #${torch_mlir_conversion_libs} ) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 710e91187..f21bb010b 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp new file mode 100644 index 000000000..96839fbfb --- /dev/null +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -0,0 +1,71 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + + +namespace { +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +// AtenTanhOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTanhOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + if (selfTy && selfTy.getElementType().isa()) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } else { + return op.emitError( + "Only floating-point datatype legalization currently supported"); + } +} +} // namespace + +void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_ATENOP_PATTERN(AtenTanhOp); +#undef INSERT_ATENOP_PATTERN + +} diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt new file mode 100644 index 000000000..d8ebeb765 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -0,0 +1,22 @@ +add_mlir_conversion_library(TorchMLIRTorchToMhlo + TorchToMhlo.cpp + BasicOp.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo + + DEPENDS + MhloDialect + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MhloDialect + TorchMLIRTorchDialect +) + +torch_mlir_target_includes(TorchMLIRTorchToMhlo) diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h new file mode 100644 index 000000000..72168af14 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -0,0 +1,27 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H +#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace torch { +namespace torch_to_mhlo { + +void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + +} // namespace torch_to_mhlo +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp new file mode 100644 index 000000000..94a64d1bc --- /dev/null +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertTorchToMhlo : public ConvertTorchToMhloBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, + target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToMhloPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 9550e2bba..2b439a2a5 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -20,6 +20,9 @@ #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; @@ -42,11 +45,19 @@ void mlir::torch::registerTorchConversionPasses() { "Pipeline lowering torch backend contract to linalg-on-tensors backend " "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); +#ifdef TORCH_MLIR_ENABLE_MHLO + mlir::PassPipelineRegistration( + "torch-backend-to-mhlo-backend-pipeline", + "Pipeline lowering torch backend contract to MHLO backend " + "contract.", + TorchConversion::createTorchBackendToMhloBackendPipeline); +#endif } void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( @@ -118,3 +129,26 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( // correct form. pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } + +#ifdef TORCH_MLIR_ENABLE_MHLO +void TorchConversion::createTorchBackendToMhloBackendPipeline( + OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) { + // Check some invariants to catch errors in a clear way. + pm.addPass( + TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); + + pm.addNestedPass(createConvertTorchToMhloPass()); + + if (options.optimize) { + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + } + // Finish the type conversion from `torch` types to the types of the + // MHLO backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); +} +#endif diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 79b6c4512..828ab67d7 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -44,6 +44,10 @@ class OutputType(Enum): # for end-users, but can be convenient for development or reporting bugs. RAW = 3 + # This output type consists of `mhlo` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to MHLO. + MHLO = 4 + @staticmethod def get(spec: Union[str, "OutputType"]) -> "OutputType": """Gets an OutputType from allowed way to specify one. @@ -118,7 +122,8 @@ _example_arg = Union[TensorPlaceholder, torch.Tensor] def compile(model: torch.nn.Module, example_args: Union[_example_arg, Sequence[_example_arg]], output_type: Union[str, "OutputType"] = OutputType.TORCH, - use_tracing=False): + use_tracing: bool = False, + verbose: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -180,6 +185,11 @@ def compile(model: torch.nn.Module, "torchscript-module-to-torch-backend-pipeline", "Lowering TorchScript IR -> Torch Backend IR") + if verbose: + print("\n====================") + print("Torch Backend IR") + print(mb.module) + if output_type == OutputType.TORCH: return mb.module @@ -188,6 +198,10 @@ def compile(model: torch.nn.Module, mb.module, "torch-backend-to-tosa-backend-pipeline", "Lowering Torch Backend IR -> TOSA Backend IR") + if verbose: + print("\n====================") + print("TOSA Backend IR") + print(mb.module) return mb.module if output_type == OutputType.LINALG_ON_TENSORS: @@ -195,6 +209,20 @@ def compile(model: torch.nn.Module, mb.module, "torch-backend-to-linalg-on-tensors-backend-pipeline", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") + if verbose: + print("\n====================") + print("LINALG Backend IR") + print(mb.module) return mb.module + elif output_type == OutputType.MHLO: + run_pipeline_with_repro_report( + mb.module, + "torch-backend-to-mhlo-backend-pipeline", + "Lowering Torch Backend IR -> MHLO Backend IR") + if verbose: + print("\n====================") + print("MHLO Backend IR") + print(mb.module) + return mb.module raise Exception(f"Unknown OutputType: {output_type}") diff --git a/setup.py b/setup.py index 849cbfb61..85ce480fd 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ class CMakeBuild(build_py): f"-DLLVM_TARGETS_TO_BUILD=host", f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", f"-DLLVM_ENABLE_PROJECTS=mlir", + f"-DTORCH_MLIR_ENABLE_MHLO=ON", f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir;torch-mlir-dialects", f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects", diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 628bee112..a5098d8aa 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON TORCH_MLIR_ENABLE_JIT_IR_IMPORTER + TORCH_MLIR_ENABLE_MHLO ) configure_lit_site_cfg( diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir new file mode 100644 index 000000000..c76d30f2e --- /dev/null +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.tanh$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToMhlo/lit.local.cfg b/test/Conversion/TorchToMhlo/lit.local.cfg new file mode 100644 index 000000000..829a5662f --- /dev/null +++ b/test/Conversion/TorchToMhlo/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_mhlo: + config.unsupported = True diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 20f8b2729..94c93a05c 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -15,6 +15,7 @@ config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.python_executable = sys.executable config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ +config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@ import lit.llvm lit.llvm.initialize(lit_config, config)