From 6e485574e55cadc441470457e49470f5e6ac54d0 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 22 May 2024 05:23:18 -0700 Subject: [PATCH] [Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376) Discord Thread: https://discord.com/channels/636084430946959380/1238330633328005243 ## Context: [This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61) was updated to support e2e tests for the TorchDynamo frontend in Torch-MLIR, where we run FX decompositions and import the FX IR to generate Torch dialect, followed by `torch-function-to-torch-backend-pipeline`, skipping only the shape/type refinement for now. However, we should be able to skip many of the torch simplification passes, as depicted in the [frontend roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png). Based on IREE's TorchDynamo [pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29), the only two passes we seem to require are: `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. This is inline with our findings as well based on initial exploration. This PR creates a dedicated frontend simplification pipeline for TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to ensure we're not regressing by removing many of the passes that were historically needed for TorchScript. One notable change here is that we do not call the `LowerToBackendContractPass` anymore, which used to call `TorchSimplificationPipeline` iteratively until VerifyBackendContract was clean. Some of this was required for the shape/type refinement to converge, which seems a non-issue for Dynamo frontend. Do we anticipate this (the iterative invocation of TorchSimplificationPipeline followed by VerifyBackendContract) to be worth retaining in the Dynamo frontend pipeline? If so, I can make those changes, PLMK. --- .gitignore | 1 + .../Dialect/Torch/Transforms/Passes.h | 5 +++++ lib/Dialect/Torch/Transforms/Passes.cpp | 16 ++++++++++++++++ python/torch_mlir/fx.py | 17 +++-------------- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 00a5bc96f..7cc823a3f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ externals/pytorch/ libtorch* /build/ +.build-cache/ /setup_build/ __pycache__ *.pyc diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index d4cceb05d..aef6baa5d 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions void createTorchScriptModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that lowers the graph IR that is produced by +/// TorchDynamo export into the form expected by torch-verify-backend-contract. +void createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); + /// Creates a pipeline that lowers a flat list of funcs and global slots /// with the torch and aten dialects and mutable arrays and converts it to /// the form required by torch-verify-backend-contract. diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index d01eac967..3ed8dc324 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -17,6 +17,10 @@ void mlir::torch::registerTorchPasses() { "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline); + mlir::PassPipelineRegistration( + "torchdynamo-export-to-torch-backend-pipeline", + "Pipeline lowering TorchDynamo exported graph IR to Torch backend form.", + mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline); mlir::PassPipelineRegistration( "torch-function-to-torch-backend-pipeline", "Pipeline lowering a Torch function to Torch backend form.", @@ -59,6 +63,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( createTorchFunctionToTorchBackendPipeline(pm, options); } +void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + pm.addNestedPass( + createReduceOpVariantsPass(options.extraLibrary)); + pm.addNestedPass(createCanonicalizerPass()); + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } +} + void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // Incorporate user annotations and remove signature Python-isms. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 834cffd63..b8765b659 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -27,7 +27,6 @@ def _module_lowering( verbose, output_type, torch_mod, - backend_legal_ops=None, extra_library_file_name=None, ): @@ -35,23 +34,13 @@ def _module_lowering( if verbose: print(torch_mod) return torch_mod - # TODO: pass backend_legal_ops/extra_library_file_name by caller - if backend_legal_ops is None: - backend_legal_ops = [] + # TODO: pass extra_library_file_name by caller if extra_library_file_name is None: extra_library_file_name = "" - option_string = ( - "{backend-legal-ops=" - + ",".join(backend_legal_ops) - + " extra-library=" - + extra_library_file_name - + " shape-dtype-refine=" - + ("false" if not backend_legal_ops and not extra_library_file_name else "true") - + "}" - ) + option_string = "{extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( torch_mod, - f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})", "Lowering TorchFX IR -> Torch Backend IR", enable_ir_printing=verbose, )