torch-mlir/projects/pt1/examples/_example_utils.py

53 lines
1.7 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from PIL import Image
import requests
import torch
from torchvision import transforms
DEFAULT_IMAGE_URL = (
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
)
DEFAULT_LABEL_URL = (
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt"
)
def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL):
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
}
img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB")
# preprocessing pipeline
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img_preprocessed = preprocess(img)
return torch.unsqueeze(img_preprocessed, 0)
def load_labels(url: str = DEFAULT_LABEL_URL):
classes_text = requests.get(
url=url,
stream=True,
).text
labels = [line.strip() for line in classes_text.splitlines()]
return labels
def top3_possibilities(res, labels):
_, indexes = torch.sort(res, descending=True)
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
return top3