mirror of https://github.com/llvm/torch-mlir
Add a new `torch_mlir.compile` method.
This makes it much easier to convert models and hides all the ClassAnnotator complexity. This also adds a new example `torchscript_resnet18_all_output_types.py` which shows the ResNet18 IR for all output types. Also, - This moves `run_pipeline_with_repro_report` to `torch_mlir.compiler_utils`.pull/772/head
parent
578d0ec292
commit
075464fa74
|
@ -3,20 +3,19 @@
|
|||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
import torch_mlir
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
'User-Agent':
|
||||
|
@ -60,59 +59,17 @@ def predictions(torch_func, jit_func, img, labels):
|
|||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
|
||||
class ResNet18Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet18(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.s = ResNet18Module()
|
||||
|
||||
def forward(self, x):
|
||||
return self.s.forward(x)
|
||||
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
import sys
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
|
||||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
recursivescriptmodule._c._type(),
|
||||
["forward"],
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
],
|
||||
)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mb.module)
|
||||
|
||||
compiled = backend.compile(mb.module)
|
||||
compiled = backend.compile(module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
||||
predictions(resnet18.forward, jit_module.forward, img, labels)
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import torch_mlir
|
||||
|
||||
resnet18 = torchvision.models.resnet18(pretrained=True)
|
||||
resnet18.eval()
|
||||
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.TORCH)
|
||||
print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
|
||||
print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
||||
# TODO: Debug why this is so slow.
|
||||
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=torch_mlir.OutputType.TOSA)
|
||||
print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10))
|
File diff suppressed because one or more lines are too long
|
@ -20,6 +20,14 @@ add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")
|
|||
declare_mlir_python_sources(TorchMLIRPythonSources)
|
||||
declare_mlir_python_sources(TorchMLIRPythonExtensions)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
SOURCES
|
||||
__init__.py
|
||||
compiler_utils.py
|
||||
)
|
||||
|
||||
declare_mlir_python_sources(TorchMLIRPythonSources.Dialects
|
||||
ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
|
||||
ADD_TO_PARENT TorchMLIRPythonSources
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
|
||||
|
||||
class OutputType(Enum):
|
||||
"""The kind of output that `torch_mlir.compile` can produce.
|
||||
|
||||
In MLIR terminology, this describes the mix of dialects that will be
|
||||
produced by the conversion process.
|
||||
"""
|
||||
# This output type consists of `torch` dialect ops that have been converted
|
||||
# maximally to value semantics, decomposed, and shapes have been inferred.
|
||||
TORCH = 0
|
||||
# This output type consists of `tosa` dialect ops. It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it to TOSA.
|
||||
TOSA = 1
|
||||
# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
|
||||
# `arith` ops (and also `math` and `tm_tensor`). It can be thought of
|
||||
# as taking the `TORCH` output type and lowering it so that tensor
|
||||
# computations are done with `linalg`-on-tensors ops.
|
||||
LINALG_ON_TENSORS = 2
|
||||
|
||||
|
||||
def compile(model: torch.nn.Module,
|
||||
example_args: List[torch.Tensor],
|
||||
output_type: OutputType = OutputType.TORCH):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to convert.
|
||||
example_args: A list of example arguments to use when inferring the
|
||||
shapes of the arguments to `forward` method of the model.
|
||||
A single tensor is treated as a list of a single tensor.
|
||||
output_type: The kind of output to produce. See `OutputType` for more
|
||||
details.
|
||||
|
||||
Returns:
|
||||
An MLIR module that contains the converted model in the specified
|
||||
output type.
|
||||
"""
|
||||
|
||||
# TODO: Don't hardcode "forward". See `torch.onnx.export` and
|
||||
# `torch.jit.trace_module` for API inspiration.
|
||||
# TODO: Support dynamic dimension sizes. See `torch.onnx.export`'s
|
||||
# `dynamic_axes` for API inspiration, or do something more ergonomic
|
||||
# like a tensor wrapper possibly.
|
||||
# TODO: Support tracing the model instead of scripting it.
|
||||
scripted = torch.jit.script(model)
|
||||
|
||||
if isinstance(example_args, torch.Tensor):
|
||||
example_args = [example_args]
|
||||
|
||||
class_annotator = ClassAnnotator()
|
||||
forward_annotation = [None]
|
||||
for arg in example_args:
|
||||
# Assume that all tensors have value semantics for now.
|
||||
forward_annotation.append((list(arg.shape), arg.dtype, True))
|
||||
class_annotator.exportNone(scripted._c._type())
|
||||
class_annotator.exportPath(scripted._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
scripted._c._type(), ["forward"], forward_annotation)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.import_module(scripted._c, class_annotator)
|
||||
|
||||
run_pipeline_with_repro_report(mb.module,
|
||||
"torchscript-module-to-torch-backend-pipeline",
|
||||
"Lowering TorchScript IR -> Torch Backend IR")
|
||||
|
||||
if output_type == OutputType.TORCH:
|
||||
pass
|
||||
elif output_type == OutputType.TOSA:
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"torch-backend-to-tosa-backend-pipeline",
|
||||
"Lowering Torch Backend IR -> TOSA Backend IR")
|
||||
else:
|
||||
assert output_type == OutputType.LINALG_ON_TENSORS
|
||||
run_pipeline_with_repro_report(
|
||||
mb.module,
|
||||
"torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
return mb.module
|
|
@ -0,0 +1,58 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from io import StringIO
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
def get_module_name_for_debug_dump(module):
|
||||
"""Gets a name suitable for a debug dump.
|
||||
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
module_name = get_module_name_for_debug_dump(module)
|
||||
try:
|
||||
original_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
pm = PassManager.parse(pipeline)
|
||||
pm.run(module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
debug_options="-print-ir-after-all -mlir-disable-threading"
|
||||
raise Exception(f"""
|
||||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
Add '{debug_options}' to get the IR dump for debugging purpose.
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
|
@ -13,8 +13,7 @@ from torch_mlir.runtime import *
|
|||
# Imported for side effects.
|
||||
import torch_mlir.all_passes_registration
|
||||
import torch_mlir.dialects.torch
|
||||
|
||||
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from .abc import LinalgOnTensorsBackend
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@ import torch
|
|||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch
|
|||
|
||||
from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
|
||||
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from .utils import (
|
||||
recursively_convert_to_numpy,
|
||||
recursively_convert_from_numpy,
|
||||
|
|
|
@ -12,7 +12,8 @@ import torch
|
|||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations
|
||||
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
|
||||
def recursively_convert_to_numpy(o: Any):
|
||||
if isinstance(o, torch.Tensor):
|
||||
|
|
|
@ -7,8 +7,8 @@ from torch_mlir.ir import *
|
|||
from torch_mlir.passmanager import *
|
||||
# Imported for side effects.
|
||||
import torch_mlir.all_passes_registration
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
|
||||
from torch_mlir_e2e_test.utils import run_pipeline_with_repro_report
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from .abc import TosaBackend
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from io import StringIO
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir.ir import StringAttr
|
||||
|
||||
def get_module_name_for_debug_dump(module):
|
||||
"""Gets a name suitable for a debug dump.
|
||||
|
||||
The name is not guaranteed to be unique.
|
||||
"""
|
||||
if not "torch.debug_module_name" in module.operation.attributes:
|
||||
return "UnnammedModule"
|
||||
return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
|
||||
|
||||
def run_pipeline_with_repro_report(module,
|
||||
pipeline: str,
|
||||
description: str):
|
||||
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
|
||||
module_name = get_module_name_for_debug_dump(module)
|
||||
try:
|
||||
original_stderr = sys.stderr
|
||||
sys.stderr = StringIO()
|
||||
asm_for_error_report = module.operation.get_asm(
|
||||
large_elements_limit=10, enable_debug_info=True)
|
||||
# Lower module in place to make it ready for compiler backends.
|
||||
with module.context:
|
||||
pm = PassManager.parse(pipeline)
|
||||
pm.run(module)
|
||||
except Exception as e:
|
||||
# TODO: More robust.
|
||||
# - don't arbitrarily clutter up /tmp. When a test suite has many
|
||||
# tests, this can be a big disk cost (also, /tmp/ is frequently a
|
||||
# RAM fs, which increases worries about capacity).
|
||||
# - don't have colliding filenames (hard to do without cluttering
|
||||
# up /tmp)
|
||||
# - if we do have have colliding filenames, writes should at least
|
||||
# avoid being racy.
|
||||
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
|
||||
with open(filename, 'w') as f:
|
||||
f.write(asm_for_error_report)
|
||||
debug_options="-print-ir-after-all -mlir-disable-threading"
|
||||
raise Exception(f"""
|
||||
{description} failed with the following diagnostics:
|
||||
{sys.stderr.getvalue()}
|
||||
|
||||
Error can be reproduced with:
|
||||
$ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
|
||||
Add '{debug_options}' to get the IR dump for debugging purpose.
|
||||
""") from None
|
||||
finally:
|
||||
sys.stderr = original_stderr
|
Loading…
Reference in New Issue