torch_mlir.compile: Add OutputType.RAW

This can help with development and reporting bugs.
pull/865/head snapshot-20220519.460
Sean Silva 2022-05-19 10:28:29 +00:00
parent 10c8e3c593
commit 2af53ce434
2 changed files with 20 additions and 5 deletions

View File

@ -34,6 +34,11 @@ print(torch_mlir.compile(TanhModule(), placeholder))
# CHECK-LABEL: @forward
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
# Basic smoke test for the raw output type.
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
# CHECK: torch.nn_module {
# CHECK: } : !torch.nn.Module<"__torch__.TanhModule">
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -30,6 +30,9 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = 2
# Raw output of the JIT IR importer. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = 3
class TensorPlaceholder:
@ -140,21 +143,28 @@ def compile(model: torch.nn.Module,
mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)
if output_type == OutputType.RAW:
return mb.module
run_pipeline_with_repro_report(mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript IR -> Torch Backend IR")
if output_type == OutputType.TORCH:
pass
elif output_type == OutputType.TOSA:
return mb.module
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-tosa-backend-pipeline",
"Lowering Torch Backend IR -> TOSA Backend IR")
else:
assert output_type == OutputType.LINALG_ON_TENSORS
return mb.module
if output_type == OutputType.LINALG_ON_TENSORS:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
return mb.module
return mb.module
raise Exception(f"Unknown OutputType: {output_type}")