mirror of https://github.com/llvm/torch-mlir
torch_mlir.compile: Add OutputType.RAW
This can help with development and reporting bugs.pull/865/head snapshot-20220519.460
parent
10c8e3c593
commit
2af53ce434
|
@ -34,6 +34,11 @@ print(torch_mlir.compile(TanhModule(), placeholder))
|
||||||
# CHECK-LABEL: @forward
|
# CHECK-LABEL: @forward
|
||||||
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
# CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32>
|
||||||
|
|
||||||
|
# Basic smoke test for the raw output type.
|
||||||
|
print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW))
|
||||||
|
# CHECK: torch.nn_module {
|
||||||
|
# CHECK: } : !torch.nn.Module<"__torch__.TanhModule">
|
||||||
|
|
||||||
class MmModule(torch.nn.Module):
|
class MmModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -30,6 +30,9 @@ class OutputType(Enum):
|
||||||
# as taking the `TORCH` output type and lowering it so that tensor
|
# as taking the `TORCH` output type and lowering it so that tensor
|
||||||
# computations are done with `linalg`-on-tensors ops.
|
# computations are done with `linalg`-on-tensors ops.
|
||||||
LINALG_ON_TENSORS = 2
|
LINALG_ON_TENSORS = 2
|
||||||
|
# Raw output of the JIT IR importer. This is not expected to be useful
|
||||||
|
# for end-users, but can be convenient for development or reporting bugs.
|
||||||
|
RAW = 3
|
||||||
|
|
||||||
|
|
||||||
class TensorPlaceholder:
|
class TensorPlaceholder:
|
||||||
|
@ -140,21 +143,28 @@ def compile(model: torch.nn.Module,
|
||||||
mb = ModuleBuilder()
|
mb = ModuleBuilder()
|
||||||
mb.import_module(scripted._c, class_annotator)
|
mb.import_module(scripted._c, class_annotator)
|
||||||
|
|
||||||
|
if output_type == OutputType.RAW:
|
||||||
|
return mb.module
|
||||||
|
|
||||||
run_pipeline_with_repro_report(mb.module,
|
run_pipeline_with_repro_report(mb.module,
|
||||||
"torchscript-module-to-torch-backend-pipeline",
|
"torchscript-module-to-torch-backend-pipeline",
|
||||||
"Lowering TorchScript IR -> Torch Backend IR")
|
"Lowering TorchScript IR -> Torch Backend IR")
|
||||||
|
|
||||||
if output_type == OutputType.TORCH:
|
if output_type == OutputType.TORCH:
|
||||||
pass
|
return mb.module
|
||||||
elif output_type == OutputType.TOSA:
|
|
||||||
|
if output_type == OutputType.TOSA:
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
"torch-backend-to-tosa-backend-pipeline",
|
"torch-backend-to-tosa-backend-pipeline",
|
||||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||||
else:
|
return mb.module
|
||||||
assert output_type == OutputType.LINALG_ON_TENSORS
|
|
||||||
|
if output_type == OutputType.LINALG_ON_TENSORS:
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
mb.module,
|
mb.module,
|
||||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||||
return mb.module
|
return mb.module
|
||||||
|
|
||||||
|
raise Exception(f"Unknown OutputType: {output_type}")
|
||||||
|
|
Loading…
Reference in New Issue