Fall cleaning: Re-organize sources so that each target type is contained and optional.

While working on https://github.com/openxla/iree/pull/14917, I noticed that it is somewhat hard to take a dependency on torch-mlir such that one only builds deps for the target(s) of interest (in this case Linalg). I noticed that some ifdef'ey optionality was added for stablehlo, but this was not mirrored for the others. Further, it does the switching very deep in the dependency graph vs having top-level directories and defines gating entire features. In addition, I noticed that a lot of things in the Linalg path were broken down to a fine level of detail but were not actually shared/shareable outside of that target. I opted to clump these together into TorchToLinalg. It is easy enough to "promote" them to common with this new organization if the need arises.

General approach:

* Isolate each conversion target in one of TorchToLinalg, TorchToStablehlo, TorchToTosa.
* Gate each by top-level CMake flags and defines.
* Common conversions go in a Common/ directory (currently Arith and SCF).
* Pull target specific conversions out of TorchConversion/Transforms and put in their top-level directory.
* General maintenance on the build graph and registration stuff that had bitrotted and was blocking progress.

The main functional change for people taking a source dep is that `#include "torch-mlir/Conversion/Passes.h"` no longer is a one stop shop: For optional conversions, you have to include the dedicated `Passes.h` of each and take a library dep. See `InitAll.cpp` which does it right (and *is* a one stop shop still).
isolate_optional_targets
Stella Laurenzo 2023-09-06 04:00:28 -07:00
parent 9cb5d38cd1
commit dd79067571
70 changed files with 800 additions and 671 deletions

View File

@ -36,10 +36,19 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro() endmacro()
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) # Optional conversion targets.
if(TORCH_MLIR_ENABLE_LINALG)
add_definitions(-DTORCH_MLIR_ENABLE_LINALG)
endif()
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect conversions" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif() endif()
option(TORCH_MLIR_ENABLE_TOSA "Add tosa dialect conversions" ON)
if(TORCH_MLIR_ENABLE_TOSA)
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
endif()
option(TORCH_MLIR_ENABLE_LINALG "Add linalg dialect" ON)
option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON)
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)

View File

@ -1,9 +1,14 @@
set(LLVM_TARGET_DEFINITIONS Passes.td) set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) add_public_tablegen_target(TorchMLIRConversionCommonPassIncGen)
else() add_mlir_doc(Passes TorchMLIRConversionCommonPasses ./ -gen-pass-doc)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
add_public_tablegen_target(TorchMLIRConversionPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) if(TORCH_MLIR_ENABLE_LINALG)
add_subdirectory(TorchToLinalg)
endif()
if(TORCH_MLIR_ENABLE_TOSA)
add_subdirectory(TorchToTosa)
endif()
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo)
endif()

View File

@ -7,16 +7,21 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_PASSES_H #include "mlir/Dialect/Func/IR/FuncOps.h"
#define TORCHMLIR_CONVERSION_PASSES_H #include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
/// Registers all torch-mlir conversion passes. std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
// Note that this only registers common conversion passes. Backend
// specific passes with their own Passes.h in a subdirectory must be
// included/registered explicitly as they are all optional.
void registerConversionPasses(); void registerConversionPasses();
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_PASSES_H

View File

@ -13,7 +13,7 @@
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Torch conversions // Common conversions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> { def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
@ -26,132 +26,4 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToSCFPass()"; let constructor = "mlir::torch::createConvertTorchToSCFPass()";
} }
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{
Convert ATen ops to linalg ops.
This pass's main responsibility is to bridge the world between ops
that safely terminate the program in case of operand shape mismatches
(ATen) and ops where such mismatches are undefined behavior (linalg).
To model the termination of the program for implementing error guards,
we use the `cf.assert` op.
This is a design decision that is at variance from other passes in the
ecosystem, which use the
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
`shape.assuming` regions). This is a change in design decisions
from those passes (which the authors of this pass have contributed to).
The reasons for this change are heuristic, but boil down to:
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
not a good fit for modeling error guards. Regions mark a "start" and an
"end" (which is their nesting property). But
modeling assertions in the program doesn't fit into that. For assertions,
only the "start" matters (once tested, a predicate remains true "forever"
-- it doesn't end at the "yield" of the region).
Thus, having regions places arbitrary "end"s that just add IR structure
that has no semantic value for modeling this problem! (and to make things
worse the "end"s, which we don't need, are what require "yielding"
values, which interrupts use-def chains). Consider the different
structural properties of regions:
a. IsolatedFromAbove region:
- "start" interrupts use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
b. Capturing region (like `shape.assuming`):
- "start" does not interrupt use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
c. What we "ideally" want:
- "start" interrupts use-def chains (can be pruned though)
- no "end" IR structure!
- structurally protects from intra-block upward code motion
(but not downward code motion!)
- Observation: We probably can't get all of this, but overall this
problem is much better suited for a "MemorySSA"-like
abstraction, call it "EffectSSA" which is constructed on-demand
based on MLIR's effect modeling system (rather than
`shape.assuming`, which only covers the effects the IR creator
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
to handle effects other than those encoded in the
witness structure).
2. The presence of `shape.assuming` regions tends to create highly nested
IR structures, which don't interoperate well with any other IR
structures, and creates very bulky IR (and IR creation code). In general
if we are going to do anything with anything (e.g. canonicalize) we
end up needing need to either:
a. Flatten the `shape.assuming` IR (defeating the purpose of having
it).
b. Do some sort of shape.assuming "region merging".
c. Have special patterns that handle a subset of special cases (looking
through "yields" and such) and don't generalize.
3. Witnesses tend to encourage non-scalable peephole transformations, which
tend to make analyses/transformations non-robust to the presence of
control flow and side effecting ops (easy to forget to handle side
effects other than those modeled by the witness system).
4. All this code operates on ranked tensors, for which using individual
SSA values for sizes (rather than a "shape type") seems to
work really well at this level of abstraction based on prior experience
in other projects. (unranked code tends to benefit from having a discrete
"shape type" to model shapes).
We will see if we end up needing something like `shape.assuming`, but for
now, it seems likely we can do something simpler and just bypass it. The
design of having an EffectSSA that is constructed on-demand seems very
compelling for modeling effects more broadly.
}];
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
This pass assumes that TOSA ops are responsible for emitting error
guards in case of shape mismatches.
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Convert ATen ops to tmtensor/linalg ops.
This pass is similar to the TorchToLinalg pass; the difference is that this
pass also makes use of TMTensor Dialect, which the former one doesn't.
}];
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
let description = [{
Convert TorchConversion ops to mlprogram ops.
}];
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
}
#endif
#endif // TORCHMLIR_CONVERSION_PASSES #endif // TORCHMLIR_CONVERSION_PASSES

View File

@ -1,23 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTorchConversionToMLProgramPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H

View File

@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionLinalgPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionLinalgPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,41 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES_H
#define TORCHMLIR_CONVERSION_LINALG_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
/// Creates a pipeline that lowers from the torch backend contract to the
/// linalg-on-tensors backend contract.
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTorchConversionToMLProgramPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
/// Registers all torch-mlir conversion passes.
void registerLinalgConversionPasses();
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_PASSES_H

View File

@ -0,0 +1,132 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_LINALG_PASSES
#define TORCHMLIR_CONVERSION_LINALG_PASSES
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// Torch conversions
//===----------------------------------------------------------------------===//
def ConvertTorchToArith : Pass<"convert-torch-to-arith", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Std ops";
let constructor = "mlir::torch::createConvertTorchToArithPass()";
}
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to SCF ops";
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
}
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{
Convert ATen ops to linalg ops.
This pass's main responsibility is to bridge the world between ops
that safely terminate the program in case of operand shape mismatches
(ATen) and ops where such mismatches are undefined behavior (linalg).
To model the termination of the program for implementing error guards,
we use the `cf.assert` op.
This is a design decision that is at variance from other passes in the
ecosystem, which use the
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
`shape.assuming` regions). This is a change in design decisions
from those passes (which the authors of this pass have contributed to).
The reasons for this change are heuristic, but boil down to:
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
not a good fit for modeling error guards. Regions mark a "start" and an
"end" (which is their nesting property). But
modeling assertions in the program doesn't fit into that. For assertions,
only the "start" matters (once tested, a predicate remains true "forever"
-- it doesn't end at the "yield" of the region).
Thus, having regions places arbitrary "end"s that just add IR structure
that has no semantic value for modeling this problem! (and to make things
worse the "end"s, which we don't need, are what require "yielding"
values, which interrupts use-def chains). Consider the different
structural properties of regions:
a. IsolatedFromAbove region:
- "start" interrupts use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
b. Capturing region (like `shape.assuming`):
- "start" does not interrupt use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
c. What we "ideally" want:
- "start" interrupts use-def chains (can be pruned though)
- no "end" IR structure!
- structurally protects from intra-block upward code motion
(but not downward code motion!)
- Observation: We probably can't get all of this, but overall this
problem is much better suited for a "MemorySSA"-like
abstraction, call it "EffectSSA" which is constructed on-demand
based on MLIR's effect modeling system (rather than
`shape.assuming`, which only covers the effects the IR creator
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
to handle effects other than those encoded in the
witness structure).
2. The presence of `shape.assuming` regions tends to create highly nested
IR structures, which don't interoperate well with any other IR
structures, and creates very bulky IR (and IR creation code). In general
if we are going to do anything with anything (e.g. canonicalize) we
end up needing need to either:
a. Flatten the `shape.assuming` IR (defeating the purpose of having
it).
b. Do some sort of shape.assuming "region merging".
c. Have special patterns that handle a subset of special cases (looking
through "yields" and such) and don't generalize.
3. Witnesses tend to encourage non-scalable peephole transformations, which
tend to make analyses/transformations non-robust to the presence of
control flow and side effecting ops (easy to forget to handle side
effects other than those modeled by the witness system).
4. All this code operates on ranked tensors, for which using individual
SSA values for sizes (rather than a "shape type") seems to
work really well at this level of abstraction based on prior experience
in other projects. (unranked code tends to benefit from having a discrete
"shape type" to model shapes).
We will see if we end up needing something like `shape.assuming`, but for
now, it seems likely we can do something simpler and just bypass it. The
design of having an EffectSSA that is constructed on-demand seems very
compelling for modeling effects more broadly.
}];
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Convert ATen ops to tmtensor/linalg ops.
This pass is similar to the TorchToLinalg pass; the difference is that this
pass also makes use of TMTensor Dialect, which the former one doesn't.
}];
let constructor = "mlir::torch::createConvertTorchToTMTensorPass()";
}
def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> {
let summary = "Convert recognized TorchConversion ops to MLProgram ops";
let description = [{
Convert TorchConversion ops to mlprogram ops.
}];
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::createVerifyLinalgOnTensorsBackendContractPass()";
}
#endif // TORCHMLIR_CONVERSION_LINALG_PASSES

View File

@ -1,24 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H

View File

@ -1,22 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionStablehloPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionStablehloPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,51 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H
#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
struct StablehloBackendPipelineOptions
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
Option<bool> enableStaticShape{
*this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<bool> enableI32Index{
*this, "enable-i32-index",
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
llvm::cl::init(false)};
};
void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
/// Registers all torch-mlir conversion passes.
void registerStablehloConversionPasses();
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES_H

View File

@ -0,0 +1,43 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSES
#define TORCHMLIR_CONVERSION_STABLEHLO_PASSES
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// Torch conversions
//===----------------------------------------------------------------------===//
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";
// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
}
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract";
let constructor = "mlir::torch::createVerifyStablehloBackendContractPass()";
}
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSES

View File

@ -1,26 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H

View File

@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionTosaPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionTosaPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,35 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES_H
#define TORCHMLIR_CONVERSION_TOSA_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
/// Registers all torch-mlir conversion passes.
void registerTosaConversionPasses();
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_PASSES_H

View File

@ -0,0 +1,33 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TOSA_PASSES
#define TORCHMLIR_CONVERSION_TOSA_PASSES
include "mlir/Pass/PassBase.td"
//===----------------------------------------------------------------------===//
// Torch conversions
//===----------------------------------------------------------------------===//
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
This pass assumes that TOSA ops are responsible for emitting error
guards in case of shape mismatches.
}];
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::createVerifyTosaBackendContractPass()";
}
#endif // TORCHMLIR_CONVERSION_TOSA_PASSES

View File

@ -21,8 +21,6 @@ class ModuleOp;
namespace torch { namespace torch {
namespace Torch { namespace Torch {
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass(); std::unique_ptr<OperationPass<ModuleOp>> createGlobalizeObjectGraphPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
@ -141,13 +139,6 @@ static const char kTorchOpPrefix[] = R"(torch.)";
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.
void registerTorchPasses(); void registerTorchPasses();
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -22,36 +22,6 @@ class ModuleOp;
namespace torch { namespace torch {
namespace TorchConversion { namespace TorchConversion {
/// Creates a pipeline that lowers from the torch backend contract to the
/// linalg-on-tensors backend contract.
void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm);
/// Creates a pipeline that lowers from the torch backend contract to the
/// TOSA backend contract.
void createTorchBackendToTosaBackendPipeline(OpPassManager &pm);
// Do not register the stablehlo options if the stablehlo target is disabled
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct StablehloBackendPipelineOptions
: public PassPipelineOptions<StablehloBackendPipelineOptions> {
Option<bool> enableStaticShape{
*this, "enable-static-shape",
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<bool> enableI32Index{
*this, "enable-i32-index",
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
llvm::cl::init(false)};
};
void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass(); std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
@ -65,11 +35,6 @@ createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass(); std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
} // namespace TorchConversion } // namespace TorchConversion
/// Registers all Torch transformation passes. /// Registers all Torch transformation passes.

View File

@ -32,23 +32,6 @@ def FinalizingBackendTypeConversion
}]; }];
} }
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
}
def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()";
}
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO
// The following passes are for a one-off conversion of a specific kind of quantized group matmul. // The following passes are for a one-off conversion of a specific kind of quantized group matmul.
// They should not be included in default lowering flows until further along. // They should not be included in default lowering flows until further along.
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {

View File

@ -8,6 +8,9 @@ add_mlir_public_c_api_library(TorchMLIRCAPI
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir-c/ ${PROJECT_SOURCE_DIR}/include/torch-mlir-c/
DEPENDS
TorchMLIRTorchPassIncGen
ENABLE_AGGREGATION ENABLE_AGGREGATION
LINK_COMPONENTS LINK_COMPONENTS
Core Core

View File

@ -9,6 +9,11 @@
#include "mlir/CAPI/Pass.h" #include "mlir/CAPI/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace
// Must include the declarations as they carry important visibility attributes. // Must include the declarations as they carry important visibility attributes.
#include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc" #include "torch-mlir/Dialect/Torch/Transforms/Transforms.capi.h.inc"

View File

@ -23,8 +23,16 @@ set(LinkedLibs
TorchMLIRRefBackend TorchMLIRRefBackend
) )
# Conditionally link in backends if enabled.
if(TORCH_MLIR_ENABLE_LINALG)
list(APPEND LinkedLibs TorchMLIRTorchToLinalg)
endif()
if(TORCH_MLIR_ENABLE_TOSA)
list(APPEND LinkedLibs TorchMLIRTorchToTosa)
endif()
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs list(APPEND LinkedLibs
TorchMLIRTorchToStablehlo
MhloPasses MhloPasses
MhloToLinalg MhloToLinalg
StablehloToMhlo StablehloToMhlo

View File

@ -1,36 +1,18 @@
add_subdirectory(TorchToLinalg) if(TORCH_MLIR_ENABLE_LINALG)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToArith) endif()
add_subdirectory(TorchToTosa) if(TORCH_MLIR_ENABLE_TOSA)
add_subdirectory(TorchToTosa)
endif()
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
add_subdirectory(TorchToStablehlo) add_subdirectory(TorchToStablehlo)
endif() endif()
add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram)
add_subdirectory(Utils) add_subdirectory(Utils)
# TODO: Automate this with add_torch_mlir_conversion_library.
set(linked_libs TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToArith
TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
endif()
add_mlir_library(TorchMLIRConversionPasses add_mlir_library(TorchMLIRConversionPasses
Passes.cpp Passes.cpp
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
${linked_libs} TorchMLIRConversionUtils
#${torch_mlir_conversion_libs}
) )
add_subdirectory(Common)

View File

@ -0,0 +1,28 @@
add_mlir_conversion_library(TorchMLIRConversionCommon
TorchToArith.cpp
TorchToSCF.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion
DEPENDS
TorchMLIRConversionCommonPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRIR
MLIRFuncDialect
MLIRPass
MLIRMathDialect
MLIRSCFDialect
MLIRTransforms
TorchMLIRTorchDialect
TorchMLIRTorchConversionDialect
TorchMLIRTorchConversionPasses
TorchMLIRTorchUtils
)
torch_mlir_target_includes(TorchMLIRConversionCommon)

View File

@ -7,7 +7,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/Passes.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
@ -43,7 +43,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf()); auto rank =
rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>( rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, getTypeConverter()->convertType(op.getType()), rank); op, getTypeConverter()->convertType(op.getType()), rank);
return success(); return success();
@ -74,7 +75,8 @@ public:
matchAndRewrite(AtenOp op, matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor, typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(), adaptor.getB()); rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
adaptor.getB());
return success(); return success();
} }
}; };
@ -112,10 +114,10 @@ public:
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor, typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value a = Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); rewriter.getF64Type());
Value b = Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(),
convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); rewriter.getF64Type());
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b); rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
return success(); return success();
} }
@ -180,7 +182,8 @@ public:
})); }));
return success(); return success();
} }
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) { if (auto elements =
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) { if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) { if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy = Type builtinTensorElemTy =
@ -357,7 +360,8 @@ public:
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
namespace { namespace {
class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith> { class ConvertTorchToArith
: public ConvertTorchToArithBase<ConvertTorchToArith> {
public: public:
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>(); registry.insert<func::FuncDialect>();

View File

@ -7,7 +7,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/Passes.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -9,22 +9,6 @@
#include "torch-mlir/Conversion/Passes.h" #include "torch-mlir/Conversion/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace { namespace {
#define GEN_PASS_REGISTRATION #define GEN_PASS_REGISTRATION
#include "torch-mlir/Conversion/Passes.h.inc" #include "torch-mlir/Conversion/Passes.h.inc"

View File

@ -1,22 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram
TorchConversionToMLProgram.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchConversionToMLProgram
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLinalgDialect
MLIRMLProgramDialect
MLIRMathDialect
MLIRPass
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchConversionToMLProgram)

View File

@ -1,20 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToArith
TorchToArith.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToArith
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRFuncDialect
TorchMLIRTorchDialect
)
torch_mlir_target_includes(TorchMLIRTorchToArith)

View File

@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToLinalg add_mlir_conversion_library(TorchMLIRTorchToLinalg
Passes.cpp
DataMovement.cpp DataMovement.cpp
IndirectDataMovement.cpp IndirectDataMovement.cpp
Linear.cpp Linear.cpp
@ -7,25 +8,37 @@ add_mlir_conversion_library(TorchMLIRTorchToLinalg
Reduction.cpp Reduction.cpp
TensorConstructors.cpp TensorConstructors.cpp
TensorScalarInterop.cpp TensorScalarInterop.cpp
TorchConversionToMLProgram.cpp
TorchToLinalg.cpp TorchToLinalg.cpp
TorchToTMTensor.cpp
Uncategorized.cpp Uncategorized.cpp
Utils.cpp Utils.cpp
VerifyLinalgOnTensorsBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToLinalg
DEPENDS DEPENDS
TorchMLIRConversionPassIncGen TorchMLIRConversionLinalgPassIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
MLIRFuncDialect
MLIRPass MLIRPass
MLIRLinalgDialect MLIRLinalgDialect
MLIRMathDialect MLIRMathDialect
MLIRMLProgramDialect
MLIRSCFDialect
MLIRTransforms
TorchMLIRConversionCommon
TorchMLIRTorchDialect TorchMLIRTorchDialect
TorchMLIRTorchConversionDialect
TorchMLIRTorchConversionPasses
TorchMLIRTorchUtils
TorchMLIRTMTensorDialect
) )
torch_mlir_target_includes(TorchMLIRTorchToLinalg) torch_mlir_target_includes(TorchMLIRTorchToLinalg)

View File

@ -7,13 +7,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h" #include "./PassDetail.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "../PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
@ -21,7 +15,12 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -1,4 +1,4 @@
//===------------------------------------------------------------*- C++ -*-===// //===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -7,17 +7,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #ifndef TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #define TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToArithPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #define GEN_PASS_CLASSES
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc"
} // namespace torch
} // end namespace mlir
#endif // TORCHMLIR_CONVERSION_LINALG_PASSDETAIL_H

View File

@ -0,0 +1,75 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h.inc"
} // end namespace
void mlir::torch::registerLinalgConversionPasses() {
::registerPasses();
mlir::PassPipelineRegistration<>(
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
"contract.",
createTorchBackendToLinalgOnTensorsBackendPipeline);
}
void mlir::torch::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToArith.
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addPass(createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<func::FuncOp>(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that linalg on tensors backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(createVerifyLinalgOnTensorsBackendContractPass());
}

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -82,7 +82,8 @@ public:
// temp = multiplier * currentSeed + incrementStep // temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier); Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep); Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange()); globalVar =
rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
rewriter.create<ml_program::GlobalStoreOp>( rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar); globalVar);

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Complex/IR/Complex.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -1273,13 +1273,13 @@ public:
// Set the values in the input tensor to the smallest element of that // Set the values in the input tensor to the smallest element of that
// type // type
TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/true); /*getMin=*/true);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, minAttr); normalizationValue = rewriter.create<arith::ConstantOp>(loc, minAttr);
} else if (reduceEnum == torch_upstream::ReductionType::MIN) { } else if (reduceEnum == torch_upstream::ReductionType::MIN) {
// Set the values in the input tensor to the largest element of that // Set the values in the input tensor to the largest element of that
// type // type
TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(),
/*getMin=*/false); /*getMin=*/false);
normalizationValue = rewriter.create<arith::ConstantOp>(loc, maxAttr); normalizationValue = rewriter.create<arith::ConstantOp>(loc, maxAttr);
} }

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "Utils.h" #include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "Utils.h" #include "./Utils.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

View File

@ -7,7 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "./PassDetail.h"
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
@ -24,7 +25,6 @@
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -33,7 +33,6 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
using namespace TMTensor; using namespace TMTensor;
namespace { namespace {
class VerifyLinalgOnTensorsBackendContractPass class VerifyLinalgOnTensorsBackendContractPass
: public VerifyLinalgOnTensorsBackendContractBase< : public VerifyLinalgOnTensorsBackendContractBase<
@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module. // doesn't unnecessarily spew out the entire module.
emitError(module.getLoc()) emitError(module.getLoc())
<< "Module does not conform to the linalg-on-tensors backend contract. " << "Module does not conform to the linalg-on-tensors backend "
"contract. "
"See dialect conversion legality information above."; "See dialect conversion legality information above.";
return signalPassFailure(); return signalPassFailure();
} }
@ -105,6 +105,6 @@ class VerifyLinalgOnTensorsBackendContractPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass() { mlir::torch::createVerifyLinalgOnTensorsBackendContractPass() {
return std::make_unique<VerifyLinalgOnTensorsBackendContractPass>(); return std::make_unique<VerifyLinalgOnTensorsBackendContractPass>();
} }

View File

@ -1,22 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToSCF
TorchToSCF.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToSCF
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRSCFDialect
MLIRFuncDialect
TorchMLIRTorchDialect
TorchMLIRTorchConversionDialect
)
torch_mlir_target_includes(TorchMLIRTorchToSCF)

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -1,4 +1,5 @@
add_mlir_conversion_library(TorchMLIRTorchToStablehlo add_mlir_conversion_library(TorchMLIRTorchToStablehlo
Passes.cpp
TorchToStablehlo.cpp TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp StablehloLegalizeUtils.cpp
Basic.cpp Basic.cpp
@ -7,21 +8,25 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
ViewLike.cpp ViewLike.cpp
Reduction.cpp Reduction.cpp
Pooling.cpp Pooling.cpp
VerifyStablehloBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo
DEPENDS DEPENDS
TorchMLIRConversionPassIncGen TorchMLIRConversionStablehloPassIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRBufferTransforms
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRBufferTransforms MLIRTransforms
StablehloOps StablehloOps
TorchMLIRConversionCommon
TorchMLIRTorchConversionPasses
TorchMLIRTorchDialect TorchMLIRTorchDialect
TorchMLIRConversionUtils TorchMLIRConversionUtils
) )

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -1,4 +1,4 @@
//===------------------------------------------------------------*- C++ -*-===// //===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -7,17 +7,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #ifndef TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #define TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #define GEN_PASS_CLASSES
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc"
} // namespace torch
} // end namespace mlir
#endif // TORCHMLIR_CONVERSION_STABLEHLO_PASSDETAIL_H

View File

@ -0,0 +1,62 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/Passes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "transforms/passes.h"
using namespace mlir;
using namespace mlir::torch;
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h.inc"
} // end namespace
void mlir::torch::registerStablehloConversionPasses() {
::registerPasses();
mlir::PassPipelineRegistration<StablehloBackendPipelineOptions>(
"torch-backend-to-stablehlo-backend-pipeline",
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.",
createTorchBackendToStablehloBackendPipeline);
}
void mlir::torch::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options) {
// Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to Stablehlo and Chlo ops.
pm.addPass(createVerifyStablehloBackendContractPass());
}

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -11,7 +11,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric> #include <numeric>

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -6,8 +6,9 @@
// Also available under a BSD-style license. See LICENSE. // Also available under a BSD-style license. See LICENSE.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "PassDetail.h" #include "./PassDetail.h"
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@ -18,11 +19,9 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace { namespace {
class VerifyStablehloBackendContractPass class VerifyStablehloBackendContractPass
@ -45,7 +44,8 @@ class VerifyStablehloBackendContractPass
ConversionTarget target(*context); ConversionTarget target(*context);
// Structural operations. // Structural operations.
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes); target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Shape operations. // Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes); target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
@ -58,7 +58,6 @@ class VerifyStablehloBackendContractPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { mlir::torch::createVerifyStablehloBackendContractPass() {
return std::make_unique<VerifyStablehloBackendContractPass>(); return std::make_unique<VerifyStablehloBackendContractPass>();
} }
#endif // TORCH_MLIR_ENABLE_STABLEHLO

View File

@ -7,9 +7,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "PopulatePatterns.h" #include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -1,23 +0,0 @@
add_mlir_conversion_library(TorchMLIRTorchToTMTensor
TorchToTMTensor.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTMTensor
DEPENDS
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRLinalgDialect
MLIRMathDialect
TorchMLIRTorchDialect
TorchMLIRTMTensorDialect
TorchMLIRTorchUtils
)
torch_mlir_target_includes(TorchMLIRTorchToTMTensor)

View File

@ -1,13 +1,15 @@
add_mlir_conversion_library(TorchMLIRTorchToTosa add_mlir_conversion_library(TorchMLIRTorchToTosa
Passes.cpp
TorchToTosa.cpp TorchToTosa.cpp
TosaLegalizeUtils.cpp TosaLegalizeUtils.cpp
TosaLegalizeCommon.cpp TosaLegalizeCommon.cpp
VerifyTosaBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosa
DEPENDS DEPENDS
TorchMLIRConversionPassIncGen TorchMLIRConversionTosaPassIncGen
LINK_COMPONENTS LINK_COMPONENTS
Core Core
@ -16,6 +18,8 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa
MLIRIR MLIRIR
MLIRPass MLIRPass
MLIRTosaDialect MLIRTosaDialect
MLIRTransforms
TorchMLIRTorchConversionPasses
TorchMLIRConversionUtils TorchMLIRConversionUtils
TorchMLIRTorchDialect TorchMLIRTorchDialect
) )

View File

@ -1,4 +1,4 @@
//===------------------------------------------------------------*- C++ -*-===// //===- PassDetail.h - Conversion Pass class details -------------*- C++ -*-===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -7,16 +7,20 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #ifndef TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #define TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
}
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #define GEN_PASS_CLASSES
#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc"
} // namespace torch
} // end namespace mlir
#endif // TORCHMLIR_CONVERSION_TOSA_PASSDETAIL_H

View File

@ -0,0 +1,61 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Conversion/TorchToTosa/Passes.h.inc"
} // end namespace
void mlir::torch::registerTosaConversionPasses() {
::registerPasses();
mlir::PassPipelineRegistration<>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
createTorchBackendToTosaBackendPipeline);
}
void mlir::torch::createTorchBackendToTosaBackendPipeline(OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that TOSA backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(createVerifyTosaBackendContractPass());
}

View File

@ -7,12 +7,12 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/Passes.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "../PassDetail.h" #include "./PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"

View File

@ -7,7 +7,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "./PassDetail.h"
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@ -16,11 +17,9 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
namespace { namespace {
class VerifyTosaBackendContractPass class VerifyTosaBackendContractPass
@ -62,6 +61,6 @@ class VerifyTosaBackendContractPass
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { mlir::torch::createVerifyTosaBackendContractPass() {
return std::make_unique<VerifyTosaBackendContractPass>(); return std::make_unique<VerifyTosaBackendContractPass>();
} }

View File

@ -11,8 +11,17 @@
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h.inc"
} // namespace
void mlir::torch::registerTorchPasses() { void mlir::torch::registerTorchPasses() {
mlir::torch::registerPasses(); ::registerPasses();
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchscript-module-to-torch-backend-pipeline", "torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",

View File

@ -1,20 +1,10 @@
set(LinkedLibs set(LinkedLibs
MLIRFuncTransforms MLIRFuncTransforms
MLIRIR MLIRIR
MLIRLinalgTransforms
MLIRMemRefTransforms
MLIRPass MLIRPass
MLIRTosaTransforms
MLIRVectorTransforms
TorchMLIRTorchConversionDialect TorchMLIRTorchConversionDialect
TorchMLIRTorchConversionToMLProgram
TorchMLIRTorchDialect TorchMLIRTorchDialect
TorchMLIRTorchPasses TorchMLIRTorchPasses
TorchMLIRTorchToArith
TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToTMTensor
TorchMLIRTorchToTosa
) )
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
@ -27,9 +17,6 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Passes.cpp Passes.cpp
ConvertCustomQuantOp.cpp ConvertCustomQuantOp.cpp
UnpackQuantTensor.cpp UnpackQuantTensor.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyStablehloBackendContract.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms

View File

@ -8,27 +8,10 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Pass registration // Pass registration
@ -39,111 +22,4 @@ namespace reg {
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc"
} // end namespace reg } // end namespace reg
void mlir::torch::registerTorchConversionPasses() { void mlir::torch::registerTorchConversionPasses() { reg::registerPasses(); }
reg::registerPasses();
mlir::PassPipelineRegistration<>(
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Pipeline lowering torch backend contract to linalg-on-tensors backend "
"contract.",
TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline);
mlir::PassPipelineRegistration<>(
"torch-backend-to-tosa-backend-pipeline",
"Pipeline lowering torch backend contract to TOSA backend "
"contract.",
TorchConversion::createTorchBackendToTosaBackendPipeline);
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::PassPipelineRegistration<
TorchConversion::StablehloBackendPipelineOptions>(
"torch-backend-to-stablehlo-backend-pipeline",
"Pipeline lowering torch backend contract to StableHLO backend "
"contract.",
TorchConversion::createTorchBackendToStablehloBackendPipeline);
#endif
}
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToArith.
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addPass(createConvertTorchConversionToMLProgramPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<func::FuncOp>(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that linalg on tensors backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass());
}
void TorchConversion::createTorchBackendToTosaBackendPipeline(
OpPassManager &pm) {
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that TOSA backends
// expect. This fails compilation (signalPassFailure) if the IR is not in the
// correct form.
pm.addPass(TorchConversion::createVerifyTosaBackendContractPass());
}
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm,
const TorchConversion::StablehloBackendPipelineOptions &options) {
// Generate Stablehlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
// Finish the type conversion from `torch` types to the types of the
// StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to Stablehlo and Chlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
}
#endif

View File

@ -21,8 +21,17 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h" #include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_LINALG
#include "torch-mlir/Conversion/TorchToLinalg/Passes.h"
#endif
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "mhlo/transforms/passes.h" #include "mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToStablehlo/Passes.h"
#endif
#ifdef TORCH_MLIR_ENABLE_TOSA
#include "torch-mlir/Conversion/TorchToTosa/Passes.h"
#endif #endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) { void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
@ -41,11 +50,20 @@ void mlir::torch::registerAllPasses() {
mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::RefBackend::registerRefBackendPasses();
mlir::torch::TMTensor::registerPasses(); mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_LINALG
mlir::torch::registerLinalgConversionPasses();
#endif // TORCH_MLIR_ENABLE_LINALG
#ifdef TORCH_MLIR_ENABLE_STABLEHLO #ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::mhlo::registerSymbolicShapeOptimizationPass(); mlir::mhlo::registerSymbolicShapeOptimizationPass();
mlir::mhlo::registerStablehloLegalizeToHloPass(); mlir::mhlo::registerStablehloLegalizeToHloPass();
mlir::mhlo::registerChloLegalizeToHloPass(); mlir::mhlo::registerChloLegalizeToHloPass();
mlir::mhlo::registerHloLegalizeToLinalgPass(); mlir::mhlo::registerHloLegalizeToLinalgPass();
mlir::mhlo::registerTestUnfuseBatchNormPass(); mlir::mhlo::registerTestUnfuseBatchNormPass();
mlir::torch::registerStablehloConversionPasses();
#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCH_MLIR_ENABLE_STABLEHLO
#ifdef TORCH_MLIR_ENABLE_TOSA
mlir::torch::registerTosaConversionPasses();
#endif // TORCH_MLIR_ENABLE_TOSA
} }