[FxImporter] Add an e2e test example for FxImporter (#3331)

pull/3335/head
penguin_wwy 2024-05-14 00:45:19 +08:00 committed by GitHub
parent 75d1d72059
commit 20d4d16d32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 141 additions and 48 deletions

View File

@ -76,6 +76,23 @@ pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/ex
## Demos
### FxImporter ResNet18
```shell
# Get the latest example if you haven't checked out the code
wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py
# Run ResNet18 as a standalone script.
python projects/pt1/examples/fximporter_resnet18.py
# Output
load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg
...
PyTorch prediction
[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)]
torch-mlir prediction
[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)]
```
### TorchScript ResNet18
Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend:

View File

@ -0,0 +1,52 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from PIL import Image
import requests
import torch
from torchvision import transforms
DEFAULT_IMAGE_URL = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
DEFAULT_LABEL_URL = (
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt"
)
def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
}
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)
def load_labels(url: str = DEFAULT_LABEL_URL):
classes_text = requests.get(
url=url,
stream=True,
).text
labels = [line.strip() for line in classes_text.splitlines()]
return labels
def top3_possibilities(res, labels):
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
return top3

View File

@ -0,0 +1,59 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from pathlib import Path
import torch
import torch.utils._pytree as pytree
import torchvision.models as models
from torch_mlir import fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
)
sys.path.append(str(Path(__file__).absolute().parent))
from _example_utils import (
top3_possibilities,
load_and_preprocess_image,
load_labels,
DEFAULT_IMAGE_URL,
)
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
labels = load_labels()
resnet18 = models.resnet18(pretrained=True).eval()
module = fx.export_and_import(
resnet18,
torch.ones(1, 3, 224, 224),
output_type="linalg-on-tensors",
func_name=resnet18.__class__.__name__,
)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
fx_module = backend.load(compiled)
params = {
**dict(resnet18.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
with torch.no_grad():
numpy_inputs = recursively_convert_to_numpy(params_flat + [img])
golden_prediction = top3_possibilities(resnet18.forward(img), labels)
print("PyTorch prediction")
print(golden_prediction)
prediction = top3_possibilities(
torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)),
labels,
)
print("torch-mlir prediction")
print(prediction)

View File

@ -4,71 +4,36 @@
# Also available under a BSD-style license. See LICENSE.
import sys
from PIL import Image
import requests
from pathlib import Path
import torch
import torchvision.models as models
from torchvision import transforms
from torch_mlir import torchscript
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
def load_and_preprocess_image(url: str):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
}
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)
def load_labels():
classes_text = requests.get(
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
stream=True,
).text
labels = [line.strip() for line in classes_text.splitlines()]
return labels
def top3_possibilities(res):
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
return top3
sys.path.append(str(Path(__file__).absolute().parent))
from _example_utils import (
top3_possibilities,
load_and_preprocess_image,
load_labels,
DEFAULT_IMAGE_URL,
)
def predictions(torch_func, jit_func, img, labels):
golden_prediction = top3_possibilities(torch_func(img))
golden_prediction = top3_possibilities(torch_func(img), labels)
print("PyTorch prediction")
print(golden_prediction)
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())), labels)
print("torch-mlir prediction")
print(prediction)
image_url = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
print("load image from " + image_url, file=sys.stderr)
img = load_and_preprocess_image(image_url)
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
labels = load_labels()
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
resnet18 = models.resnet18(pretrained=True).eval()
module = torchscript.compile(
resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors"
)