mirror of https://github.com/llvm/torch-mlir
[NFC] Expose both raw Torch dialect and Torch dialect in backend form with Dynamo/FX (#3541)
This is a non-functional change. It merely allows intercepting the Torch dialect during TorchDynamo export at two stages: 1. `OutputType.RAW`: This gets us the torch dialect as-imported from the FX graph 2. `OutputType.TORCH`: This gets us the torch dialect after the raw torch goes through DecomposeComplexOps and ReduceOpVariants. Prior to this, there was no way of accessing the Torch dialect in backend compliant form (right after running the `torchdynamo-export-to-torch-backend-pipeline`) because both [here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/fx.py?L33) and [here](https://sourcegraph.com/github.com/llvm/torch-mlir@5e4f00acb13f3f849a05e5ac28ee39307a5fdbff/-/blob/python/torch_mlir/compiler_utils.py?L138) the same `OutputType.TORCH` were used, meaning the 2nd condition would never be reached. Since the default behavior is unchanged, this is an NFC.pull/3543/head
parent
5e4f00acb1
commit
cdbcf519f7
|
@ -82,12 +82,12 @@ def run_pipeline_with_repro_report(
|
||||||
|
|
||||||
class OutputType(Enum):
|
class OutputType(Enum):
|
||||||
|
|
||||||
# Output torch dialect. When converting from FX, this will be immediately
|
# Output torch dialect in backend form. When converting from TorchDynamo,
|
||||||
# after the import from FX to MLIR. When converting from torchscript,
|
# this comes after some decomposition and reduce op variants passes are
|
||||||
# this will come after some cleanup passes which attempt to de-alias,
|
# applied to the raw torch dialect. When converting from TorchScript, this
|
||||||
# decompose and infer shapes. These should be roughly the same level of
|
# comes after some cleanup passes which attempt to de-alias, decompose and infer shapes.
|
||||||
# abstraction since those steps are done within PyTorch itself
|
# These should be roughly the same level of abstraction since those
|
||||||
# when coming directly from Dynamo/FX.
|
# steps are done within PyTorch itself when coming directly from Dynamo/FX.
|
||||||
TORCH = "torch"
|
TORCH = "torch"
|
||||||
|
|
||||||
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
||||||
|
@ -104,7 +104,8 @@ class OutputType(Enum):
|
||||||
# as taking the `TORCH` output type and lowering it to StableHLO.
|
# as taking the `TORCH` output type and lowering it to StableHLO.
|
||||||
STABLEHLO = "stablehlo"
|
STABLEHLO = "stablehlo"
|
||||||
|
|
||||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
# Raw output of the JIT IR importer in the TorchScript frontend or that of
|
||||||
|
# the FX IR importer in the TorchDynamo frontend. This is not expected to be useful
|
||||||
# for end-users, but can be convenient for development or reporting bugs.
|
# for end-users, but can be convenient for development or reporting bugs.
|
||||||
RAW = "raw"
|
RAW = "raw"
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ def _module_lowering(
|
||||||
extra_library_file_name=None,
|
extra_library_file_name=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if output_type == OutputType.TORCH:
|
if output_type == OutputType.RAW:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(torch_mod)
|
print(torch_mod)
|
||||||
return torch_mod
|
return torch_mod
|
||||||
|
@ -50,7 +50,7 @@ def _module_lowering(
|
||||||
def export_and_import(
|
def export_and_import(
|
||||||
f: Union[nn.Module, ExportedProgram],
|
f: Union[nn.Module, ExportedProgram],
|
||||||
*args,
|
*args,
|
||||||
output_type: Union[str, OutputType] = OutputType.TORCH,
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
||||||
fx_importer: Optional[FxImporter] = None,
|
fx_importer: Optional[FxImporter] = None,
|
||||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||||
experimental_support_mutation: bool = False,
|
experimental_support_mutation: bool = False,
|
||||||
|
@ -99,7 +99,7 @@ def export_and_import(
|
||||||
|
|
||||||
def stateless_fx_import(
|
def stateless_fx_import(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
output_type: Union[str, OutputType] = OutputType.TORCH,
|
output_type: Union[str, OutputType] = OutputType.RAW,
|
||||||
fx_importer: Optional[FxImporter] = None,
|
fx_importer: Optional[FxImporter] = None,
|
||||||
hooks: Optional[FxImporterHooks] = None,
|
hooks: Optional[FxImporterHooks] = None,
|
||||||
model_name: str = "main",
|
model_name: str = "main",
|
||||||
|
|
Loading…
Reference in New Issue