mirror of https://github.com/llvm/torch-mlir
Add an e2e test example for Resnet18
Show an example of classifying image from https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg with Resnet18pull/264/head
parent
8494455282
commit
93816ee21a
|
@ -0,0 +1,106 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import typing
|
||||
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
logging.enable()
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
|
||||
|
||||
def predictions(torch_func, jit_func, img, labels):
|
||||
golden_prediction = top3_possibilities(torch_func(img))
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img)))
|
||||
print("NPCOMP prediction")
|
||||
print(prediction)
|
||||
|
||||
|
||||
class ResNet18Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet18(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.s = ResNet18Module()
|
||||
|
||||
def forward(self, x):
|
||||
return self.s.forward(x)
|
||||
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
|
||||
|
||||
class_annotator.exportNone(recursivescriptmodule._c._type())
|
||||
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
recursivescriptmodule._c._type(),
|
||||
["forward"],
|
||||
[None, ([-1, -1, -1, -1], torch.float32, True),],
|
||||
)
|
||||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
|
||||
backend = refjit.RefjitNpcompBackend()
|
||||
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
image_url = (
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
)
|
||||
print("load image from " + image_url)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
Loading…
Reference in New Issue