diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index c1315abd4..cb2799f85 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -82,12 +82,12 @@ def run_pipeline_with_repro_report( class OutputType(Enum): - # Output torch dialect. When converting from FX, this will be immediately - # after the import from FX to MLIR. When converting from torchscript, - # this will come after some cleanup passes which attempt to de-alias, - # decompose and infer shapes. These should be roughly the same level of - # abstraction since those steps are done within PyTorch itself - # when coming directly from Dynamo/FX. + # Output torch dialect in backend form. When converting from TorchDynamo, + # this comes after some decomposition and reduce op variants passes are + # applied to the raw torch dialect. When converting from TorchScript, this + # comes after some cleanup passes which attempt to de-alias, decompose and infer shapes. + # These should be roughly the same level of abstraction since those + # steps are done within PyTorch itself when coming directly from Dynamo/FX. TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and @@ -104,7 +104,8 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to StableHLO. STABLEHLO = "stablehlo" - # Raw output of the JIT IR importer. This is not expected to be useful + # Raw output of the JIT IR importer in the TorchScript frontend or that of + # the FX IR importer in the TorchDynamo frontend. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. RAW = "raw" diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 5cd7d2d6e..0d9ad77d2 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -30,7 +30,7 @@ def _module_lowering( extra_library_file_name=None, ): - if output_type == OutputType.TORCH: + if output_type == OutputType.RAW: if verbose: print(torch_mod) return torch_mod @@ -50,7 +50,7 @@ def _module_lowering( def export_and_import( f: Union[nn.Module, ExportedProgram], *args, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, @@ -99,7 +99,7 @@ def export_and_import( def stateless_fx_import( gm: torch.fx.GraphModule, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main",