mirror of https://github.com/llvm/torch-mlir
[Pipeline] Use dedicated simplification pipeline for TorchDynamo frontend (#3376)
Discord Thread: https://discord.com/channels/636084430946959380/1238330633328005243 ## Context: [This](https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/fx.py#L61) was updated to support e2e tests for the TorchDynamo frontend in Torch-MLIR, where we run FX decompositions and import the FX IR to generate Torch dialect, followed by `torch-function-to-torch-backend-pipeline`, skipping only the shape/type refinement for now. However, we should be able to skip many of the torch simplification passes, as depicted in the [frontend roadmap](https://github.com/llvm/torch-mlir/blob/main/docs/images/roadmap_frontend.png). Based on IREE's TorchDynamo [pipeline](https://github.com/iree-org/iree/blob/main/compiler/plugins/input/Torch/InputConversion/Passes.cpp#L29), the only two passes we seem to require are: `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. This is inline with our findings as well based on initial exploration. This PR creates a dedicated frontend simplification pipeline for TorchDynamo / FX Importer which calls only `ReduceOpVariantsPass` and `DecomposeComplexOpsPass`. We rely on the e2e fx_importer tests to ensure we're not regressing by removing many of the passes that were historically needed for TorchScript. One notable change here is that we do not call the `LowerToBackendContractPass` anymore, which used to call `TorchSimplificationPipeline` iteratively until VerifyBackendContract was clean. Some of this was required for the shape/type refinement to converge, which seems a non-issue for Dynamo frontend. Do we anticipate this (the iterative invocation of TorchSimplificationPipeline followed by VerifyBackendContract) to be worth retaining in the Dynamo frontend pipeline? If so, I can make those changes, PLMK.pull/3350/head
parent
560ca24771
commit
6e485574e5
|
@ -11,6 +11,7 @@ externals/pytorch/
|
||||||
libtorch*
|
libtorch*
|
||||||
|
|
||||||
/build/
|
/build/
|
||||||
|
.build-cache/
|
||||||
/setup_build/
|
/setup_build/
|
||||||
__pycache__
|
__pycache__
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
|
@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions
|
||||||
void createTorchScriptModuleToTorchBackendPipeline(
|
void createTorchScriptModuleToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
|
/// Creates a pipeline that lowers the graph IR that is produced by
|
||||||
|
/// TorchDynamo export into the form expected by torch-verify-backend-contract.
|
||||||
|
void createTorchDynamoExportToTorchBackendPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
|
||||||
|
|
||||||
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
/// Creates a pipeline that lowers a flat list of funcs and global slots
|
||||||
/// with the torch and aten dialects and mutable arrays and converts it to
|
/// with the torch and aten dialects and mutable arrays and converts it to
|
||||||
/// the form required by torch-verify-backend-contract.
|
/// the form required by torch-verify-backend-contract.
|
||||||
|
|
|
@ -17,6 +17,10 @@ void mlir::torch::registerTorchPasses() {
|
||||||
"torchscript-module-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
"Pipeline lowering TorchScript object graph IR to Torch backend form.",
|
||||||
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
|
mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline);
|
||||||
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
|
"torchdynamo-export-to-torch-backend-pipeline",
|
||||||
|
"Pipeline lowering TorchDynamo exported graph IR to Torch backend form.",
|
||||||
|
mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline);
|
||||||
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
mlir::PassPipelineRegistration<Torch::TorchLoweringPipelineOptions>(
|
||||||
"torch-function-to-torch-backend-pipeline",
|
"torch-function-to-torch-backend-pipeline",
|
||||||
"Pipeline lowering a Torch function to Torch backend form.",
|
"Pipeline lowering a Torch function to Torch backend form.",
|
||||||
|
@ -59,6 +63,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
|
||||||
createTorchFunctionToTorchBackendPipeline(pm, options);
|
createTorchFunctionToTorchBackendPipeline(pm, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
|
||||||
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
createReduceOpVariantsPass(options.extraLibrary));
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
if (options.decompose) {
|
||||||
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
|
||||||
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
|
||||||
// Incorporate user annotations and remove signature Python-isms.
|
// Incorporate user annotations and remove signature Python-isms.
|
||||||
|
|
|
@ -27,7 +27,6 @@ def _module_lowering(
|
||||||
verbose,
|
verbose,
|
||||||
output_type,
|
output_type,
|
||||||
torch_mod,
|
torch_mod,
|
||||||
backend_legal_ops=None,
|
|
||||||
extra_library_file_name=None,
|
extra_library_file_name=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -35,23 +34,13 @@ def _module_lowering(
|
||||||
if verbose:
|
if verbose:
|
||||||
print(torch_mod)
|
print(torch_mod)
|
||||||
return torch_mod
|
return torch_mod
|
||||||
# TODO: pass backend_legal_ops/extra_library_file_name by caller
|
# TODO: pass extra_library_file_name by caller
|
||||||
if backend_legal_ops is None:
|
|
||||||
backend_legal_ops = []
|
|
||||||
if extra_library_file_name is None:
|
if extra_library_file_name is None:
|
||||||
extra_library_file_name = ""
|
extra_library_file_name = ""
|
||||||
option_string = (
|
option_string = "{extra-library=" + extra_library_file_name + "}"
|
||||||
"{backend-legal-ops="
|
|
||||||
+ ",".join(backend_legal_ops)
|
|
||||||
+ " extra-library="
|
|
||||||
+ extra_library_file_name
|
|
||||||
+ " shape-dtype-refine="
|
|
||||||
+ ("false" if not backend_legal_ops and not extra_library_file_name else "true")
|
|
||||||
+ "}"
|
|
||||||
)
|
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
torch_mod,
|
torch_mod,
|
||||||
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
|
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
|
||||||
"Lowering TorchFX IR -> Torch Backend IR",
|
"Lowering TorchFX IR -> Torch Backend IR",
|
||||||
enable_ir_printing=verbose,
|
enable_ir_printing=verbose,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue