mirror of https://github.com/llvm/torch-mlir
[MHLO] add new options to pipeline (#1331)
parent
71b1f0dd7a
commit
7f63a17a46
|
@ -36,9 +36,22 @@ void createTorchBackendToTosaBackendPipeline(
|
|||
|
||||
// Do not register the torch-to-mhlo pipeline if mhlo target is disabled
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
struct MhloBackendPipelineOptions
|
||||
: public PassPipelineOptions<MhloBackendPipelineOptions> {
|
||||
Option<bool> enableStaticShape{
|
||||
*this, "enable-static-shape",
|
||||
llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)};
|
||||
// The i64 calculation is much slower than i32 on some devices, such as
|
||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
||||
// are unlikely to exceed the range of i32(4GiB)
|
||||
Option<bool> enableI32Index{
|
||||
*this, "enable-i32-index",
|
||||
llvm::cl::desc("Enable truncate index from i64 to i32(unsafely)"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
void createTorchBackendToMhloBackendPipeline(
|
||||
OpPassManager &pm,
|
||||
const torch::Torch::TorchLoweringPipelineOptions &options);
|
||||
OpPassManager &pm, const MhloBackendPipelineOptions &options);
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
|
||||
#endif
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
"contract.",
|
||||
TorchConversion::createTorchBackendToTosaBackendPipeline);
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||
mlir::PassPipelineRegistration<TorchConversion::MhloBackendPipelineOptions>(
|
||||
"torch-backend-to-mhlo-backend-pipeline",
|
||||
"Pipeline lowering torch backend contract to MHLO backend "
|
||||
"contract.",
|
||||
|
@ -121,8 +121,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
||||
OpPassManager &pm, const Torch::TorchLoweringPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass());
|
||||
OpPassManager &pm,
|
||||
const TorchConversion::MhloBackendPipelineOptions &options) {
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToMhloPass(
|
||||
options.enableStaticShape, options.enableI32Index));
|
||||
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
|
Loading…
Reference in New Issue