mirror of https://github.com/llvm/torch-mlir
[FxImporter] Add an e2e test example for FxImporter (#3331)
parent
75d1d72059
commit
20d4d16d32
17
README.md
17
README.md
|
@ -76,6 +76,23 @@ pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/ex
|
|||
|
||||
## Demos
|
||||
|
||||
### FxImporter ResNet18
|
||||
```shell
|
||||
# Get the latest example if you haven't checked out the code
|
||||
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py
|
||||
|
||||
# Run ResNet18 as a standalone script.
|
||||
python projects/pt1/examples/fximporter_resnet18.py
|
||||
|
||||
# Output
|
||||
load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
|
||||
...
|
||||
PyTorch prediction
|
||||
[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)]
|
||||
torch-mlir prediction
|
||||
[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)]
|
||||
```
|
||||
|
||||
### TorchScript ResNet18
|
||||
|
||||
Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
DEFAULT_IMAGE_URL = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
DEFAULT_LABEL_URL = (
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt"
|
||||
)
|
||||
|
||||
|
||||
def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL):
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels(url: str = DEFAULT_LABEL_URL):
|
||||
classes_text = requests.get(
|
||||
url=url,
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res, labels):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
|
@ -0,0 +1,59 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
import torchvision.models as models
|
||||
from torch_mlir import fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from torch_mlir_e2e_test.configs.utils import (
|
||||
recursively_convert_to_numpy,
|
||||
)
|
||||
|
||||
sys.path.append(str(Path(__file__).absolute().parent))
|
||||
from _example_utils import (
|
||||
top3_possibilities,
|
||||
load_and_preprocess_image,
|
||||
load_labels,
|
||||
DEFAULT_IMAGE_URL,
|
||||
)
|
||||
|
||||
|
||||
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
|
||||
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
|
||||
labels = load_labels()
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True).eval()
|
||||
module = fx.export_and_import(
|
||||
resnet18,
|
||||
torch.ones(1, 3, 224, 224),
|
||||
output_type="linalg-on-tensors",
|
||||
func_name=resnet18.__class__.__name__,
|
||||
)
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(module)
|
||||
fx_module = backend.load(compiled)
|
||||
|
||||
params = {
|
||||
**dict(resnet18.named_buffers(remove_duplicate=False)),
|
||||
}
|
||||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_flat = list(params_flat)
|
||||
with torch.no_grad():
|
||||
numpy_inputs = recursively_convert_to_numpy(params_flat + [img])
|
||||
|
||||
golden_prediction = top3_possibilities(resnet18.forward(img), labels)
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
|
||||
prediction = top3_possibilities(
|
||||
torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)),
|
||||
labels,
|
||||
)
|
||||
print("torch-mlir prediction")
|
||||
print(prediction)
|
|
@ -4,71 +4,36 @@
|
|||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
|
||||
from torch_mlir import torchscript
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
sys.path.append(str(Path(__file__).absolute().parent))
|
||||
from _example_utils import (
|
||||
top3_possibilities,
|
||||
load_and_preprocess_image,
|
||||
load_labels,
|
||||
DEFAULT_IMAGE_URL,
|
||||
)
|
||||
|
||||
|
||||
def predictions(torch_func, jit_func, img, labels):
|
||||
golden_prediction = top3_possibilities(torch_func(img))
|
||||
golden_prediction = top3_possibilities(torch_func(img), labels)
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())), labels)
|
||||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
|
||||
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
|
||||
labels = load_labels()
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
resnet18 = models.resnet18(pretrained=True).eval()
|
||||
module = torchscript.compile(
|
||||
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue