Add a new `torch_mlir.compile` method.

This makes it much easier to convert models and hides all the
ClassAnnotator complexity.

This also adds a new example `torchscript_resnet18_all_output_types.py`
which shows the ResNet18 IR for all output types.

Also,

- This moves `run_pipeline_with_repro_report` to
  `torch_mlir.compiler_utils`.
pull/772/head
Sean Silva 2022-04-20 00:30:09 +00:00
parent 578d0ec292
commit 075464fa74
12 changed files with 260 additions and 241 deletions

View File

@ -3,20 +3,19 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from PIL import Image
import requests
import torch
import torchvision.models as models
from torchvision import transforms
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.passmanager import PassManager
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
mb = ModuleBuilder()
def load_and_preprocess_image(url: str):
headers = {
'User-Agent':
@ -60,59 +59,17 @@ def predictions(torch_func, jit_func, img, labels):
print("torch-mlir prediction")
print(prediction)
class ResNet18Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.resnet = models.resnet18(pretrained=True)
self.train(False)
def forward(self, img):
return self.resnet.forward(img)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.s = ResNet18Module()
def forward(self, x):
return self.s.forward(x)
image_url = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
import sys
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
print("load image from " + image_url, file=sys.stderr)
img = load_and_preprocess_image(image_url)
labels = load_labels()
test_module = TestModule()
class_annotator = ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
class_annotator.exportNone(recursivescriptmodule._c._type())
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
class_annotator.annotateArgs(
recursivescriptmodule._c._type(),
["forward"],
[
None,
([-1, -1, -1, -1], torch.float32, True),
],
)
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
with mb.module.context:
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
pm.run(mb.module)
compiled = backend.compile(mb.module)
compiled = backend.compile(module)
jit_module = backend.load(compiled)
predictions(test_module.forward, jit_module.forward, img, labels)
predictions(resnet18.forward, jit_module.forward, img, labels)

View File

@ -0,0 +1,20 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
import torchvision
import torch_mlir
resnet18 = torchvision.models.resnet18(pretrained=True)
resnet18.eval()
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.TORCH)
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
# TODO: Debug why this is so slow.
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.TOSA)
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))

File diff suppressed because one or more lines are too long

View File

@ -20,6 +20,14 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
declare_mlir_python_sources(TorchMLIRPythonSources)
declare_mlir_python_sources(TorchMLIRPythonExtensions)
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources
SOURCES
__init__.py
compiler_utils.py
)
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
ADD_TO_PARENT TorchMLIRPythonSources

View File

@ -0,0 +1,94 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import List
from enum import Enum
import torch
from torch_mlir.passmanager import PassManager
from .compiler_utils import run_pipeline_with_repro_report
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
class OutputType(Enum):
"""The kind of output that `torch_mlir.compile` can produce.
In MLIR terminology, this describes the mix of dialects that will be
produced by the conversion process.
"""
# This output type consists of `torch` dialect ops that have been converted
# maximally to value semantics, decomposed, and shapes have been inferred.
TORCH = 0
# This output type consists of `tosa` dialect ops. It can be thought of
# as taking the `TORCH` output type and lowering it to TOSA.
TOSA = 1
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
# as taking the `TORCH` output type and lowering it so that tensor
# computations are done with `linalg`-on-tensors ops.
LINALG_ON_TENSORS = 2
def compile(model: torch.nn.Module,
example_args: List[torch.Tensor],
output_type: OutputType = OutputType.TORCH):
"""Convert a PyTorch model to MLIR.
Args:
model: The PyTorch model to convert.
example_args: A list of example arguments to use when inferring the
shapes of the arguments to `forward` method of the model.
A single tensor is treated as a list of a single tensor.
output_type: The kind of output to produce. See `OutputType` for more
details.
Returns:
An MLIR module that contains the converted model in the specified
output type.
"""
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
# `torch.jit.trace_module` for API inspiration.
# TODO: Support dynamic dimension sizes. See `torch.onnx.export`'s
# `dynamic_axes` for API inspiration, or do something more ergonomic
# like a tensor wrapper possibly.
# TODO: Support tracing the model instead of scripting it.
scripted = torch.jit.script(model)
if isinstance(example_args, torch.Tensor):
example_args = [example_args]
class_annotator = ClassAnnotator()
forward_annotation = [None]
for arg in example_args:
# Assume that all tensors have value semantics for now.
forward_annotation.append((list(arg.shape), arg.dtype, True))
class_annotator.exportNone(scripted._c._type())
class_annotator.exportPath(scripted._c._type(), ["forward"])
class_annotator.annotateArgs(
scripted._c._type(), ["forward"], forward_annotation)
mb = ModuleBuilder()
mb.import_module(scripted._c, class_annotator)
run_pipeline_with_repro_report(mb.module,
"torchscript-module-to-torch-backend-pipeline",
"Lowering TorchScript IR -> Torch Backend IR")
if output_type == OutputType.TORCH:
pass
elif output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-tosa-backend-pipeline",
"Lowering Torch Backend IR -> TOSA Backend IR")
else:
assert output_type == OutputType.LINALG_ON_TENSORS
run_pipeline_with_repro_report(
mb.module,
"torch-backend-to-linalg-on-tensors-backend-pipeline",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
return mb.module

View File

@ -0,0 +1,58 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from io import StringIO
import os
import sys
import tempfile
from torch_mlir.passmanager import PassManager
from torch_mlir.ir import StringAttr
def get_module_name_for_debug_dump(module):
"""Gets a name suitable for a debug dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context:
pm = PassManager.parse(pipeline)
pm.run(module)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
debug_options="-print-ir-after-all -mlir-disable-threading"
raise Exception(f"""
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
Add '{debug_options}' to get the IR dump for debugging purpose.
""") from None
finally:
sys.stderr = original_stderr

View File

@ -13,8 +13,7 @@ from torch_mlir.runtime import *
# Imported for side effects.
import torch_mlir.all_passes_registration
import torch_mlir.dialects.torch
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from .abc import LinalgOnTensorsBackend

View File

@ -14,7 +14,8 @@ import torch
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,

View File

@ -14,7 +14,7 @@ import torch
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,

View File

@ -12,7 +12,8 @@ import torch
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
def recursively_convert_to_numpy(o: Any):
if isinstance(o, torch.Tensor):

View File

@ -7,8 +7,8 @@ from torch_mlir.ir import *
from torch_mlir.passmanager import *
# Imported for side effects.
import torch_mlir.all_passes_registration
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from .abc import TosaBackend

View File

@ -1,58 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from io import StringIO
import os
import sys
import tempfile
from torch_mlir.passmanager import PassManager
from torch_mlir.ir import StringAttr
def get_module_name_for_debug_dump(module):
"""Gets a name suitable for a debug dump.
The name is not guaranteed to be unique.
"""
if not "torch.debug_module_name" in module.operation.attributes:
return "UnnammedModule"
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
def run_pipeline_with_repro_report(module,
pipeline: str,
description: str):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
asm_for_error_report = module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
# Lower module in place to make it ready for compiler backends.
with module.context:
pm = PassManager.parse(pipeline)
pm.run(module)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
# tests, this can be a big disk cost (also, /tmp/ is frequently a
# RAM fs, which increases worries about capacity).
# - don't have colliding filenames (hard to do without cluttering
# up /tmp)
# - if we do have have colliding filenames, writes should at least
# avoid being racy.
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
with open(filename, 'w') as f:
f.write(asm_for_error_report)
debug_options="-print-ir-after-all -mlir-disable-threading"
raise Exception(f"""
{description} failed with the following diagnostics:
{sys.stderr.getvalue()}
Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
Add '{debug_options}' to get the IR dump for debugging purpose.
""") from None
finally:
sys.stderr = original_stderr