diff --git a/frontends/pytorch/examples/torchscript_resnet18_e2e.py b/frontends/pytorch/examples/torchscript_resnet18_e2e.py new file mode 100644 index 000000000..2778a1d9b --- /dev/null +++ b/frontends/pytorch/examples/torchscript_resnet18_e2e.py @@ -0,0 +1,106 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +from PIL import Image +import requests +import torch +import torchvision.models as models +from torchvision import transforms +import typing + +import torch_mlir + +import npcomp +from npcomp.compiler.pytorch.backend import refjit, frontend_lowering, iree +from npcomp.compiler.utils import logging + +logging.enable() + +mb = torch_mlir.ModuleBuilder() + +def load_and_preprocess_image(url: str): + img = Image.open(requests.get(url, stream=True).raw).convert("RGB") + # preprocessing pipeline + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + img_preprocessed = preprocess(img) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(): + classes_text = requests.get( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res): + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 + top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] + return top3 + + +def predictions(torch_func, jit_func, img, labels): + golden_prediction = top3_possibilities(torch_func(img)) + print("PyTorch prediction") + print(golden_prediction) + prediction = top3_possibilities(torch.from_numpy(jit_func(img))) + print("NPCOMP prediction") + print(prediction) + + +class ResNet18Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.resnet = models.resnet18(pretrained=True) + self.train(False) + + def forward(self, img): + return self.resnet.forward(img) + + +class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.s = ResNet18Module() + + def forward(self, x): + return self.s.forward(x) + + +test_module = TestModule() +class_annotator = torch_mlir.ClassAnnotator() +recursivescriptmodule = torch.jit.script(test_module) +torch.jit.save(recursivescriptmodule, "/tmp/foo.pt") + +class_annotator.exportNone(recursivescriptmodule._c._type()) +class_annotator.exportPath(recursivescriptmodule._c._type(), ["forward"]) +class_annotator.annotateArgs( + recursivescriptmodule._c._type(), + ["forward"], + [None, ([-1, -1, -1, -1], torch.float32, True),], +) +# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. +mb.import_module(recursivescriptmodule._c, class_annotator) + +backend = refjit.RefjitNpcompBackend() +compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) +jit_module = backend.load(compiled) + +image_url = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) +print("load image from " + image_url) +img = load_and_preprocess_image(image_url) +labels = load_labels() +predictions(test_module.forward, jit_module.forward, img, labels)