2021-09-30 00:03:40 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
2021-07-29 23:06:02 +08:00
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
import requests
|
|
|
|
import torch
|
|
|
|
import torchvision.models as models
|
|
|
|
from torchvision import transforms
|
|
|
|
|
2021-09-25 06:03:25 +08:00
|
|
|
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
2021-07-29 23:06:02 +08:00
|
|
|
|
2021-09-28 07:44:07 +08:00
|
|
|
from torch_mlir.passmanager import PassManager
|
|
|
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
|
|
|
|
2021-07-29 23:06:02 +08:00
|
|
|
|
2021-09-25 06:03:25 +08:00
|
|
|
mb = ModuleBuilder()
|
2021-08-10 04:55:20 +08:00
|
|
|
|
2021-07-29 23:06:02 +08:00
|
|
|
def load_and_preprocess_image(url: str):
|
2021-08-10 04:55:20 +08:00
|
|
|
headers = {
|
|
|
|
'User-Agent':
|
|
|
|
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
|
|
|
}
|
|
|
|
img = Image.open(requests.get(url, headers=headers,
|
|
|
|
stream=True).raw).convert("RGB")
|
2021-07-29 23:06:02 +08:00
|
|
|
# preprocessing pipeline
|
2021-08-10 04:55:20 +08:00
|
|
|
preprocess = transforms.Compose([
|
|
|
|
transforms.Resize(256),
|
|
|
|
transforms.CenterCrop(224),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
|
std=[0.229, 0.224, 0.225]),
|
|
|
|
])
|
2021-07-29 23:06:02 +08:00
|
|
|
img_preprocessed = preprocess(img)
|
|
|
|
return torch.unsqueeze(img_preprocessed, 0)
|
|
|
|
|
|
|
|
|
|
|
|
def load_labels():
|
|
|
|
classes_text = requests.get(
|
|
|
|
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
|
|
|
stream=True,
|
|
|
|
).text
|
|
|
|
labels = [line.strip() for line in classes_text.splitlines()]
|
|
|
|
return labels
|
|
|
|
|
|
|
|
|
|
|
|
def top3_possibilities(res):
|
|
|
|
_, indexes = torch.sort(res, descending=True)
|
|
|
|
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
|
|
|
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
|
|
|
return top3
|
|
|
|
|
|
|
|
|
|
|
|
def predictions(torch_func, jit_func, img, labels):
|
|
|
|
golden_prediction = top3_possibilities(torch_func(img))
|
|
|
|
print("PyTorch prediction")
|
|
|
|
print(golden_prediction)
|
2021-09-25 06:03:25 +08:00
|
|
|
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
|
2021-09-28 07:44:07 +08:00
|
|
|
print("torch-mlir prediction")
|
2021-07-29 23:06:02 +08:00
|
|
|
print(prediction)
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet18Module(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.resnet = models.resnet18(pretrained=True)
|
|
|
|
self.train(False)
|
|
|
|
|
|
|
|
def forward(self, img):
|
|
|
|
return self.resnet.forward(img)
|
|
|
|
|
|
|
|
|
|
|
|
class TestModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.s = ResNet18Module()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.s.forward(x)
|
|
|
|
|
|
|
|
|
2021-08-10 04:55:20 +08:00
|
|
|
image_url = (
|
|
|
|
"https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
|
|
|
)
|
|
|
|
import sys
|
|
|
|
|
|
|
|
print("load image from " + image_url, file=sys.stderr)
|
|
|
|
img = load_and_preprocess_image(image_url)
|
|
|
|
labels = load_labels()
|
|
|
|
|
2021-07-29 23:06:02 +08:00
|
|
|
test_module = TestModule()
|
2021-09-25 06:03:25 +08:00
|
|
|
class_annotator = ClassAnnotator()
|
2021-07-29 23:06:02 +08:00
|
|
|
recursivescriptmodule = torch.jit.script(test_module)
|
|
|
|
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
|
|
|
|
|
|
|
|
class_annotator.exportNone(recursivescriptmodule._c._type())
|
|
|
|
class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"])
|
|
|
|
class_annotator.annotateArgs(
|
|
|
|
recursivescriptmodule._c._type(),
|
|
|
|
["forward"],
|
2021-08-10 04:55:20 +08:00
|
|
|
[
|
|
|
|
None,
|
|
|
|
([-1, -1, -1, -1], torch.float32, True),
|
|
|
|
],
|
2021-07-29 23:06:02 +08:00
|
|
|
)
|
|
|
|
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
|
|
|
mb.import_module(recursivescriptmodule._c, class_annotator)
|
|
|
|
|
2021-09-28 07:44:07 +08:00
|
|
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
|
|
|
with mb.module.context:
|
2021-10-08 10:07:03 +08:00
|
|
|
pm = PassManager.parse('torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline')
|
2021-09-28 07:44:07 +08:00
|
|
|
pm.run(mb.module)
|
2021-09-25 06:03:25 +08:00
|
|
|
|
2021-09-28 07:44:07 +08:00
|
|
|
compiled = backend.compile(mb.module)
|
2021-07-29 23:06:02 +08:00
|
|
|
jit_module = backend.load(compiled)
|
|
|
|
|
|
|
|
predictions(test_module.forward, jit_module.forward, img, labels)
|