[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)

Discord Thread:
https://discord.com/channels/636084430946959380/1238330633328005243

## Context: 

[This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61)
was updated to support e2e tests for the TorchDynamo frontend in
Torch-MLIR, where we run FX decompositions and import the FX IR to
generate Torch dialect, followed by
`torch-function-to-torch-backend-pipeline`, skipping only the shape/type
refinement for now. However, we should be able to skip many of the torch
simplification passes, as depicted in the [frontend
roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png).

Based on IREE's TorchDynamo
[pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29),
the only two passes we seem to require are: `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. This is inline with our findings as well
based on initial exploration.

This PR creates a dedicated frontend simplification pipeline for
TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and
`DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to
ensure we're not regressing by removing many of the passes that were
historically needed for TorchScript.

One notable change here is that we do not call the
`LowerToBackendContractPass` anymore, which used to call
`TorchSimplificationPipeline` iteratively until VerifyBackendContract
was clean. Some of this was required for the shape/type refinement to
converge, which seems a non-issue for Dynamo frontend. Do we anticipate
this (the iterative invocation of TorchSimplificationPipeline followed
by VerifyBackendContract) to be worth retaining in the Dynamo frontend
pipeline? If so, I can make those changes, PLMK.
pull/3350/head
Sambhav Jain 2024-05-22 05:23:18 -07:00 committed by GitHub
parent 560ca24771
commit 6e485574e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 14 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@ externals/pytorch/
libtorch* libtorch*
/build/ /build/
.build-cache/
/setup_build/ /setup_build/
__pycache__ __pycache__
*.pyc *.pyc

View File

@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions
void createTorchScriptModuleToTorchBackendPipeline( void createTorchScriptModuleToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options); OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers the graph IR that is produced by
/// TorchDynamo export into the form expected by torch-verify-backend-contract.
void createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
/// Creates a pipeline that lowers a flat list of funcs and global slots /// Creates a pipeline that lowers a flat list of funcs and global slots
/// with the torch and aten dialects and mutable arrays and converts it to /// with the torch and aten dialects and mutable arrays and converts it to
/// the form required by torch-verify-backend-contract. /// the form required by torch-verify-backend-contract.

View File

@ -17,6 +17,10 @@ void mlir::torch::registerTorchPasses() {
"torchscript-module-to-torch-backend-pipeline", "torchscript-module-to-torch-backend-pipeline",
"Pipeline lowering TorchScript object graph IR to Torch backend form.", "Pipeline lowering TorchScript object graph IR to Torch backend form.",
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline); mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torchdynamo-export-to-torch-backend-pipeline",
"Pipeline lowering TorchDynamo exported graph IR to Torch backend form.",
mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline);
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>( mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
"torch-function-to-torch-backend-pipeline", "torch-function-to-torch-backend-pipeline",
"Pipeline lowering a Torch function to Torch backend form.", "Pipeline lowering a Torch function to Torch backend form.",
@ -59,6 +63,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
createTorchFunctionToTorchBackendPipeline(pm, options); createTorchFunctionToTorchBackendPipeline(pm, options);
} }
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
if (options.decompose) {
pm.addNestedPass<func::FuncOp>(
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
}
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options) { OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
// Incorporate user annotations and remove signature Python-isms. // Incorporate user annotations and remove signature Python-isms.

View File

@ -27,7 +27,6 @@ def _module_lowering(
verbose, verbose,
output_type, output_type,
torch_mod, torch_mod,
backend_legal_ops=None,
extra_library_file_name=None, extra_library_file_name=None,
): ):
@ -35,23 +34,13 @@ def _module_lowering(
if verbose: if verbose:
print(torch_mod) print(torch_mod)
return torch_mod return torch_mod
# TODO: pass backend_legal_ops/extra_library_file_name by caller # TODO: pass extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
if extra_library_file_name is None: if extra_library_file_name is None:
extra_library_file_name = "" extra_library_file_name = ""
option_string = ( option_string = "{extra-library=" + extra_library_file_name + "}"
"{backend-legal-ops="
+ ",".join(backend_legal_ops)
+ " extra-library="
+ extra_library_file_name
+ " shape-dtype-refine="
+ ("false" if not backend_legal_ops and not extra_library_file_name else "true")
+ "}"
)
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
torch_mod, torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR", "Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose, enable_ir_printing=verbose,
) )