diff --git a/CMakeLists.txt b/CMakeLists.txt index a3c636fc6..c01a7d43a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,10 +36,19 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() -option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) +# Optional conversion targets. +if(TORCH_MLIR_ENABLE_LINALG) + add_definitions(-DTORCH_MLIR_ENABLE_LINALG) +endif() +option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect conversions" ON) if(TORCH_MLIR_ENABLE_STABLEHLO) add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) endif() +option(TORCH_MLIR_ENABLE_TOSA "Add tosa dialect conversions" ON) +if(TORCH_MLIR_ENABLE_TOSA) + add_definitions(-DTORCH_MLIR_ENABLE_TOSA) +endif() +option(TORCH_MLIR_ENABLE_LINALG "Add linalg dialect" ON) option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d65523149..90fa85ba2 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,9 +1,14 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) -else() - mlir_tablegen(Passes.h.inc -gen-pass-decls) -endif() -add_public_tablegen_target(TorchMLIRConversionPassIncGen) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionCommonPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionCommonPasses ./ -gen-pass-doc) -add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) +if(TORCH_MLIR_ENABLE_LINALG) + add_subdirectory(TorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + add_subdirectory(TorchToTosa) +endif() +if(TORCH_MLIR_ENABLE_STABLEHLO) + add_subdirectory(TorchToStablehlo) +endif() \ No newline at end of file diff --git a/include/torch-mlir/Conversion/Passes.h b/include/torch-mlir/Conversion/Passes.h index 8ab6eb56b..df7515be2 100644 --- a/include/torch-mlir/Conversion/Passes.h +++ b/include/torch-mlir/Conversion/Passes.h @@ -7,16 +7,21 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_PASSES_H -#define TORCHMLIR_CONVERSION_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include namespace mlir { namespace torch { -/// Registers all torch-mlir conversion passes. +std::unique_ptr> createConvertTorchToArithPass(); +std::unique_ptr> createConvertTorchToSCFPass(); + +// Note that this only registers common conversion passes. Backend +// specific passes with their own Passes.h in a subdirectory must be +// included/registered explicitly as they are all optional. void registerConversionPasses(); } // namespace torch } // namespace mlir - -#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472..1b7b7d281 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -13,7 +13,7 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// -// Torch conversions +// Common conversions //===----------------------------------------------------------------------===// def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { @@ -26,132 +26,4 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToSCFPass()"; } -def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { - let summary = "Convert recognized Torch ops to Linalg ops"; - let description = [{ - Convert ATen ops to linalg ops. - - This pass's main responsibility is to bridge the world between ops - that safely terminate the program in case of operand shape mismatches - (ATen) and ops where such mismatches are undefined behavior (linalg). - - To model the termination of the program for implementing error guards, - we use the `cf.assert` op. - This is a design decision that is at variance from other passes in the - ecosystem, which use the - `shape` dialect's witness system (`shape.cstr_*` family of ops feeding into - `shape.assuming` regions). This is a change in design decisions - from those passes (which the authors of this pass have contributed to). - The reasons for this change are heuristic, but boil down to: - 1. The modeling of `shape.assuming` is odd, as it uses a region, which is - not a good fit for modeling error guards. Regions mark a "start" and an - "end" (which is their nesting property). But - modeling assertions in the program doesn't fit into that. For assertions, - only the "start" matters (once tested, a predicate remains true "forever" - -- it doesn't end at the "yield" of the region). - Thus, having regions places arbitrary "end"s that just add IR structure - that has no semantic value for modeling this problem! (and to make things - worse the "end"s, which we don't need, are what require "yielding" - values, which interrupts use-def chains). Consider the different - structural properties of regions: - a. IsolatedFromAbove region: - - "start" interrupts use-def chains, - - "end" interrupts use-def chains - - structurally protects from intra-block upward and downward - code motion - b. Capturing region (like `shape.assuming`): - - "start" does not interrupt use-def chains, - - "end" interrupts use-def chains - - structurally protects from intra-block upward and downward - code motion - c. What we "ideally" want: - - "start" interrupts use-def chains (can be pruned though) - - no "end" IR structure! - - structurally protects from intra-block upward code motion - (but not downward code motion!) - - Observation: We probably can't get all of this, but overall this - problem is much better suited for a "MemorySSA"-like - abstraction, call it "EffectSSA" which is constructed on-demand - based on MLIR's effect modeling system (rather than - `shape.assuming`, which only covers the effects the IR creator - encoded -- with witnesses/`shape.assuming` -- it is easy to forget - to handle effects other than those encoded in the - witness structure). - 2. The presence of `shape.assuming` regions tends to create highly nested - IR structures, which don't interoperate well with any other IR - structures, and creates very bulky IR (and IR creation code). In general - if we are going to do anything with anything (e.g. canonicalize) we - end up needing need to either: - a. Flatten the `shape.assuming` IR (defeating the purpose of having - it). - b. Do some sort of shape.assuming "region merging". - c. Have special patterns that handle a subset of special cases (looking - through "yields" and such) and don't generalize. - 3. Witnesses tend to encourage non-scalable peephole transformations, which - tend to make analyses/transformations non-robust to the presence of - control flow and side effecting ops (easy to forget to handle side - effects other than those modeled by the witness system). - 4. All this code operates on ranked tensors, for which using individual - SSA values for sizes (rather than a "shape type") seems to - work really well at this level of abstraction based on prior experience - in other projects. (unranked code tends to benefit from having a discrete - "shape type" to model shapes). - - We will see if we end up needing something like `shape.assuming`, but for - now, it seems likely we can do something simpler and just bypass it. The - design of having an EffectSSA that is constructed on-demand seems very - compelling for modeling effects more broadly. - }]; - let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; -} - -def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { - let summary = "Convert Torch ops to TOSA ops"; - let description = [{ - This pass assumes that TOSA ops are responsible for emitting error - guards in case of shape mismatches. - }]; - let constructor = "mlir::torch::createConvertTorchToTosaPass()"; -} - -def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { - let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; - let description = [{ - Convert ATen ops to tmtensor/linalg ops. - - This pass is similar to the TorchToLinalg pass; the difference is that this - pass also makes use of TMTensor Dialect, which the former one doesn't. - }]; - let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; -} - -def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { - let summary = "Convert recognized TorchConversion ops to MLProgram ops"; - let description = [{ - Convert TorchConversion ops to mlprogram ops. - }]; - let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { - let summary = "Convert Torch ops to Stablehlo ops"; - let description = [{ - Convert Torch ops to Stablehlo ops. - }]; - let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; - - // Specify any options. - let options = [ - Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", - "Enable static shape conversion">, - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes - // are unlikely to exceed the range of i32(4GiB) - Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", - "Enable truncate index from i64 to i32(unsafely)">, - ]; -} -#endif - #endif // TORCHMLIR_CONVERSION_PASSES diff --git a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h deleted file mode 100644 index 6d14ec927..000000000 --- a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h +++ /dev/null @@ -1,23 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H -#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -std::unique_ptr> -createConvertTorchConversionToMLProgramPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt new file mode 100644 index 000000000..1b78c36b1 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionLinalgPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionLinalgPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Passes.h b/include/torch-mlir/Conversion/TorchToLinalg/Passes.h new file mode 100644 index 000000000..adc5a31ce --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/Passes.h @@ -0,0 +1,41 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES_H +#define TORCHMLIR_CONVERSION_LINALG_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +/// Creates a pipeline that lowers from the torch backend contract to the +/// linalg-on-tensors backend contract. +void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); + +std::unique_ptr> +createVerifyLinalgOnTensorsBackendContractPass(); + +std::unique_ptr> createConvertTorchToLinalgPass(); + +std::unique_ptr> +createConvertTorchConversionToMLProgramPass(); + +std::unique_ptr> createConvertTorchToTMTensorPass(); + +/// Registers all torch-mlir conversion passes. +void registerLinalgConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Passes.td b/include/torch-mlir/Conversion/TorchToLinalg/Passes.td new file mode 100644 index 000000000..168b52613 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToLinalg/Passes.td @@ -0,0 +1,132 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES +#define TORCHMLIR_CONVERSION_LINALG_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to Std ops"; + let constructor = "mlir::torch::createConvertTorchToArithPass()"; +} + +def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to SCF ops"; + let constructor = "mlir::torch::createConvertTorchToSCFPass()"; +} + +def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to Linalg ops"; + let description = [{ + Convert ATen ops to linalg ops. + + This pass's main responsibility is to bridge the world between ops + that safely terminate the program in case of operand shape mismatches + (ATen) and ops where such mismatches are undefined behavior (linalg). + + To model the termination of the program for implementing error guards, + we use the `cf.assert` op. + This is a design decision that is at variance from other passes in the + ecosystem, which use the + `shape` dialect's witness system (`shape.cstr_*` family of ops feeding into + `shape.assuming` regions). This is a change in design decisions + from those passes (which the authors of this pass have contributed to). + The reasons for this change are heuristic, but boil down to: + 1. The modeling of `shape.assuming` is odd, as it uses a region, which is + not a good fit for modeling error guards. Regions mark a "start" and an + "end" (which is their nesting property). But + modeling assertions in the program doesn't fit into that. For assertions, + only the "start" matters (once tested, a predicate remains true "forever" + -- it doesn't end at the "yield" of the region). + Thus, having regions places arbitrary "end"s that just add IR structure + that has no semantic value for modeling this problem! (and to make things + worse the "end"s, which we don't need, are what require "yielding" + values, which interrupts use-def chains). Consider the different + structural properties of regions: + a. IsolatedFromAbove region: + - "start" interrupts use-def chains, + - "end" interrupts use-def chains + - structurally protects from intra-block upward and downward + code motion + b. Capturing region (like `shape.assuming`): + - "start" does not interrupt use-def chains, + - "end" interrupts use-def chains + - structurally protects from intra-block upward and downward + code motion + c. What we "ideally" want: + - "start" interrupts use-def chains (can be pruned though) + - no "end" IR structure! + - structurally protects from intra-block upward code motion + (but not downward code motion!) + - Observation: We probably can't get all of this, but overall this + problem is much better suited for a "MemorySSA"-like + abstraction, call it "EffectSSA" which is constructed on-demand + based on MLIR's effect modeling system (rather than + `shape.assuming`, which only covers the effects the IR creator + encoded -- with witnesses/`shape.assuming` -- it is easy to forget + to handle effects other than those encoded in the + witness structure). + 2. The presence of `shape.assuming` regions tends to create highly nested + IR structures, which don't interoperate well with any other IR + structures, and creates very bulky IR (and IR creation code). In general + if we are going to do anything with anything (e.g. canonicalize) we + end up needing need to either: + a. Flatten the `shape.assuming` IR (defeating the purpose of having + it). + b. Do some sort of shape.assuming "region merging". + c. Have special patterns that handle a subset of special cases (looking + through "yields" and such) and don't generalize. + 3. Witnesses tend to encourage non-scalable peephole transformations, which + tend to make analyses/transformations non-robust to the presence of + control flow and side effecting ops (easy to forget to handle side + effects other than those modeled by the witness system). + 4. All this code operates on ranked tensors, for which using individual + SSA values for sizes (rather than a "shape type") seems to + work really well at this level of abstraction based on prior experience + in other projects. (unranked code tends to benefit from having a discrete + "shape type" to model shapes). + + We will see if we end up needing something like `shape.assuming`, but for + now, it seems likely we can do something simpler and just bypass it. The + design of having an EffectSSA that is constructed on-demand seems very + compelling for modeling effects more broadly. + }]; + let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; +} + +def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { + let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; + let description = [{ + Convert ATen ops to tmtensor/linalg ops. + + This pass is similar to the TorchToLinalg pass; the difference is that this + pass also makes use of TMTensor Dialect, which the former one doesn't. + }]; + let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; +} + +def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { + let summary = "Convert recognized TorchConversion ops to MLProgram ops"; + let description = [{ + Convert TorchConversion ops to mlprogram ops. + }]; + let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; +} + +def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the linalg-on-tensors backend contract"; + let constructor = "mlir::torch::createVerifyLinalgOnTensorsBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_LINALG_PASSES diff --git a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h b/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h deleted file mode 100644 index 53caf8598..000000000 --- a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h +++ /dev/null @@ -1,24 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H -#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { -namespace torch { -std::unique_ptr> createConvertTorchToLinalgPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H diff --git a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h b/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h deleted file mode 100644 index 7b869dae4..000000000 --- a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h +++ /dev/null @@ -1,22 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H -#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace torch { -std::unique_ptr> createConvertTorchToSCFPass(); -} -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt new file mode 100644 index 000000000..ccb92e011 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionStablehloPassIncGen) + +add_mlir_doc(Passes TorchMLIRConversionStablehloPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h new file mode 100644 index 000000000..cfcb8a1da --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.h @@ -0,0 +1,51 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +struct StablehloBackendPipelineOptions + : public PassPipelineOptions { + Option enableStaticShape{ + *this, "enable-static-shape", + llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option enableI32Index{ + *this, "enable-i32-index", + llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"), + llvm::cl::init(false)}; +}; + +void createTorchBackendToStablehloBackendPipeline( + OpPassManager &pm, const StablehloBackendPipelineOptions &options); +std::unique_ptr> +createVerifyStablehloBackendContractPass(); + +std::unique_ptr> +createConvertTorchToStablehloPass(); +std::unique_ptr> +createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); + +/// Registers all torch-mlir conversion passes. +void registerStablehloConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td new file mode 100644 index 000000000..d13ed0528 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToStablehlo/Passes.td @@ -0,0 +1,43 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { + let summary = "Convert Torch ops to Stablehlo ops"; + let description = [{ + Convert Torch ops to Stablehlo ops. + }]; + let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; + + // Specify any options. + let options = [ + Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", + "Enable static shape conversion">, + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", + "Enable truncate index from i64 to i32(unsafely)">, + ]; +} + +def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the stablehlo backend contract"; + let constructor = "mlir::torch::createVerifyStablehloBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h deleted file mode 100644 index c19260159..000000000 --- a/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h +++ /dev/null @@ -1,26 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H -#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" -#include - -namespace mlir { -namespace torch { -std::unique_ptr> -createConvertTorchToStablehloPass(); -std::unique_ptr> -createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); -} // namespace torch -} // namespace mlir - -#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt b/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt new file mode 100644 index 000000000..d8036af77 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionTosaPassIncGen) + +add_mlir_doc(Passes TorchMLIRConversionTosaPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchToTosa/Passes.h b/include/torch-mlir/Conversion/TorchToTosa/Passes.h new file mode 100644 index 000000000..c9b0af837 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/Passes.h @@ -0,0 +1,35 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES_H +#define TORCHMLIR_CONVERSION_TOSA_PASSES_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { + +/// Creates a pipeline that lowers from the torch backend contract to the +/// TOSA backend contract. +void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); + +std::unique_ptr> createVerifyTosaBackendContractPass(); + +std::unique_ptr> createConvertTorchToTosaPass(); + +/// Registers all torch-mlir conversion passes. +void registerTosaConversionPasses(); + +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_PASSES_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/Passes.td b/include/torch-mlir/Conversion/TorchToTosa/Passes.td new file mode 100644 index 000000000..d2ab7d3e8 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTosa/Passes.td @@ -0,0 +1,33 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES +#define TORCHMLIR_CONVERSION_TOSA_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// Torch conversions +//===----------------------------------------------------------------------===// + +def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { + let summary = "Convert Torch ops to TOSA ops"; + let description = [{ + This pass assumes that TOSA ops are responsible for emitting error + guards in case of shape mismatches. + }]; + let constructor = "mlir::torch::createConvertTorchToTosaPass()"; +} + +def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the linalg-on-tensors backend contract"; + let constructor = "mlir::torch::createVerifyTosaBackendContractPass()"; +} + +#endif // TORCHMLIR_CONVERSION_TOSA_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 84efddcc9..ddbb2d633 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -21,8 +21,6 @@ class ModuleOp; namespace torch { namespace Torch { -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - std::unique_ptr> createGlobalizeObjectGraphPass(); std::unique_ptr> @@ -141,13 +139,6 @@ static const char kTorchOpPrefix[] = R"(torch.)"; /// Registers all Torch transformation passes. void registerTorchPasses(); -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - -#define GEN_PASS_REGISTRATION -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" - } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840..6a6d42225 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -22,36 +22,6 @@ class ModuleOp; namespace torch { namespace TorchConversion { -/// Creates a pipeline that lowers from the torch backend contract to the -/// linalg-on-tensors backend contract. -void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); - -/// Creates a pipeline that lowers from the torch backend contract to the -/// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); - -// Do not register the stablehlo options if the stablehlo target is disabled -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -struct StablehloBackendPipelineOptions - : public PassPipelineOptions { - Option enableStaticShape{ - *this, "enable-static-shape", - llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes - // are unlikely to exceed the range of i32(4GiB) - Option enableI32Index{ - *this, "enable-i32-index", - llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"), - llvm::cl::init(false)}; -}; - -void createTorchBackendToStablehloBackendPipeline( - OpPassManager &pm, const StablehloBackendPipelineOptions &options); -std::unique_ptr> -createVerifyStablehloBackendContractPass(); -#endif - std::unique_ptr> createFuncBackendTypeConversionPass(); std::unique_ptr> @@ -65,11 +35,6 @@ createFinalizingBackendTypeConversionPass(); std::unique_ptr> createUnpackQuantTensorPass(); std::unique_ptr> createConvertCustomQuantOpPass(); -std::unique_ptr> -createVerifyLinalgOnTensorsBackendContractPass(); - -std::unique_ptr> createVerifyTosaBackendContractPass(); - } // namespace TorchConversion /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81..f81a7defa 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -32,23 +32,6 @@ def FinalizingBackendTypeConversion }]; } -def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the linalg-on-tensors backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; -} - -def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the linalg-on-tensors backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the stablehlo backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; -} -#endif // TORCH_MLIR_ENABLE_STABLEHLO - // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index d71796ae8..8dea1689c 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -8,6 +8,9 @@ add_mlir_public_c_api_library(TorchMLIRCAPI ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ + DEPENDS + TorchMLIRTorchPassIncGen + ENABLE_AGGREGATION LINK_COMPONENTS Core diff --git a/lib/CAPI/Transforms.cpp b/lib/CAPI/Transforms.cpp index f0f57a72d..c29467aaa 100644 --- a/lib/CAPI/Transforms.cpp +++ b/lib/CAPI/Transforms.cpp @@ -9,6 +9,11 @@ #include "mlir/CAPI/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" +} // namespace + // Must include the declarations as they carry important visibility attributes. #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index fb1da3fb6..24e3b2a5c 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -23,8 +23,16 @@ set(LinkedLibs TorchMLIRRefBackend ) +# Conditionally link in backends if enabled. +if(TORCH_MLIR_ENABLE_LINALG) + list(APPEND LinkedLibs TorchMLIRTorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + list(APPEND LinkedLibs TorchMLIRTorchToTosa) +endif() if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs + TorchMLIRTorchToStablehlo MhloPasses MhloToLinalg StablehloToMhlo diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index d72563b1e..8790915b1 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,36 +1,18 @@ -add_subdirectory(TorchToLinalg) -add_subdirectory(TorchToSCF) -add_subdirectory(TorchToArith) -add_subdirectory(TorchToTosa) +if(TORCH_MLIR_ENABLE_LINALG) + add_subdirectory(TorchToLinalg) +endif() +if(TORCH_MLIR_ENABLE_TOSA) + add_subdirectory(TorchToTosa) +endif() if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) endif() -add_subdirectory(TorchToTMTensor) -add_subdirectory(TorchConversionToMLProgram) add_subdirectory(Utils) -# TODO: Automate this with add_torch_mlir_conversion_library. -set(linked_libs TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToArith - TorchMLIRTorchToTosa - TorchMLIRTorchToTMTensor - TorchMLIRTorchConversionToMLProgram - TorchMLIRConversionUtils) -if(TORCH_MLIR_ENABLE_STABLEHLO) - list(APPEND linked_libs TorchMLIRTorchToStablehlo) -endif() - add_mlir_library(TorchMLIRConversionPasses Passes.cpp - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - ${linked_libs} - #${torch_mlir_conversion_libs} + TorchMLIRConversionUtils ) + +add_subdirectory(Common) \ No newline at end of file diff --git a/lib/Conversion/Common/CMakeLists.txt b/lib/Conversion/Common/CMakeLists.txt new file mode 100644 index 000000000..b14a16730 --- /dev/null +++ b/lib/Conversion/Common/CMakeLists.txt @@ -0,0 +1,28 @@ +add_mlir_conversion_library(TorchMLIRConversionCommon + TorchToArith.cpp + TorchToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion + + DEPENDS + TorchMLIRConversionCommonPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRIR + MLIRFuncDialect + MLIRPass + MLIRMathDialect + MLIRSCFDialect + MLIRTransforms + TorchMLIRTorchDialect + TorchMLIRTorchConversionDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchUtils +) + +torch_mlir_target_includes(TorchMLIRConversionCommon) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/Common/TorchToArith.cpp similarity index 96% rename from lib/Conversion/TorchToArith/TorchToArith.cpp rename to lib/Conversion/Common/TorchToArith.cpp index 9e3cc2f75..fcfe11754 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/Common/TorchToArith.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" +#include "torch-mlir/Conversion/Passes.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -43,7 +43,8 @@ public: LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); + auto rank = + rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -74,7 +75,8 @@ public: matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); + rewriter.template replaceOpWithNewOp(op, adaptor.getA(), + adaptor.getB()); return success(); } }; @@ -112,10 +114,10 @@ public: typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value a = - convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); - Value b = - convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); + Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), + rewriter.getF64Type()); + Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), + rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } @@ -180,7 +182,8 @@ public: })); return success(); } - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = + op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -357,7 +360,8 @@ public: // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToArith : public ConvertTorchToArithBase { +class ConvertTorchToArith + : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/Common/TorchToSCF.cpp similarity index 99% rename from lib/Conversion/TorchToSCF/TorchToSCF.cpp rename to lib/Conversion/Common/TorchToSCF.cpp index 146959151..c2b3f085d 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/Common/TorchToSCF.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" +#include "torch-mlir/Conversion/Passes.h" #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 45714601d..509ef0441 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,22 +9,6 @@ #include "torch-mlir/Conversion/Passes.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "transforms/passes.h" -#endif // TORCH_MLIR_ENABLE_STABLEHLO - -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" - -//===----------------------------------------------------------------------===// -// Pass registration -//===----------------------------------------------------------------------===// - namespace { #define GEN_PASS_REGISTRATION #include "torch-mlir/Conversion/Passes.h.inc" diff --git a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt deleted file mode 100644 index b89ffbb43..000000000 --- a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram - TorchConversionToMLProgram.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLinalgDialect - MLIRMLProgramDialect - MLIRMathDialect - MLIRPass - TorchMLIRTorchDialect -) - -torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram) diff --git a/lib/Conversion/TorchToArith/CMakeLists.txt b/lib/Conversion/TorchToArith/CMakeLists.txt deleted file mode 100644 index 4524c3b07..000000000 --- a/lib/Conversion/TorchToArith/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToArith - TorchToArith.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRFuncDialect - TorchMLIRTorchDialect -) - -torch_mlir_target_includes(TorchMLIRTorchToArith) diff --git a/lib/Conversion/TorchToLinalg/CMakeLists.txt b/lib/Conversion/TorchToLinalg/CMakeLists.txt index ece929597..8913479f4 100644 --- a/lib/Conversion/TorchToLinalg/CMakeLists.txt +++ b/lib/Conversion/TorchToLinalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg + Passes.cpp DataMovement.cpp IndirectDataMovement.cpp Linear.cpp @@ -7,25 +8,37 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg Reduction.cpp TensorConstructors.cpp TensorScalarInterop.cpp + TorchConversionToMLProgram.cpp TorchToLinalg.cpp + TorchToTMTensor.cpp Uncategorized.cpp Utils.cpp + VerifyLinalgOnTensorsBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionLinalgPassIncGen LINK_COMPONENTS Core LINK_LIBS PUBLIC MLIRIR + MLIRFuncDialect MLIRPass MLIRLinalgDialect MLIRMathDialect + MLIRMLProgramDialect + MLIRSCFDialect + MLIRTransforms + TorchMLIRConversionCommon TorchMLIRTorchDialect + TorchMLIRTorchConversionDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchUtils + TorchMLIRTMTensorDialect ) torch_mlir_target_includes(TorchMLIRTorchToLinalg) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6cf6eb3be..1e4a2b2df 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -7,13 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" - -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -21,7 +15,12 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632..42ac1a800 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 65f08a4d7..3a5e8f78d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h b/lib/Conversion/TorchToLinalg/PassDetail.h similarity index 55% rename from include/torch-mlir/Conversion/TorchToArith/TorchToArith.h rename to lib/Conversion/TorchToLinalg/PassDetail.h index ab708557b..36a83aa50 100644 --- a/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h +++ b/lib/Conversion/TorchToLinalg/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,17 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H -#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H +#ifndef TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToArithPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H diff --git a/lib/Conversion/TorchToLinalg/Passes.cpp b/lib/Conversion/TorchToLinalg/Passes.cpp new file mode 100644 index 000000000..83bed2627 --- /dev/null +++ b/lib/Conversion/TorchToLinalg/Passes.cpp @@ -0,0 +1,75 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" + +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc" +} // end namespace + +void mlir::torch::registerLinalgConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration<>( + "torch-backend-to-linalg-on-tensors-backend-pipeline", + "Pipeline lowering torch backend contract to linalg-on-tensors backend " + "contract.", + createTorchBackendToLinalgOnTensorsBackendPipeline); +} + +void mlir::torch::createTorchBackendToLinalgOnTensorsBackendPipeline( + OpPassManager &pm) { + // Lower to linalg + guards which is the input to codegen backends. + // We do this first as it tends to involve pattern-matching against constants, + // (e.g. dimensions which must be constant in a ranked programming model) + // and those constants get somewhat obscured by TorchToArith. + pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(createConvertTorchToArithPass()); + pm.addPass(createConvertTorchConversionToMLProgramPass()); + pm.addNestedPass(memref::createExpandOpsPass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // Resolve `dim` ops on tensors (which currently live in the `memref` + // dialect for some reason -- we don't have memrefs at this level). + pm.addNestedPass( + memref::createResolveShapedTypeResultDimsPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // linalg-on-tensors backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to the form that linalg on tensors backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(createVerifyLinalgOnTensorsBackendContractPass()); +} diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 1d7ff925b..a49432b34 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index e1a3e416c..fac632dfc 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 4078fbaa3..3b5646169 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 7e73fabd8..f4e7cf859 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62..a8a050eef 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp similarity index 96% rename from lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp rename to lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp index eab81c2be..e0a4f2653 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchToLinalg/TorchConversionToMLProgram.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -82,7 +82,8 @@ public: // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = + rewriter.create(loc, seed, globalVar, ValueRange()); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1f9f4b17b..e9c8ca021 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp similarity index 99% rename from lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp rename to lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp index a2d58daac..bec346870 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToTMTensor.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -1273,13 +1273,13 @@ public: // Set the values in the input tensor to the smallest element of that // type TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/true); + /*getMin=*/true); normalizationValue = rewriter.create(loc, minAttr); } else if (reduceEnum == torch_upstream::ReductionType::MIN) { // Set the values in the input tensor to the largest element of that // type TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), - /*getMin=*/false); + /*getMin=*/false); normalizationValue = rewriter.create(loc, maxAttr); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8684b68d9..8259d0416 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 4a47790b0..1dee87725 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "Utils.h" +#include "./Utils.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp similarity index 95% rename from lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp rename to lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp index 93d7de825..79e7cd64e 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Conversion/TorchToLinalg/VerifyLinalgOnTensorsBackendContract.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -24,7 +25,6 @@ #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/IR/BuiltinOps.h" @@ -33,7 +33,6 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; - namespace { class VerifyLinalgOnTensorsBackendContractPass : public VerifyLinalgOnTensorsBackendContractBase< @@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) - << "Module does not conform to the linalg-on-tensors backend contract. " + << "Module does not conform to the linalg-on-tensors backend " + "contract. " "See dialect conversion legality information above."; return signalPassFailure(); } @@ -105,6 +105,6 @@ class VerifyLinalgOnTensorsBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() { +mlir::torch::createVerifyLinalgOnTensorsBackendContractPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToSCF/CMakeLists.txt b/lib/Conversion/TorchToSCF/CMakeLists.txt deleted file mode 100644 index 06a888b3e..000000000 --- a/lib/Conversion/TorchToSCF/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToSCF - TorchToSCF.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRSCFDialect - MLIRFuncDialect - TorchMLIRTorchDialect - TorchMLIRTorchConversionDialect -) - -torch_mlir_target_includes(TorchMLIRTorchToSCF) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 0a2ed02e0..1d31c3ab9 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 84a560cd7..bd5a9057e 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo + Passes.cpp TorchToStablehlo.cpp StablehloLegalizeUtils.cpp Basic.cpp @@ -7,21 +8,25 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo ViewLike.cpp Reduction.cpp Pooling.cpp + VerifyStablehloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionStablehloPassIncGen LINK_COMPONENTS Core LINK_LIBS PUBLIC + MLIRBufferTransforms MLIRIR MLIRPass - MLIRBufferTransforms + MLIRTransforms StablehloOps + TorchMLIRConversionCommon + TorchMLIRTorchConversionPasses TorchMLIRTorchDialect TorchMLIRConversionUtils ) diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 9c8123bfd..3e098583b 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 71d679aea..a3e76a971 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/lib/Conversion/TorchToStablehlo/PassDetail.h similarity index 54% rename from include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h rename to lib/Conversion/TorchToStablehlo/PassDetail.h index a6d774a64..0cc875206 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/lib/Conversion/TorchToStablehlo/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,17 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H -#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" -#include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTosaPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H diff --git a/lib/Conversion/TorchToStablehlo/Passes.cpp b/lib/Conversion/TorchToStablehlo/Passes.cpp new file mode 100644 index 000000000..4fec1de8b --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Passes.cpp @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" + +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" +#include "transforms/passes.h" + +using namespace mlir; +using namespace mlir::torch; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc" +} // end namespace + +void mlir::torch::registerStablehloConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration( + "torch-backend-to-stablehlo-backend-pipeline", + "Pipeline lowering torch backend contract to StableHLO backend " + "contract.", + createTorchBackendToStablehloBackendPipeline); +} + +void mlir::torch::createTorchBackendToStablehloBackendPipeline( + OpPassManager &pm, const StablehloBackendPipelineOptions &options) { + // Generate Stablehlo ops. + pm.addNestedPass(createConvertTorchToStablehloPass( + options.enableStaticShape, options.enableI32Index)); + // Lowering remained ops to Arith + pm.addNestedPass(createConvertTorchToArithPass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // StableHLO backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to Stablehlo and Chlo ops. + pm.addPass(createVerifyStablehloBackendContractPass()); +} diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3..8522a9032 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9..209fb944b 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index a25a66bbb..d814d8073 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -11,7 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 4bcc02344..95172af1f 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp similarity index 86% rename from lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp rename to lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp index 888f29ade..d818558ae 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Conversion/TorchToStablehlo/VerifyStablehloBackendContract.cpp @@ -6,8 +6,9 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "PassDetail.h" + +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -18,11 +19,9 @@ #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch::TorchConversion; namespace { class VerifyStablehloBackendContractPass @@ -45,7 +44,8 @@ class VerifyStablehloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp( + opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); @@ -58,7 +58,6 @@ class VerifyStablehloBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { +mlir::torch::createVerifyStablehloBackendContractPass() { return std::make_unique(); } -#endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6..763a0c712 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -7,9 +7,9 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" diff --git a/lib/Conversion/TorchToTMTensor/CMakeLists.txt b/lib/Conversion/TorchToTMTensor/CMakeLists.txt deleted file mode 100644 index d05d8277c..000000000 --- a/lib/Conversion/TorchToTMTensor/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToTMTensor -TorchToTMTensor.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor - - DEPENDS - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRLinalgDialect - MLIRMathDialect - TorchMLIRTorchDialect - TorchMLIRTMTensorDialect - TorchMLIRTorchUtils -) - -torch_mlir_target_includes(TorchMLIRTorchToTMTensor) diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index 909ee3bcb..e6f9a0b15 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -1,13 +1,15 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa + Passes.cpp TorchToTosa.cpp TosaLegalizeUtils.cpp TosaLegalizeCommon.cpp + VerifyTosaBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa DEPENDS - TorchMLIRConversionPassIncGen + TorchMLIRConversionTosaPassIncGen LINK_COMPONENTS Core @@ -16,6 +18,8 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa MLIRIR MLIRPass MLIRTosaDialect + MLIRTransforms + TorchMLIRTorchConversionPasses TorchMLIRConversionUtils TorchMLIRTorchDialect ) diff --git a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h b/lib/Conversion/TorchToTosa/PassDetail.h similarity index 56% rename from include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h rename to lib/Conversion/TorchToTosa/PassDetail.h index 2b42c3291..9f56b159f 100644 --- a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h +++ b/lib/Conversion/TorchToTosa/PassDetail.h @@ -1,4 +1,4 @@ -//===------------------------------------------------------------*- C++ -*-===// +//===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,16 +7,20 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H -#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H +#ifndef TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTMTensorPass(); -} -} // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc" + +} // namespace torch +} // end namespace mlir + +#endif // TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H diff --git a/lib/Conversion/TorchToTosa/Passes.cpp b/lib/Conversion/TorchToTosa/Passes.cpp new file mode 100644 index 000000000..67b9e5347 --- /dev/null +++ b/lib/Conversion/TorchToTosa/Passes.cpp @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc" +} // end namespace + +void mlir::torch::registerTosaConversionPasses() { + ::registerPasses(); + mlir::PassPipelineRegistration<>( + "torch-backend-to-tosa-backend-pipeline", + "Pipeline lowering torch backend contract to TOSA backend " + "contract.", + createTorchBackendToTosaBackendPipeline); +} + +void mlir::torch::createTorchBackendToTosaBackendPipeline(OpPassManager &pm) { + pm.addNestedPass(createConvertTorchToTosaPass()); + // Perform rank broadcasting so TosaToLinalg pass works + pm.addNestedPass(createTosaMakeBroadcastablePass()); + + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + + // Finish the type conversion from `torch` types to the types of the + // TOSA backend contract. + pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass( + TorchConversion::createFinalizingBackendTypeConversionPass()); + + // Verify that we have lowered to the form that TOSA backends + // expect. This fails compilation (signalPassFailure) if the IR is not in the + // correct form. + pm.addPass(createVerifyTosaBackendContractPass()); +} diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index bf2f20d82..e47b04e37 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "../PassDetail.h" +#include "./PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp similarity index 91% rename from lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp rename to lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp index a29e14a3d..770fb1195 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Conversion/TorchToTosa/VerifyTosaBackendContract.cpp @@ -7,7 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "./PassDetail.h" +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -16,11 +17,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::torch::TorchConversion; namespace { class VerifyTosaBackendContractPass @@ -62,6 +61,6 @@ class VerifyTosaBackendContractPass } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { +mlir::torch::createVerifyTosaBackendContractPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 407e90247..b4222c871 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -11,8 +11,17 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc" +} // namespace + void mlir::torch::registerTorchPasses() { - mlir::torch::registerPasses(); + ::registerPasses(); mlir::PassPipelineRegistration( "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 6495e4682..22c3e7454 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,20 +1,10 @@ set(LinkedLibs MLIRFuncTransforms MLIRIR - MLIRLinalgTransforms - MLIRMemRefTransforms MLIRPass - MLIRTosaTransforms - MLIRVectorTransforms TorchMLIRTorchConversionDialect - TorchMLIRTorchConversionToMLProgram TorchMLIRTorchDialect TorchMLIRTorchPasses - TorchMLIRTorchToArith - TorchMLIRTorchToLinalg - TorchMLIRTorchToSCF - TorchMLIRTorchToTMTensor - TorchMLIRTorchToTosa ) if(TORCH_MLIR_ENABLE_STABLEHLO) @@ -27,9 +17,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp ConvertCustomQuantOp.cpp UnpackQuantTensor.cpp - VerifyLinalgOnTensorsBackendContract.cpp - VerifyTosaBackendContract.cpp - VerifyStablehloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 09e99057e..e65770878 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -8,27 +8,10 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" -#include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" -#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#endif -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; -using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Pass registration @@ -39,111 +22,4 @@ namespace reg { #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" } // end namespace reg -void mlir::torch::registerTorchConversionPasses() { - reg::registerPasses(); - mlir::PassPipelineRegistration<>( - "torch-backend-to-linalg-on-tensors-backend-pipeline", - "Pipeline lowering torch backend contract to linalg-on-tensors backend " - "contract.", - TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - - mlir::PassPipelineRegistration<>( - "torch-backend-to-tosa-backend-pipeline", - "Pipeline lowering torch backend contract to TOSA backend " - "contract.", - TorchConversion::createTorchBackendToTosaBackendPipeline); -#ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::PassPipelineRegistration< - TorchConversion::StablehloBackendPipelineOptions>( - "torch-backend-to-stablehlo-backend-pipeline", - "Pipeline lowering torch backend contract to StableHLO backend " - "contract.", - TorchConversion::createTorchBackendToStablehloBackendPipeline); -#endif -} - -void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( - OpPassManager &pm) { - // Lower to linalg + guards which is the input to codegen backends. - // We do this first as it tends to involve pattern-matching against constants, - // (e.g. dimensions which must be constant in a ranked programming model) - // and those constants get somewhat obscured by TorchToArith. - pm.addNestedPass(createConvertTorchToTMTensorPass()); - pm.addNestedPass(createConvertTorchToLinalgPass()); - pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(createConvertTorchToArithPass()); - pm.addPass(createConvertTorchConversionToMLProgramPass()); - pm.addNestedPass(memref::createExpandOpsPass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // Resolve `dim` ops on tensors (which currently live in the `memref` - // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass( - memref::createResolveShapedTypeResultDimsPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // linalg-on-tensors backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to the form that linalg on tensors backends - // expect. This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); -} - -void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { - pm.addNestedPass(createConvertTorchToTosaPass()); - // Perform rank broadcasting so TosaToLinalg pass works - pm.addNestedPass(createTosaMakeBroadcastablePass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // TOSA backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to the form that TOSA backends - // expect. This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); -} - -#ifdef TORCH_MLIR_ENABLE_STABLEHLO -void TorchConversion::createTorchBackendToStablehloBackendPipeline( - OpPassManager &pm, - const TorchConversion::StablehloBackendPipelineOptions &options) { - // Generate Stablehlo ops. - pm.addNestedPass(createConvertTorchToStablehloPass( - options.enableStaticShape, options.enableI32Index)); - // Lowering remained ops to Arith - pm.addNestedPass(createConvertTorchToArithPass()); - - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - - // Finish the type conversion from `torch` types to the types of the - // StableHLO backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to Stablehlo and Chlo ops. - pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); -} -#endif +void mlir::torch::registerTorchConversionPasses() { reg::registerPasses(); } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 88bceb013..9ea0f3f26 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -21,8 +21,17 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_LINALG +#include "torch-mlir/Conversion/TorchToLinalg/Passes.h" +#endif + #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "mhlo/transforms/passes.h" +#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h" +#endif + +#ifdef TORCH_MLIR_ENABLE_TOSA +#include "torch-mlir/Conversion/TorchToTosa/Passes.h" #endif void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { @@ -41,11 +50,20 @@ void mlir::torch::registerAllPasses() { mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::TMTensor::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_LINALG + mlir::torch::registerLinalgConversionPasses(); +#endif // TORCH_MLIR_ENABLE_LINALG + #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::mhlo::registerSymbolicShapeOptimizationPass(); mlir::mhlo::registerStablehloLegalizeToHloPass(); mlir::mhlo::registerChloLegalizeToHloPass(); mlir::mhlo::registerHloLegalizeToLinalgPass(); mlir::mhlo::registerTestUnfuseBatchNormPass(); + mlir::torch::registerStablehloConversionPasses(); #endif // TORCH_MLIR_ENABLE_STABLEHLO + +#ifdef TORCH_MLIR_ENABLE_TOSA + mlir::torch::registerTosaConversionPasses(); +#endif // TORCH_MLIR_ENABLE_TOSA }