From dd790675716fe0655dfeaf21ee03c8a1f39fafd2 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 6 Sep 2023 04:00:28 -0700 Subject: [PATCH] Fall cleaning: Re-organize sources so that each target type is contained and optional. While working on https://github.com/openxla/iree/pull/14917, I noticed that it is somewhat hard to take a dependency on torch-mlir such that one only builds deps for the target(s) of interest (in this case Linalg). I noticed that some ifdef'ey optionality was added for stablehlo, but this was not mirrored for the others. Further, it does the switching very deep in the dependency graph vs having top-level directories and defines gating entire features. In addition, I noticed that a lot of things in the Linalg path were broken down to a fine level of detail but were not actually shared/shareable outside of that target. I opted to clump these together into TorchToLinalg. It is easy enough to "promote" them to common with this new organization if the need arises. General approach: * Isolate each conversion target in one of TorchToLinalg, TorchToStablehlo, TorchToTosa. * Gate each by top-level CMake flags and defines. * Common conversions go in a Common/ directory (currently Arith and SCF). * Pull target specific conversions out of TorchConversion/Transforms and put in their top-level directory. * General maintenance on the build graph and registration stuff that had bitrotted and was blocking progress. The main functional change for people taking a source dep is that `#include "torch-mlir/Conversion/Passes.h"` no longer is a one stop shop: For optional conversions, you have to include the dedicated `Passes.h` of each and take a library dep. See `InitAll.cpp` which does it right (and *is* a one stop shop still). --- CMakeLists.txt | 11 +- include/torch-mlir/Conversion/CMakeLists.txt | 19 ++- include/torch-mlir/Conversion/Passes.h | 15 +- include/torch-mlir/Conversion/Passes.td | 130 +---------------- .../TorchConversionToMLProgram.h | 23 --- .../Conversion/TorchToLinalg/CMakeLists.txt | 4 + .../Conversion/TorchToLinalg/Passes.h | 41 ++++++ .../Conversion/TorchToLinalg/Passes.td | 132 ++++++++++++++++++ .../Conversion/TorchToLinalg/TorchToLinalg.h | 24 ---- .../Conversion/TorchToSCF/TorchToSCF.h | 22 --- .../TorchToStablehlo/CMakeLists.txt | 5 + .../Conversion/TorchToStablehlo/Passes.h | 51 +++++++ .../Conversion/TorchToStablehlo/Passes.td | 43 ++++++ .../TorchToStablehlo/TorchToStablehlo.h | 26 ---- .../Conversion/TorchToTosa/CMakeLists.txt | 5 + .../Conversion/TorchToTosa/Passes.h | 35 +++++ .../Conversion/TorchToTosa/Passes.td | 33 +++++ .../Dialect/Torch/Transforms/Passes.h | 9 -- .../TorchConversion/Transforms/Passes.h | 35 ----- .../TorchConversion/Transforms/Passes.td | 17 --- lib/CAPI/CMakeLists.txt | 3 + lib/CAPI/Transforms.cpp | 5 + lib/CMakeLists.txt | 8 ++ lib/Conversion/CMakeLists.txt | 36 ++--- lib/Conversion/Common/CMakeLists.txt | 28 ++++ .../{TorchToArith => Common}/TorchToArith.cpp | 22 +-- .../{TorchToSCF => Common}/TorchToSCF.cpp | 2 +- lib/Conversion/Passes.cpp | 16 --- .../TorchConversionToMLProgram/CMakeLists.txt | 22 --- lib/Conversion/TorchToArith/CMakeLists.txt | 20 --- lib/Conversion/TorchToLinalg/CMakeLists.txt | 15 +- lib/Conversion/TorchToLinalg/DataMovement.cpp | 13 +- .../TorchToLinalg/IndirectDataMovement.cpp | 4 +- lib/Conversion/TorchToLinalg/Linear.cpp | 4 +- .../Conversion/TorchToLinalg/PassDetail.h | 19 +-- lib/Conversion/TorchToLinalg/Passes.cpp | 75 ++++++++++ lib/Conversion/TorchToLinalg/Pooling.cpp | 4 +- lib/Conversion/TorchToLinalg/Random.cpp | 4 +- lib/Conversion/TorchToLinalg/Reduction.cpp | 4 +- .../TorchToLinalg/TensorConstructors.cpp | 4 +- .../TorchToLinalg/TensorScalarInterop.cpp | 4 +- .../TorchConversionToMLProgram.cpp | 7 +- .../TorchToLinalg/TorchToLinalg.cpp | 4 +- .../TorchToTMTensor.cpp | 8 +- .../TorchToLinalg/Uncategorized.cpp | 4 +- lib/Conversion/TorchToLinalg/Utils.cpp | 4 +- .../VerifyLinalgOnTensorsBackendContract.cpp | 10 +- lib/Conversion/TorchToSCF/CMakeLists.txt | 22 --- lib/Conversion/TorchToStablehlo/Basic.cpp | 4 +- .../TorchToStablehlo/CMakeLists.txt | 9 +- .../TorchToStablehlo/GatherScatter.cpp | 4 +- lib/Conversion/TorchToStablehlo/Linear.cpp | 4 +- .../Conversion/TorchToStablehlo/PassDetail.h | 19 +-- lib/Conversion/TorchToStablehlo/Passes.cpp | 62 ++++++++ lib/Conversion/TorchToStablehlo/Pooling.cpp | 4 +- lib/Conversion/TorchToStablehlo/Reduction.cpp | 4 +- .../StablehloLegalizeUtils.cpp | 2 +- .../TorchToStablehlo/TorchToStablehlo.cpp | 4 +- .../VerifyStablehloBackendContract.cpp | 13 +- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 4 +- lib/Conversion/TorchToTMTensor/CMakeLists.txt | 23 --- lib/Conversion/TorchToTosa/CMakeLists.txt | 6 +- .../Conversion/TorchToTosa/PassDetail.h | 18 ++- lib/Conversion/TorchToTosa/Passes.cpp | 61 ++++++++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 4 +- .../VerifyTosaBackendContract.cpp | 7 +- lib/Dialect/Torch/Transforms/Passes.cpp | 11 +- .../TorchConversion/Transforms/CMakeLists.txt | 13 -- .../TorchConversion/Transforms/Passes.cpp | 126 +---------------- lib/InitAll.cpp | 18 +++ 70 files changed, 800 insertions(+), 671 deletions(-) delete mode 100644 include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h create mode 100644 include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt create mode 100644 include/torch-mlir/Conversion/TorchToLinalg/Passes.h create mode 100644 include/torch-mlir/Conversion/TorchToLinalg/Passes.td delete mode 100644 include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h delete mode 100644 include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h create mode 100644 include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt create mode 100644 include/torch-mlir/Conversion/TorchToStablehlo/Passes.h create mode 100644 include/torch-mlir/Conversion/TorchToStablehlo/Passes.td delete mode 100644 include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h create mode 100644 include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt create mode 100644 include/torch-mlir/Conversion/TorchToTosa/Passes.h create mode 100644 include/torch-mlir/Conversion/TorchToTosa/Passes.td create mode 100644 lib/Conversion/Common/CMakeLists.txt rename lib/Conversion/{TorchToArith => Common}/TorchToArith.cpp (96%) rename lib/Conversion/{TorchToSCF => Common}/TorchToSCF.cpp (99%) delete mode 100644 lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt delete mode 100644 lib/Conversion/TorchToArith/CMakeLists.txt rename include/torch-mlir/Conversion/TorchToArith/TorchToArith.h => lib/Conversion/TorchToLinalg/PassDetail.h (55%) create mode 100644 lib/Conversion/TorchToLinalg/Passes.cpp rename lib/Conversion/{TorchConversionToMLProgram => TorchToLinalg}/TorchConversionToMLProgram.cpp (96%) rename lib/Conversion/{TorchToTMTensor => TorchToLinalg}/TorchToTMTensor.cpp (99%) rename lib/{Dialect/TorchConversion/Transforms => Conversion/TorchToLinalg}/VerifyLinalgOnTensorsBackendContract.cpp (95%) delete mode 100644 lib/Conversion/TorchToSCF/CMakeLists.txt rename include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h => lib/Conversion/TorchToStablehlo/PassDetail.h (54%) create mode 100644 lib/Conversion/TorchToStablehlo/Passes.cpp rename lib/{Dialect/TorchConversion/Transforms => Conversion/TorchToStablehlo}/VerifyStablehloBackendContract.cpp (86%) delete mode 100644 lib/Conversion/TorchToTMTensor/CMakeLists.txt rename include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h => lib/Conversion/TorchToTosa/PassDetail.h (56%) create mode 100644 lib/Conversion/TorchToTosa/Passes.cpp rename lib/{Dialect/TorchConversion/Transforms => Conversion/TorchToTosa}/VerifyTosaBackendContract.cpp (91%) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3c636fc6..c01a7d43a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,10 +36,19 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() -option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) +# Optional conversion targets. +if(TORCH_MLIR_ENABLE_LINALG) + add_definitions(-DTORCH_MLIR_ENABLE_LINALG) +endif() +option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect conversions" ON) if(TORCH_MLIR_ENABLE_STABLEHLO) add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) endif() +option(TORCH_MLIR_ENABLE_TOSA "Add tosa dialect conversions" ON) +if(TORCH_MLIR_ENABLE_TOSA) + add_definitions(-DTORCH_MLIR_ENABLE_TOSA) +endif() +option(TORCH_MLIR_ENABLE_LINALG "Add linalg dialect" ON) option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d65523149..90fa85ba2 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,9 +1,14 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) -else() - mlir_tablegen(Passes.h.inc -gen-pass-decls) -endif() -add_public_tablegen_target(TorchMLIRConversionPassIncGen) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionCommonPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionCommonPasses ./ -gen-pass-doc) -add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) +if(TORCH_MLIR_ENABLE_LINALG) + add_subdirectory(TorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + add_subdirectory(TorchToTosa) +endif() +if(TORCH_MLIR_ENABLE_STABLEHLO) + add_subdirectory(TorchToStablehlo) +endif() \ No newline at end of file diff --git a/include/torch-mlir/Conversion/Passes.h b/include/torch-mlir/Conversion/Passes.h index 8ab6eb56b..df7515be2 100644 --- a/include/torch-mlir/Conversion/Passes.h +++ b/include/torch-mlir/Conversion/Passes.h @@ -7,16 +7,21 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_PASSES_H -#define TORCHMLIR_CONVERSION_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include namespace mlir { namespace torch { -/// Registers all torch-mlir conversion passes. +std::unique_ptr> createConvertTorchToArithPass(); +std::unique_ptr> createConvertTorchToSCFPass(); + +// Note that this only registers common conversion passes. Backend +// specific passes with their own Passes.h in a subdirectory must be +// included/registered explicitly as they are all optional. void registerConversionPasses(); } // namespace torch } // namespace mlir - -#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472..1b7b7d281 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -13,7 +13,7 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// -// Torch conversions +// Common conversions //===----------------------------------------------------------------------===// def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { @@ -26,132 +26,4 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToSCFPass()"; } -def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { - let summary = "Convert recognized Torch ops to Linalg ops"; - let description = [{ - Convert ATen ops to linalg ops. - - This pass's main responsibility is to bridge the world between ops - that safely terminate the program in case of operand shape mismatches - (ATen) and ops where such mismatches are undefined behavior (linalg). - - To model the termination of the program for implementing error guards, - we use the `cf.assert` op. - This is a design decision that is at variance from other passes in the - ecosystem, which use the - `shape` dialect's witness system (`shape.cstr_*` family of ops feeding into - `shape.assuming` regions). This is a change in design decisions - from those passes (which the authors of this pass have contributed to). - The reasons for this change are heuristic, but boil down to: - 1. The modeling of `shape.assuming` is odd, as it uses a region, which is - not a good fit for modeling error guards. Regions mark a "start" and an - "end" (which is their nesting property). But - modeling assertions in the program doesn't fit into that. For assertions, - only the "start" matters (once tested, a predicate remains true "forever" - -- it doesn't end at the "yield" of the region). - Thus, having regions places arbitrary "end"s that just add IR structure - that has no semantic value for modeling this problem! (and to make things - worse the "end"s, which we don't need, are what require "yielding" - values, which interrupts use-def chains). Consider the different - structural properties of regions: - a. IsolatedFromAbove region: - - "start" interrupts use-def chains, - - "end" interrupts use-def chains - - structurally protects from intra-block upward and downward - code motion - b. Capturing region (like `shape.assuming`): - - "start" does not interrupt use-def chains, - - "end" interrupts use-def chains - - structurally protects from intra-block upward and downward - code motion - c. What we "ideally" want: - - "start" interrupts use-def chains (can be pruned though) - - no "end" IR structure! - - structurally protects from intra-block upward code motion - (but not downward code motion!) - - Observation: We probably can't get all of this, but overall this - problem is much better suited for a "MemorySSA"-like - abstraction, call it "EffectSSA" which is constructed on-demand - based on MLIR's effect modeling system (rather than - `shape.assuming`, which only covers the effects the IR creator - encoded -- with witnesses/`shape.assuming` -- it is easy to forget - to handle effects other than those encoded in the - witness structure). - 2. The presence of `shape.assuming` regions tends to create highly nested - IR structures, which don't interoperate well with any other IR - structures, and creates very bulky IR (and IR creation code). In general - if we are going to do anything with anything (e.g. canonicalize) we - end up needing need to either: - a. Flatten the `shape.assuming` IR (defeating the purpose of having - it). - b. Do some sort of shape.assuming "region merging". - c. Have special patterns that handle a subset of special cases (looking - through "yields" and such) and don't generalize. - 3. Witnesses tend to encourage non-scalable peephole transformations, which - tend to make analyses/transformations non-robust to the presence of - control flow and side effecting ops (easy to forget to handle side - effects other than those modeled by the witness system). - 4. All this code operates on ranked tensors, for which using individual - SSA values for sizes (rather than a "shape type") seems to - work really well at this level of abstraction based on prior experience - in other projects. (unranked code tends to benefit from having a discrete - "shape type" to model shapes). - - We will see if we end up needing something like `shape.assuming`, but for - now, it seems likely we can do something simpler and just bypass it. The - design of having an EffectSSA that is constructed on-demand seems very - compelling for modeling effects more broadly. - }]; - let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; -} - -def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { - let summary = "Convert Torch ops to TOSA ops"; - let description = [{ - This pass assumes that TOSA ops are responsible for emitting error - guards in case of shape mismatches. - }]; - let constructor = "mlir::torch::createConvertTorchToTosaPass()"; -} - -def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { - let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; - let description = [{ - Convert ATen ops to tmtensor/linalg ops. - - This pass is similar to the TorchToLinalg pass; the difference is that this - pass also makes use of TMTensor Dialect, which the former one doesn't. - }]; - let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; -} - -def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { - let summary = "Convert recognized TorchConversion ops to MLProgram ops"; - let description = [{ - Convert TorchConversion ops to mlprogram ops. - }]; - let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { - let summary = "Convert Torch ops to Stablehlo ops"; - let description = [{ - Convert Torch ops to Stablehlo ops. - }]; - let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; - - // Specify any options. - let options = [ - Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", - "Enable static shape conversion">, - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes - // are unlikely to exceed the range of i32(4GiB) - Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", - "Enable truncate index from i64 to i32(unsafely)">, - ]; -} -#endif - #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h deleted file mode 100644 index 6d14ec927..000000000 --- a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h +++ /dev/null @@ -1,23 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H -#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -std::unique_ptr> -createConvertTorchConversionToMLProgramPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt new file mode 100644 index 000000000..1b78c36b1 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionLinalgPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionLinalgPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Passes.h b/include/torch-mlir/Conversion/TorchToLinalg/Passes.h new file mode 100644 index 000000000..adc5a31ce --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/Passes.h @@ -0,0 +1,41 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES_H +#define TORCHMLIR_CONVERSION_LINALG_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +/// Creates a pipeline that lowers from the torch backend contract to the +/// linalg-on-tensors backend contract. +void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); + +std::unique_ptr> +createVerifyLinalgOnTensorsBackendContractPass(); + +std::unique_ptr> createConvertTorchToLinalgPass(); + +std::unique_ptr> +createConvertTorchConversionToMLProgramPass(); + +std::unique_ptr> createConvertTorchToTMTensorPass(); + +/// Registers all torch-mlir conversion passes. +void registerLinalgConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Passes.td b/include/torch-mlir/Conversion/TorchToLinalg/Passes.td new file mode 100644 index 000000000..168b52613 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/Passes.td @@ -0,0 +1,132 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES +#define TORCHMLIR_CONVERSION_LINALG_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to Std ops"; + let constructor = "mlir::torch::createConvertTorchToArithPass()"; +} + +def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to SCF ops"; + let constructor = "mlir::torch::createConvertTorchToSCFPass()"; +} + +def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to Linalg ops"; + let description = [{ + Convert ATen ops to linalg ops. + + This pass's main responsibility is to bridge the world between ops + that safely terminate the program in case of operand shape mismatches + (ATen) and ops where such mismatches are undefined behavior (linalg). + + To model the termination of the program for implementing error guards, + we use the `cf.assert` op. + This is a design decision that is at variance from other passes in the + ecosystem, which use the + `shape` dialect's witness system (`shape.cstr_*` family of ops feeding into + `shape.assuming` regions). This is a change in design decisions + from those passes (which the authors of this pass have contributed to). + The reasons for this change are heuristic, but boil down to: + 1. The modeling of `shape.assuming` is odd, as it uses a region, which is + not a good fit for modeling error guards. Regions mark a "start" and an + "end" (which is their nesting property). But + modeling assertions in the program doesn't fit into that. For assertions, + only the "start" matters (once tested, a predicate remains true "forever" + -- it doesn't end at the "yield" of the region). + Thus, having regions places arbitrary "end"s that just add IR structure + that has no semantic value for modeling this problem! (and to make things + worse the "end"s, which we don't need, are what require "yielding" + values, which interrupts use-def chains). Consider the different + structural properties of regions: + a. IsolatedFromAbove region: + - "start" interrupts use-def chains, + - "end" interrupts use-def chains + - structurally protects from intra-block upward and downward + code motion + b. Capturing region (like `shape.assuming`): + - "start" does not interrupt use-def chains, + - "end" interrupts use-def chains + - structurally protects from intra-block upward and downward + code motion + c. What we "ideally" want: + - "start" interrupts use-def chains (can be pruned though) + - no "end" IR structure! + - structurally protects from intra-block upward code motion + (but not downward code motion!) + - Observation: We probably can't get all of this, but overall this + problem is much better suited for a "MemorySSA"-like + abstraction, call it "EffectSSA" which is constructed on-demand + based on MLIR's effect modeling system (rather than + `shape.assuming`, which only covers the effects the IR creator + encoded -- with witnesses/`shape.assuming` -- it is easy to forget + to handle effects other than those encoded in the + witness structure). + 2. The presence of `shape.assuming` regions tends to create highly nested + IR structures, which don't interoperate well with any other IR + structures, and creates very bulky IR (and IR creation code). In general + if we are going to do anything with anything (e.g. canonicalize) we + end up needing need to either: + a. Flatten the `shape.assuming` IR (defeating the purpose of having + it). + b. Do some sort of shape.assuming "region merging". + c. Have special patterns that handle a subset of special cases (looking + through "yields" and such) and don't generalize. + 3. Witnesses tend to encourage non-scalable peephole transformations, which + tend to make analyses/transformations non-robust to the presence of + control flow and side effecting ops (easy to forget to handle side + effects other than those modeled by the witness system). + 4. All this code operates on ranked tensors, for which using individual + SSA values for sizes (rather than a "shape type") seems to + work really well at this level of abstraction based on prior experience + in other projects. (unranked code tends to benefit from having a discrete + "shape type" to model shapes). + + We will see if we end up needing something like `shape.assuming`, but for + now, it seems likely we can do something simpler and just bypass it. The + design of having an EffectSSA that is constructed on-demand seems very + compelling for modeling effects more broadly. + }]; + let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; +} + +def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; + let description = [{ + Convert ATen ops to tmtensor/linalg ops. + + This pass is similar to the TorchToLinalg pass; the difference is that this + pass also makes use of TMTensor Dialect, which the former one doesn't. + }]; + let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; +} + +def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { + let summary = "Convert recognized TorchConversion ops to MLProgram ops"; + let description = [{ + Convert TorchConversion ops to mlprogram ops. + }]; + let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; +} + +def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the linalg-on-tensors backend contract"; + let constructor = "mlir::torch::createVerifyLinalgOnTensorsBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_LINALG_PASSES diff --git a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h b/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h deleted file mode 100644 index 53caf8598..000000000 --- a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h +++ /dev/null @@ -1,24 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H -#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { -namespace torch { -std::unique_ptr> createConvertTorchToLinalgPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H diff --git a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h b/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h deleted file mode 100644 index 7b869dae4..000000000 --- a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h +++ /dev/null @@ -1,22 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H -#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -std::unique_ptr> createConvertTorchToSCFPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt new file mode 100644 index 000000000..ccb92e011 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionStablehloPassIncGen) + +add_mlir_doc(Passes TorchMLIRConversionStablehloPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h new file mode 100644 index 000000000..cfcb8a1da --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h @@ -0,0 +1,51 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +struct StablehloBackendPipelineOptions + : public PassPipelineOptions { + Option enableStaticShape{ + *this, "enable-static-shape", + llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option enableI32Index{ + *this, "enable-i32-index", + llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"), + llvm::cl::init(false)}; +}; + +void createTorchBackendToStablehloBackendPipeline( + OpPassManager &pm, const StablehloBackendPipelineOptions &options); +std::unique_ptr> +createVerifyStablehloBackendContractPass(); + +std::unique_ptr> +createConvertTorchToStablehloPass(); +std::unique_ptr> +createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); + +/// Registers all torch-mlir conversion passes. +void registerStablehloConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td new file mode 100644 index 000000000..d13ed0528 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td @@ -0,0 +1,43 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { + let summary = "Convert Torch ops to Stablehlo ops"; + let description = [{ + Convert Torch ops to Stablehlo ops. + }]; + let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; + + // Specify any options. + let options = [ + Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", + "Enable static shape conversion">, + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", + "Enable truncate index from i64 to i32(unsafely)">, + ]; +} + +def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the stablehlo backend contract"; + let constructor = "mlir::torch::createVerifyStablehloBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h deleted file mode 100644 index c19260159..000000000 --- a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h +++ /dev/null @@ -1,26 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H -#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { -namespace torch { -std::unique_ptr> -createConvertTorchToStablehloPass(); -std::unique_ptr> -createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); -} // namespace torch -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt new file mode 100644 index 000000000..d8036af77 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionTosaPassIncGen) + +add_mlir_doc(Passes TorchMLIRConversionTosaPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToTosa/Passes.h b/include/torch-mlir/Conversion/TorchToTosa/Passes.h new file mode 100644 index 000000000..c9b0af837 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/Passes.h @@ -0,0 +1,35 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES_H +#define TORCHMLIR_CONVERSION_TOSA_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +/// Creates a pipeline that lowers from the torch backend contract to the +/// TOSA backend contract. +void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); + +std::unique_ptr> createVerifyTosaBackendContractPass(); + +std::unique_ptr> createConvertTorchToTosaPass(); + +/// Registers all torch-mlir conversion passes. +void registerTosaConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/Passes.td b/include/torch-mlir/Conversion/TorchToTosa/Passes.td new file mode 100644 index 000000000..d2ab7d3e8 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/Passes.td @@ -0,0 +1,33 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES +#define TORCHMLIR_CONVERSION_TOSA_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { + let summary = "Convert Torch ops to TOSA ops"; + let description = [{ + This pass assumes that TOSA ops are responsible for emitting error + guards in case of shape mismatches. + }]; + let constructor = "mlir::torch::createConvertTorchToTosaPass()"; +} + +def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the linalg-on-tensors backend contract"; + let constructor = "mlir::torch::createVerifyTosaBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_TOSA_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 84efddcc9..ddbb2d633 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -21,8 +21,6 @@ class ModuleOp; namespace torch { namespace Torch { -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - std::unique_ptr> createGlobalizeObjectGraphPass(); std::unique_ptr> @@ -141,13 +139,6 @@ static const char kTorchOpPrefix[] = R"(torch.)"; /// Registers all Torch transformation passes. void registerTorchPasses(); -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - -#define GEN_PASS_REGISTRATION -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840..6a6d42225 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -22,36 +22,6 @@ class ModuleOp; namespace torch { namespace TorchConversion { -/// Creates a pipeline that lowers from the torch backend contract to the -/// linalg-on-tensors backend contract. -void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); - -/// Creates a pipeline that lowers from the torch backend contract to the -/// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); - -// Do not register the stablehlo options if the stablehlo target is disabled -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -struct StablehloBackendPipelineOptions - : public PassPipelineOptions { - Option enableStaticShape{ - *this, "enable-static-shape", - llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes - // are unlikely to exceed the range of i32(4GiB) - Option enableI32Index{ - *this, "enable-i32-index", - llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"), - llvm::cl::init(false)}; -}; - -void createTorchBackendToStablehloBackendPipeline( - OpPassManager &pm, const StablehloBackendPipelineOptions &options); -std::unique_ptr> -createVerifyStablehloBackendContractPass(); -#endif - std::unique_ptr> createFuncBackendTypeConversionPass(); std::unique_ptr> @@ -65,11 +35,6 @@ createFinalizingBackendTypeConversionPass(); std::unique_ptr> createUnpackQuantTensorPass(); std::unique_ptr> createConvertCustomQuantOpPass(); -std::unique_ptr> -createVerifyLinalgOnTensorsBackendContractPass(); - -std::unique_ptr> createVerifyTosaBackendContractPass(); - } // namespace TorchConversion /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81..f81a7defa 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -32,23 +32,6 @@ def FinalizingBackendTypeConversion }]; } -def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the linalg-on-tensors backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; -} - -def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the linalg-on-tensors backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the stablehlo backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; -} -#endif // TORCH_MLIR_ENABLE_STABLEHLO - // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index d71796ae8..8dea1689c 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -8,6 +8,9 @@ add_mlir_public_c_api_library(TorchMLIRCAPI ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ + DEPENDS + TorchMLIRTorchPassIncGen + ENABLE_AGGREGATION LINK_COMPONENTS Core diff --git a/lib/CAPI/Transforms.cpp b/lib/CAPI/Transforms.cpp index f0f57a72d..c29467aaa 100644 --- a/lib/CAPI/Transforms.cpp +++ b/lib/CAPI/Transforms.cpp @@ -9,6 +9,11 @@ #include "mlir/CAPI/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" +} // namespace + // Must include the declarations as they carry important visibility attributes. #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index fb1da3fb6..24e3b2a5c 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -23,8 +23,16 @@ set(LinkedLibs TorchMLIRRefBackend ) +# Conditionally link in backends if enabled. +if(TORCH_MLIR_ENABLE_LINALG) + list(APPEND LinkedLibs TorchMLIRTorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + list(APPEND LinkedLibs TorchMLIRTorchToTosa) +endif() if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs + TorchMLIRTorchToStablehlo MhloPasses MhloToLinalg StablehloToMhlo diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index d72563b1e..8790915b1 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,36 +1,18 @@ -add_subdirectory(TorchToLinalg) -add_subdirectory(TorchToSCF) -add_subdirectory(TorchToArith) -add_subdirectory(TorchToTosa) +if(TORCH_MLIR_ENABLE_LINALG) + add_subdirectory(TorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + add_subdirectory(TorchToTosa) +endif() if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) endif() -add_subdirectory(TorchToTMTensor) -add_subdirectory(TorchConversionToMLProgram) add_subdirectory(Utils) -# TODO: Automate this with add_torch_mlir_conversion_library. -set(linked_libs TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToArith - TorchMLIRTorchToTosa - TorchMLIRTorchToTMTensor - TorchMLIRTorchConversionToMLProgram - TorchMLIRConversionUtils) -if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND linked_libs TorchMLIRTorchToStablehlo) -endif() - add_mlir_library(TorchMLIRConversionPasses Passes.cpp - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - ${linked_libs} - #${torch_mlir_conversion_libs} + TorchMLIRConversionUtils ) + +add_subdirectory(Common) \ No newline at end of file diff --git a/lib/Conversion/Common/CMakeLists.txt b/lib/Conversion/Common/CMakeLists.txt new file mode 100644 index 000000000..b14a16730 --- /dev/null +++ b/lib/Conversion/Common/CMakeLists.txt @@ -0,0 +1,28 @@ +add_mlir_conversion_library(TorchMLIRConversionCommon + TorchToArith.cpp + TorchToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion + + DEPENDS + TorchMLIRConversionCommonPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRIR + MLIRFuncDialect + MLIRPass + MLIRMathDialect + MLIRSCFDialect + MLIRTransforms + TorchMLIRTorchDialect + TorchMLIRTorchConversionDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchUtils +) + +torch_mlir_target_includes(TorchMLIRConversionCommon) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/Common/TorchToArith.cpp similarity index 96% rename from lib/Conversion/TorchToArith/TorchToArith.cpp rename to lib/Conversion/Common/TorchToArith.cpp index 9e3cc2f75..fcfe11754 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/Common/TorchToArith.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" +#include "torch-mlir/Conversion/Passes.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -43,7 +43,8 @@ public: LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); + auto rank = + rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -74,7 +75,8 @@ public: matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); + rewriter.template replaceOpWithNewOp(op, adaptor.getA(), + adaptor.getB()); return success(); } }; @@ -112,10 +114,10 @@ public: typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value a = - convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); - Value b = - convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); + Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), + rewriter.getF64Type()); + Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), + rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } @@ -180,7 +182,8 @@ public: })); return success(); } - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = + op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -357,7 +360,8 @@ public: // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToArith : public ConvertTorchToArithBase { +class ConvertTorchToArith + : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/Common/TorchToSCF.cpp similarity index 99% rename from lib/Conversion/TorchToSCF/TorchToSCF.cpp rename to lib/Conversion/Common/TorchToSCF.cpp index 146959151..c2b3f085d 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/Common/TorchToSCF.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "torch-mlir/Conversion/Passes.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 45714601d..509ef0441 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,22 +9,6 @@ #include "torch-mlir/Conversion/Passes.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "transforms/passes.h" -#endif // TORCH_MLIR_ENABLE_STABLEHLO - -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" - -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - namespace { #define GEN_PASS_REGISTRATION #include "torch-mlir/Conversion/Passes.h.inc" diff --git a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt deleted file mode 100644 index b89ffbb43..000000000 --- a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram - TorchConversionToMLProgram.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLinalgDialect - MLIRMLProgramDialect - MLIRMathDialect - MLIRPass - TorchMLIRTorchDialect -) - -torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram) diff --git a/lib/Conversion/TorchToArith/CMakeLists.txt b/lib/Conversion/TorchToArith/CMakeLists.txt deleted file mode 100644 index 4524c3b07..000000000 --- a/lib/Conversion/TorchToArith/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToArith - TorchToArith.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRFuncDialect - TorchMLIRTorchDialect -) - -torch_mlir_target_includes(TorchMLIRTorchToArith) diff --git a/lib/Conversion/TorchToLinalg/CMakeLists.txt b/lib/Conversion/TorchToLinalg/CMakeLists.txt index ece929597..8913479f4 100644 --- a/lib/Conversion/TorchToLinalg/CMakeLists.txt +++ b/lib/Conversion/TorchToLinalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg + Passes.cpp DataMovement.cpp IndirectDataMovement.cpp Linear.cpp @@ -7,25 +8,37 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg Reduction.cpp TensorConstructors.cpp TensorScalarInterop.cpp + TorchConversionToMLProgram.cpp TorchToLinalg.cpp + TorchToTMTensor.cpp Uncategorized.cpp Utils.cpp + VerifyLinalgOnTensorsBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionLinalgPassIncGen LINK_COMPONENTS Core LINK_LIBS PUBLIC MLIRIR + MLIRFuncDialect MLIRPass MLIRLinalgDialect MLIRMathDialect + MLIRMLProgramDialect + MLIRSCFDialect + MLIRTransforms + TorchMLIRConversionCommon TorchMLIRTorchDialect + TorchMLIRTorchConversionDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchUtils + TorchMLIRTMTensorDialect ) torch_mlir_target_includes(TorchMLIRTorchToLinalg) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6cf6eb3be..1e4a2b2df 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -7,13 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" - -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -21,7 +15,12 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632..42ac1a800 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 65f08a4d7..3a5e8f78d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h b/lib/Conversion/TorchToLinalg/PassDetail.h similarity index 55% rename from include/torch-mlir/Conversion/TorchToArith/TorchToArith.h rename to lib/Conversion/TorchToLinalg/PassDetail.h index ab708557b..36a83aa50 100644 --- a/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h +++ b/lib/Conversion/TorchToLinalg/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,17 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H -#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToArithPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H diff --git a/lib/Conversion/TorchToLinalg/Passes.cpp b/lib/Conversion/TorchToLinalg/Passes.cpp new file mode 100644 index 000000000..83bed2627 --- /dev/null +++ b/lib/Conversion/TorchToLinalg/Passes.cpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" + +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc" +} // end namespace + +void mlir::torch::registerLinalgConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration<>( + "torch-backend-to-linalg-on-tensors-backend-pipeline", + "Pipeline lowering torch backend contract to linalg-on-tensors backend " + "contract.", + createTorchBackendToLinalgOnTensorsBackendPipeline); +} + +void mlir::torch::createTorchBackendToLinalgOnTensorsBackendPipeline( + OpPassManager &pm) { + // Lower to linalg + guards which is the input to codegen backends. + // We do this first as it tends to involve pattern-matching against constants, + // (e.g. dimensions which must be constant in a ranked programming model) + // and those constants get somewhat obscured by TorchToArith. + pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(createConvertTorchToArithPass()); + pm.addPass(createConvertTorchConversionToMLProgramPass()); + pm.addNestedPass(memref::createExpandOpsPass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // Resolve `dim` ops on tensors (which currently live in the `memref` + // dialect for some reason -- we don't have memrefs at this level). + pm.addNestedPass( + memref::createResolveShapedTypeResultDimsPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // linalg-on-tensors backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to the form that linalg on tensors backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(createVerifyLinalgOnTensorsBackendContractPass()); +} diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 1d7ff925b..a49432b34 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index e1a3e416c..fac632dfc 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 4078fbaa3..3b5646169 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 7e73fabd8..f4e7cf859 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62..a8a050eef 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp similarity index 96% rename from lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp rename to lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp index eab81c2be..e0a4f2653 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -82,7 +82,8 @@ public: // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = + rewriter.create(loc, seed, globalVar, ValueRange()); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1f9f4b17b..e9c8ca021 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp similarity index 99% rename from lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp rename to lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp index a2d58daac..bec346870 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -1273,13 +1273,13 @@ public: // Set the values in the input tensor to the smallest element of that // type TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/true); + /*getMin=*/true); normalizationValue = rewriter.create(loc, minAttr); } else if (reduceEnum == torch_upstream::ReductionType::MIN) { // Set the values in the input tensor to the largest element of that // type TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/false); + /*getMin=*/false); normalizationValue = rewriter.create(loc, maxAttr); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8684b68d9..8259d0416 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 4a47790b0..1dee87725 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "Utils.h" +#include "./Utils.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp similarity index 95% rename from lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp rename to lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp index 93d7de825..79e7cd64e 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -24,7 +25,6 @@ #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" @@ -33,7 +33,6 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; - namespace { class VerifyLinalgOnTensorsBackendContractPass : public VerifyLinalgOnTensorsBackendContractBase< @@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) - << "Module does not conform to the linalg-on-tensors backend contract. " + << "Module does not conform to the linalg-on-tensors backend " + "contract. " "See dialect conversion legality information above."; return signalPassFailure(); } @@ -105,6 +105,6 @@ class VerifyLinalgOnTensorsBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() { +mlir::torch::createVerifyLinalgOnTensorsBackendContractPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToSCF/CMakeLists.txt b/lib/Conversion/TorchToSCF/CMakeLists.txt deleted file mode 100644 index 06a888b3e..000000000 --- a/lib/Conversion/TorchToSCF/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToSCF - TorchToSCF.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRSCFDialect - MLIRFuncDialect - TorchMLIRTorchDialect - TorchMLIRTorchConversionDialect -) - -torch_mlir_target_includes(TorchMLIRTorchToSCF) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 0a2ed02e0..1d31c3ab9 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 84a560cd7..bd5a9057e 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo + Passes.cpp TorchToStablehlo.cpp StablehloLegalizeUtils.cpp Basic.cpp @@ -7,21 +8,25 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo ViewLike.cpp Reduction.cpp Pooling.cpp + VerifyStablehloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionStablehloPassIncGen LINK_COMPONENTS Core LINK_LIBS PUBLIC + MLIRBufferTransforms MLIRIR MLIRPass - MLIRBufferTransforms + MLIRTransforms StablehloOps + TorchMLIRConversionCommon + TorchMLIRTorchConversionPasses TorchMLIRTorchDialect TorchMLIRConversionUtils ) diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 9c8123bfd..3e098583b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 71d679aea..a3e76a971 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/lib/Conversion/TorchToStablehlo/PassDetail.h similarity index 54% rename from include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h rename to lib/Conversion/TorchToStablehlo/PassDetail.h index a6d774a64..0cc875206 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/lib/Conversion/TorchToStablehlo/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,17 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H -#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTosaPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H diff --git a/lib/Conversion/TorchToStablehlo/Passes.cpp b/lib/Conversion/TorchToStablehlo/Passes.cpp new file mode 100644 index 000000000..4fec1de8b --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Passes.cpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" +#include "transforms/passes.h" + +using namespace mlir; +using namespace mlir::torch; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc" +} // end namespace + +void mlir::torch::registerStablehloConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration( + "torch-backend-to-stablehlo-backend-pipeline", + "Pipeline lowering torch backend contract to StableHLO backend " + "contract.", + createTorchBackendToStablehloBackendPipeline); +} + +void mlir::torch::createTorchBackendToStablehloBackendPipeline( + OpPassManager &pm, const StablehloBackendPipelineOptions &options) { + // Generate Stablehlo ops. + pm.addNestedPass(createConvertTorchToStablehloPass( + options.enableStaticShape, options.enableI32Index)); + // Lowering remained ops to Arith + pm.addNestedPass(createConvertTorchToArithPass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // StableHLO backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to Stablehlo and Chlo ops. + pm.addPass(createVerifyStablehloBackendContractPass()); +} diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3..8522a9032 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9..209fb944b 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index a25a66bbb..d814d8073 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -11,7 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 4bcc02344..95172af1f 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp similarity index 86% rename from lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp rename to lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp index 888f29ade..d818558ae 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp @@ -6,8 +6,9 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "PassDetail.h" + +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -18,11 +19,9 @@ #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch::TorchConversion; namespace { class VerifyStablehloBackendContractPass @@ -45,7 +44,8 @@ class VerifyStablehloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp( + opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); @@ -58,7 +58,6 @@ class VerifyStablehloBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { +mlir::torch::createVerifyStablehloBackendContractPass() { return std::make_unique(); } -#endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6..763a0c712 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToTMTensor/CMakeLists.txt b/lib/Conversion/TorchToTMTensor/CMakeLists.txt deleted file mode 100644 index d05d8277c..000000000 --- a/lib/Conversion/TorchToTMTensor/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToTMTensor -TorchToTMTensor.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRLinalgDialect - MLIRMathDialect - TorchMLIRTorchDialect - TorchMLIRTMTensorDialect - TorchMLIRTorchUtils -) - -torch_mlir_target_includes(TorchMLIRTorchToTMTensor) diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index 909ee3bcb..e6f9a0b15 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -1,13 +1,15 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa + Passes.cpp TorchToTosa.cpp TosaLegalizeUtils.cpp TosaLegalizeCommon.cpp + VerifyTosaBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionTosaPassIncGen LINK_COMPONENTS Core @@ -16,6 +18,8 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa MLIRIR MLIRPass MLIRTosaDialect + MLIRTransforms + TorchMLIRTorchConversionPasses TorchMLIRConversionUtils TorchMLIRTorchDialect ) diff --git a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h b/lib/Conversion/TorchToTosa/PassDetail.h similarity index 56% rename from include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h rename to lib/Conversion/TorchToTosa/PassDetail.h index 2b42c3291..9f56b159f 100644 --- a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h +++ b/lib/Conversion/TorchToTosa/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,16 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H -#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTMTensorPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H diff --git a/lib/Conversion/TorchToTosa/Passes.cpp b/lib/Conversion/TorchToTosa/Passes.cpp new file mode 100644 index 000000000..67b9e5347 --- /dev/null +++ b/lib/Conversion/TorchToTosa/Passes.cpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc" +} // end namespace + +void mlir::torch::registerTosaConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration<>( + "torch-backend-to-tosa-backend-pipeline", + "Pipeline lowering torch backend contract to TOSA backend " + "contract.", + createTorchBackendToTosaBackendPipeline); +} + +void mlir::torch::createTorchBackendToTosaBackendPipeline(OpPassManager &pm) { + pm.addNestedPass(createConvertTorchToTosaPass()); + // Perform rank broadcasting so TosaToLinalg pass works + pm.addNestedPass(createTosaMakeBroadcastablePass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // TOSA backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to the form that TOSA backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(createVerifyTosaBackendContractPass()); +} diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index bf2f20d82..e47b04e37 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp similarity index 91% rename from lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp rename to lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp index a29e14a3d..770fb1195 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -16,11 +17,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch::TorchConversion; namespace { class VerifyTosaBackendContractPass @@ -62,6 +61,6 @@ class VerifyTosaBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { +mlir::torch::createVerifyTosaBackendContractPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 407e90247..b4222c871 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -11,8 +11,17 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" +} // namespace + void mlir::torch::registerTorchPasses() { - mlir::torch::registerPasses(); + ::registerPasses(); mlir::PassPipelineRegistration( "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 6495e4682..22c3e7454 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,20 +1,10 @@ set(LinkedLibs MLIRFuncTransforms MLIRIR - MLIRLinalgTransforms - MLIRMemRefTransforms MLIRPass - MLIRTosaTransforms - MLIRVectorTransforms TorchMLIRTorchConversionDialect - TorchMLIRTorchConversionToMLProgram TorchMLIRTorchDialect TorchMLIRTorchPasses - TorchMLIRTorchToArith - TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToTMTensor - TorchMLIRTorchToTosa ) if(TORCH_MLIR_ENABLE_STABLEHLO) @@ -27,9 +17,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp ConvertCustomQuantOp.cpp UnpackQuantTensor.cpp - VerifyLinalgOnTensorsBackendContract.cpp - VerifyTosaBackendContract.cpp - VerifyStablehloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 09e99057e..e65770878 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -8,27 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#endif -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Pass registration @@ -39,111 +22,4 @@ namespace reg { #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" } // end namespace reg -void mlir::torch::registerTorchConversionPasses() { - reg::registerPasses(); - mlir::PassPipelineRegistration<>( - "torch-backend-to-linalg-on-tensors-backend-pipeline", - "Pipeline lowering torch backend contract to linalg-on-tensors backend " - "contract.", - TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - - mlir::PassPipelineRegistration<>( - "torch-backend-to-tosa-backend-pipeline", - "Pipeline lowering torch backend contract to TOSA backend " - "contract.", - TorchConversion::createTorchBackendToTosaBackendPipeline); -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::PassPipelineRegistration< - TorchConversion::StablehloBackendPipelineOptions>( - "torch-backend-to-stablehlo-backend-pipeline", - "Pipeline lowering torch backend contract to StableHLO backend " - "contract.", - TorchConversion::createTorchBackendToStablehloBackendPipeline); -#endif -} - -void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( - OpPassManager &pm) { - // Lower to linalg + guards which is the input to codegen backends. - // We do this first as it tends to involve pattern-matching against constants, - // (e.g. dimensions which must be constant in a ranked programming model) - // and those constants get somewhat obscured by TorchToArith. - pm.addNestedPass(createConvertTorchToTMTensorPass()); - pm.addNestedPass(createConvertTorchToLinalgPass()); - pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(createConvertTorchToArithPass()); - pm.addPass(createConvertTorchConversionToMLProgramPass()); - pm.addNestedPass(memref::createExpandOpsPass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // Resolve `dim` ops on tensors (which currently live in the `memref` - // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass( - memref::createResolveShapedTypeResultDimsPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // linalg-on-tensors backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to the form that linalg on tensors backends - // expect. This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); -} - -void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { - pm.addNestedPass(createConvertTorchToTosaPass()); - // Perform rank broadcasting so TosaToLinalg pass works - pm.addNestedPass(createTosaMakeBroadcastablePass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // TOSA backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to the form that TOSA backends - // expect. This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -void TorchConversion::createTorchBackendToStablehloBackendPipeline( - OpPassManager &pm, - const TorchConversion::StablehloBackendPipelineOptions &options) { - // Generate Stablehlo ops. - pm.addNestedPass(createConvertTorchToStablehloPass( - options.enableStaticShape, options.enableI32Index)); - // Lowering remained ops to Arith - pm.addNestedPass(createConvertTorchToArithPass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // StableHLO backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to Stablehlo and Chlo ops. - pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); -} -#endif +void mlir::torch::registerTorchConversionPasses() { reg::registerPasses(); } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 88bceb013..9ea0f3f26 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -21,8 +21,17 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_LINALG +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" +#endif + #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "mhlo/transforms/passes.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" +#endif + +#ifdef TORCH_MLIR_ENABLE_TOSA +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #endif void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { @@ -41,11 +50,20 @@ void mlir::torch::registerAllPasses() { mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::TMTensor::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_LINALG + mlir::torch::registerLinalgConversionPasses(); +#endif // TORCH_MLIR_ENABLE_LINALG + #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::mhlo::registerSymbolicShapeOptimizationPass(); mlir::mhlo::registerStablehloLegalizeToHloPass(); mlir::mhlo::registerChloLegalizeToHloPass(); mlir::mhlo::registerHloLegalizeToLinalgPass(); mlir::mhlo::registerTestUnfuseBatchNormPass(); + mlir::torch::registerStablehloConversionPasses(); #endif // TORCH_MLIR_ENABLE_STABLEHLO + +#ifdef TORCH_MLIR_ENABLE_TOSA + mlir::torch::registerTosaConversionPasses(); +#endif // TORCH_MLIR_ENABLE_TOSA }