# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. from PIL import Image import requests import torch from torchvision import transforms DEFAULT_IMAGE_URL = ( "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" ) DEFAULT_LABEL_URL = ( "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt" ) def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL): headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" } img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") # preprocessing pipeline preprocess = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) img_preprocessed = preprocess(img) return torch.unsqueeze(img_preprocessed, 0) def load_labels(url: str = DEFAULT_LABEL_URL): classes_text = requests.get( url=url, stream=True, ).text labels = [line.strip() for line in classes_text.splitlines()] return labels def top3_possibilities(res, labels): _, indexes = torch.sort(res, descending=True) percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] return top3