mirror of https://github.com/llvm/torch-mlir
torch_mlir.compile: Add OutputType.RAW
This can help with development and reporting bugs.pull/865/head snapshot-20220519.460
parent
10c8e3c593
commit
2af53ce434
|
@ -34,6 +34,11 @@ print(torch_mlir.compile(TanhModule(), placeholder))
|
|||
# CHECK-LABEL: @forward
|
||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||
|
||||
# Basic smoke test for the raw output type.
|
||||
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: } : !torch.nn.Module<"__torch__.TanhModule">
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -30,6 +30,9 @@ class OutputType(Enum):
|
|||
# as taking the `TORCH` output type and lowering it so that tensor
|
||||
# computations are done with `linalg`-on-tensors ops.
|
||||
LINALG_ON_TENSORS = 2
|
||||
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||
# for end-users, but can be convenient for development or reporting bugs.
|
||||
RAW = 3
|
||||
|
||||
|
||||
class TensorPlaceholder:
|
||||
|
@ -140,21 +143,28 @@ def compile(model: torch.nn.Module,
|
|||
mb = ModuleBuilder()
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
|
||||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
run_pipeline_with_repro_report(mb.module,
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Lowering TorchScript IR -> Torch Backend IR")
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
pass
|
||||
elif output_type == OutputType.TOSA:
|
||||
return mb.module
|
||||
|
||||
if output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"torch-backend-to-tosa-backend-pipeline",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
else:
|
||||
assert output_type == OutputType.LINALG_ON_TENSORS
|
||||
return mb.module
|
||||
|
||||
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
return mb.module
|
||||
|
||||
raise Exception(f"Unknown OutputType: {output_type}")
|
||||
|
|
Loading…
Reference in New Issue