diff --git a/.gitignore b/.gitignore index 00a5bc96f..7cc823a3f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ externals/pytorch/ libtorch* /build/ +.build-cache/ /setup_build/ __pycache__ *.pyc diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index d4cceb05d..aef6baa5d 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -73,6 +73,11 @@ struct TorchLoweringPipelineOptions void createTorchScriptModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that lowers the graph IR that is produced by +/// TorchDynamo export into the form expected by torch-verify-backend-contract. +void createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); + /// Creates a pipeline that lowers a flat list of funcs and global slots /// with the torch and aten dialects and mutable arrays and converts it to /// the form required by torch-verify-backend-contract. diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index d01eac967..3ed8dc324 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -17,6 +17,10 @@ void mlir::torch::registerTorchPasses() { "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline); + mlir::PassPipelineRegistration( + "torchdynamo-export-to-torch-backend-pipeline", + "Pipeline lowering TorchDynamo exported graph IR to Torch backend form.", + mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline); mlir::PassPipelineRegistration( "torch-function-to-torch-backend-pipeline", "Pipeline lowering a Torch function to Torch backend form.", @@ -59,6 +63,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( createTorchFunctionToTorchBackendPipeline(pm, options); } +void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + pm.addNestedPass( + createReduceOpVariantsPass(options.extraLibrary)); + pm.addNestedPass(createCanonicalizerPass()); + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } +} + void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // Incorporate user annotations and remove signature Python-isms. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 834cffd63..b8765b659 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -27,7 +27,6 @@ def _module_lowering( verbose, output_type, torch_mod, - backend_legal_ops=None, extra_library_file_name=None, ): @@ -35,23 +34,13 @@ def _module_lowering( if verbose: print(torch_mod) return torch_mod - # TODO: pass backend_legal_ops/extra_library_file_name by caller - if backend_legal_ops is None: - backend_legal_ops = [] + # TODO: pass extra_library_file_name by caller if extra_library_file_name is None: extra_library_file_name = "" - option_string = ( - "{backend-legal-ops=" - + ",".join(backend_legal_ops) - + " extra-library=" - + extra_library_file_name - + " shape-dtype-refine=" - + ("false" if not backend_legal_ops and not extra_library_file_name else "true") - + "}" - ) + option_string = "{extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( torch_mod, - f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})", "Lowering TorchFX IR -> Torch Backend IR", enable_ir_printing=verbose, )