diff --git a/python/test/compile_api/basic.py b/python/test/compile_api/basic.py index 5e9d14159..af46f69c5 100644 --- a/python/test/compile_api/basic.py +++ b/python/test/compile_api/basic.py @@ -34,6 +34,11 @@ print(torch_mlir.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> +# Basic smoke test for the raw output type. +print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW)) +# CHECK: torch.nn_module { +# CHECK: } : !torch.nn.Module<"__torch__.TanhModule"> + class MmModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index b7410e1f0..740005f25 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -30,6 +30,9 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it so that tensor # computations are done with `linalg`-on-tensors ops. LINALG_ON_TENSORS = 2 + # Raw output of the JIT IR importer. This is not expected to be useful + # for end-users, but can be convenient for development or reporting bugs. + RAW = 3 class TensorPlaceholder: @@ -140,21 +143,28 @@ def compile(model: torch.nn.Module, mb = ModuleBuilder() mb.import_module(scripted._c, class_annotator) + if output_type == OutputType.RAW: + return mb.module + run_pipeline_with_repro_report(mb.module, "torchscript-module-to-torch-backend-pipeline", "Lowering TorchScript IR -> Torch Backend IR") if output_type == OutputType.TORCH: - pass - elif output_type == OutputType.TOSA: + return mb.module + + if output_type == OutputType.TOSA: run_pipeline_with_repro_report( mb.module, "torch-backend-to-tosa-backend-pipeline", "Lowering Torch Backend IR -> TOSA Backend IR") - else: - assert output_type == OutputType.LINALG_ON_TENSORS + return mb.module + + if output_type == OutputType.LINALG_ON_TENSORS: run_pipeline_with_repro_report( mb.module, "torch-backend-to-linalg-on-tensors-backend-pipeline", "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR") - return mb.module + return mb.module + + raise Exception(f"Unknown OutputType: {output_type}")