[NFC] Expose both raw Torch dialect and Torch dialect in backend form with Dynamo/FX (#3541)

This is a non-functional change. It merely allows intercepting the Torch
dialect during TorchDynamo export at two stages:
1. `OutputType.RAW`: This gets us the torch dialect as-imported from the
FX graph
2. `OutputType.TORCH`: This gets us the torch dialect after the raw
torch goes through DecomposeComplexOps and ReduceOpVariants.

Prior to this, there was no way of accessing the Torch dialect in
backend compliant form (right after running the
`torchdynamo-export-to-torch-backend-pipeline`) because both
[here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/fx.py?L33)
and
[here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/compiler_utils.py?L138)
the same `OutputType.TORCH` were used, meaning the 2nd condition would
never be reached.

Since the default behavior is unchanged, this is an NFC.
pull/3543/head
Sambhav Jain 2024-07-14 10:33:47 -07:00 committed by GitHub
parent 5e4f00acb1
commit cdbcf519f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 10 deletions

View File

@ -82,12 +82,12 @@ def run_pipeline_with_repro_report(
class OutputType(Enum):
# Output torch dialect. When converting from FX, this will be immediately
# after the import from FX to MLIR. When converting from torchscript,
# this will come after some cleanup passes which attempt to de-alias,
# decompose and infer shapes. These should be roughly the same level of
# abstraction since those steps are done within PyTorch itself
# when coming directly from Dynamo/FX.
# Output torch dialect in backend form. When converting from TorchDynamo,
# this comes after some decomposition and reduce op variants passes are
# applied to the raw torch dialect. When converting from TorchScript, this
# comes after some cleanup passes which attempt to de-alias, decompose and infer shapes.
# These should be roughly the same level of abstraction since those
# steps are done within PyTorch itself when coming directly from Dynamo/FX.
TORCH = "torch"
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
@ -104,7 +104,8 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"
# Raw output of the JIT IR importer. This is not expected to be useful
# Raw output of the JIT IR importer in the TorchScript frontend or that of
# the FX IR importer in the TorchDynamo frontend. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"

View File

@ -30,7 +30,7 @@ def _module_lowering(
extra_library_file_name=None,
):
if output_type == OutputType.TORCH:
if output_type == OutputType.RAW:
if verbose:
print(torch_mod)
return torch_mod
@ -50,7 +50,7 @@ def _module_lowering(
def export_and_import(
f: Union[nn.Module, ExportedProgram],
*args,
output_type: Union[str, OutputType] = OutputType.TORCH,
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
@ -99,7 +99,7 @@ def export_and_import(
def stateless_fx_import(
gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",