torch-mlir/examples/resnet_inference.ipynb

475 lines
667 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0060537a",
"metadata": {},
"outputs": [],
"source": [
"# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
]
},
{
"cell_type": "markdown",
"id": "0e6ebf55",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Configuring jupyter kernel.\n",
"\n",
"We assume that you have followed the instructions for setting up torch-mlir. See [README.md](https://github.com/llvm/torch-mlir) if not.\n",
"\n",
"To run this notebook, you need to configure jupyter to access the torch-mlir Python modules that are built as part of your development setup. An easy way to do this is to run the following command with the same Python (and shell) that is correctly set up and able to run the torch-mlir end-to-end tests with RefBackend:\n",
"\n",
"```shell\n",
"python -m ipykernel install --user --name=torch-mlir --env PYTHONPATH \"$PYTHONPATH\"\n",
"```\n",
"\n",
"You should then have an option in jupyter to select this kernel for running this notebook.\n",
"\n",
"**TODO**: Make this notebook standalone and work based entirely on pip-installable packages.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "2aa20cfc",
"metadata": {},
"source": [
"### Additional dependencies for this notebook"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ee25f979",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (2.25.1)\n",
"Requirement already satisfied: pillow in /usr/lib/python3/dist-packages (8.1.2)\n"
]
}
],
"source": [
"!python -m pip install requests pillow"
]
},
{
"cell_type": "markdown",
"id": "c8c95904",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "markdown",
"id": "c4149213",
"metadata": {},
"source": [
"### torch-mlir imports"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "847868f0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n",
"from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n",
"from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export\n",
"\n",
"from torch_mlir.passmanager import PassManager\n",
"from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend"
]
},
{
"cell_type": "markdown",
"id": "91992cea",
"metadata": {},
"source": [
"### General dependencies"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a80963f8",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"id": "322eaa75",
"metadata": {},
"source": [
"### Utilities"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f4bdf926",
"metadata": {},
"outputs": [],
"source": [
"BACKEND = RefBackendLinalgOnTensorsBackend()\n",
"\n",
"def compile_module(program: torch.nn.Module):\n",
" \"\"\"Compiles a torch.nn.Module into an compiled artifact.\n",
" \n",
" This artifact is suitable for inclusion in a user's application. It only\n",
" depends on the rebackend runtime.\n",
" \"\"\"\n",
" ## Script the program.\n",
" scripted = torch.jit.script(program)\n",
"\n",
" ## Extract annotations.\n",
" class_annotator = ClassAnnotator()\n",
" extract_annotations(program, scripted, class_annotator)\n",
"\n",
" ## Import the TorchScript module into MLIR.\n",
" mb = ModuleBuilder()\n",
" mb.import_module(scripted._c, class_annotator)\n",
"\n",
" ## Lower the MLIR from TorchScript to RefBackend, passing through linalg-on-tensors.\n",
" pm = PassManager.parse('torchscript-module-to-linalg-on-tensors-backend-pipeline', mb.module.context)\n",
" pm.run(mb.module)\n",
"\n",
" ## Invoke RefBackend to compile to compiled artifact form.\n",
" return BACKEND.compile(mb.module)"
]
},
{
"cell_type": "markdown",
"id": "71cc1403",
"metadata": {},
"source": [
"## Basic tanh module"
]
},
{
"cell_type": "markdown",
"id": "269f3dc5",
"metadata": {},
"source": [
"A simple tiny module that is easier to understand and look at than a full ResNet."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "aed1869b",
"metadata": {},
"outputs": [],
"source": [
"class TanhModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" # The `export` annotation controls which parts of the model the torch-mlir\n",
" # compiler should assume are externally accessible. By default,\n",
" # the torch-mlir compiler will only export the explicitly exported functions.\n",
" # NOTE: This is different from `torch.jit.export`. The `torch.jit.export`\n",
" # decorator controls which methods of the original torch.nn.Module get\n",
" # compiled into TorchScript. This decorator\n",
" # (`torch_mlir_e2e_test.torchscript.annotations.export`) controls which TorchScript\n",
" # methods are compiled by torch-mlir. \n",
" @export\n",
" # The `annotate_args` annotation provides metadata to the torch-mlir compiler\n",
" # regarding the constraints on arguments. The value `None` means that\n",
" # no additional information is provided for that argument. Otherwise,\n",
" # it is a 3-tuple specifying the shape (`-1` for unknown extent along a dimension),\n",
" # along with the dtype and whether the tensor has\n",
" # value semantics (this would be False if you need to mutate an input\n",
" # tensor in-place).\n",
" @annotate_args([\n",
" None,\n",
" ([-1], torch.float32, True)\n",
" ])\n",
" def forward(self, a):\n",
" return torch.tanh(a)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "71a421a0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.7615941, 0.7615941, 0. ], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create the module and compile it.\n",
"compiled = compile_module(TanhModule())\n",
"# Loads the compiled artifact into the runtime\n",
"jit_module = BACKEND.load(compiled)\n",
"# Run it!\n",
"jit_module.forward(torch.tensor([-1.0, 1.0, 0.0]).numpy())"
]
},
{
"cell_type": "markdown",
"id": "b7f8773b",
"metadata": {},
"source": [
"## ResNet Inference"
]
},
{
"cell_type": "markdown",
"id": "dd226e86",
"metadata": {},
"source": [
"Do some one-time preparation."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "88e8a5c3",
"metadata": {},
"outputs": [],
"source": [
"def _load_labels():\n",
" classes_text = requests.get(\n",
" \"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt\",\n",
" stream=True,\n",
" ).text\n",
" labels = [line.strip() for line in classes_text.splitlines()]\n",
" return labels\n",
"IMAGENET_LABELS = _load_labels()\n",
"\n",
"def _get_preprocess_transforms():\n",
" # See preprocessing specification at: https://pytorch.org/vision/stable/models.html\n",
" T = torchvision.transforms\n",
" return T.Compose(\n",
" [\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
" )\n",
"PREPROCESS_TRANSFORMS = _get_preprocess_transforms()"
]
},
{
"cell_type": "markdown",
"id": "d42e8a1b",
"metadata": {},
"source": [
"Define some helper functions."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "abd53b61",
"metadata": {},
"outputs": [],
"source": [
"def fetch_image(url: str):\n",
" # Use some \"realistic\" User-Agent so that we aren't mistaken for being a scraper.\n",
" headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36\"}\n",
" return Image.open(requests.get(url, headers=headers, stream=True).raw).convert(\"RGB\")\n",
"\n",
"def preprocess_image(img: Image):\n",
" # Preprocess and add a batch dimension.\n",
" return torch.unsqueeze(PREPROCESS_TRANSFORMS(img), 0)"
]
},
{
"cell_type": "markdown",
"id": "dbec8e96",
"metadata": {},
"source": [
"### Fetch our sample image."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8270ef94",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHgCAIAAADcxXWhAAEAAElEQVR4nKz92Y5kSZIlCBIRL3eVRTdTM3M391g8q6o7uqsxQKF+ax7nJ+azBpipt8EUenKJCM+IcHczU9NFtrvxQkTzwCLqlp3VtQBzkfDQVBOV5QozE9E5hw7h//X/9n9PKYQQUl6Mxbp2TV1bS1dXV99995vbmztDVQr89PTy80+f94cvZCdjwTmP4FSobft3bz/c3983TTPNw6dPP/3tpz+HOH347v7f/2//yw+//d1tfwsKqqAEMabTMuWclZAFu3blK6sKHGGZYbebd/uHu7t2e9WtVp0qpCWmlMgAEfVNA5eL4NcL4b92iYiqGmMA4HA4fHp8+X/8v/7f//nv/+npeU9EkbO1Zr3urYMP3713zhDI6TjuH4/TGCvXdF3X9vW3H96/fXPnvXeITdNYg8ycc44xhphDTirEKjFxzpljGsdxt9uN4wgA1lpVjTGGEKqqatu2rmvnnDHGGENEh5ddzjmllHNWVSIq/1pVFQAgIhG9/hcRhyksy7IsS0qJU8w5A4gxpm2qN2/e3NxcOeeIqK7rpvJARgROw/T8/Pz09HQ4HEIIIgIAkjnnzMzlTVprEVGAwRmy2DRd3/dt1SFiCCmEkGMKcSai6+v1u3fvtlcr7z0RGTCPj0+PX55FpO/XVVUtSzydTuM4Hg4HgHL/0Vob09K27d32+u7u7up6W1VVzvHz58+//PLTNE2q+ubNm2+++XB1dWWtn6bp+ellv98/7w8558yac55DWmIwxvqqurrejuM4DMeYFmUhgso7X7vbN3ff/+77m+2N9z7M8XA4jMPy+PjoXfvNN998++G7br0SgZf97p//+tPPv/xl9/xlnk+b9fru7m7ddNfX1+t+FUI4HQ7e+7bv+743ziKir6u2qTtfb9Zt31dEIAqqjIiAIqoGTVmbgTXGyCyqOh1GY4z3tq5r772qLmEOIfR9b4wBkJwzAFRVZQgBwPy6dAEQECFniDFNS6yqyntLBIiACAogArvn8NPffvnP//v/549/+sfT6bTdbn7z+998++233//2d9ZaUByG4en5cDweVch5G/IMKMy82+12T885ZxEZx/Hnv/203+/HcVRVZo4xOmObpmm9//777//tD//m7du363ZLRMu4nE6n0zQCABCSMUg6huVw2A3DMA6DI2OMUVVOICJEZK1FBZEsKc/LeDod5nl21tRts1qt1tvtarMmopTyPM/7w+l4PA5TVFVVAABRNMZ0Xdeueu/atu/atlbVGJecc+YowkQU4jiO4/G4H8ZTSgkACLBFyynHGEWzMcYYTJyXZTHGLMtCaK21UP4LNIUlQlaE9Xq73W4d1SIS5nA6nabTUNd1ZYx1dHt7+9vffv/2/s577yq32+0eHp+enp6eng+HwyEkIaKbm7uu697c3L579+7u9rrv+8paRWuq3lc1Ih5Pu+fHp5fdl/3+ZRgPHMM0n3bPLy+7p2WZiMhbUsQQuOvXbdutV9sffvh3/+v/+oe7u5vMIabh46ef/vjHP/70019Pp1EYiAwiWpAQ52lZcs5Etq7rqm6dc4LYNI2z1bIswzTGGHMWESFQ731fN845C1huGiJyTERUvru6rquqEoSUuK3aFHNcwjzPYZ5zzrWvmqax1hIREEo5XwAS5znPQz7sx8O4H0DNm6v7f/dv//Af/i//4e9++Ld13QFo4pzSnDjHNO+Px8NxN04LABDBp0+fHp8evnz5stvtvvvuu+NhSCmN47Lf74fTpKptu24aJ5ROw25esrWWjBcBBONc1dSttdYR1d5v16ubqyt7Xk0AiAgAIsrMiLparVar1Xq9RTAzhhItRGQeZ+chRQGI1lRdZ4hIVadp+vL45ZdffhnH+fpm/c0339zd3TVNAwiggAYAwDlnk89JcmKyNiVWsERQYiGAEGF5tpxFVUXEWusr68zrCfDfe5U9Vp6t/Kbrululv/s3P8xJmubTvIRxHAXVOecrenp6Wq/76+36/fv37+++HU7L/uV4OBxiXqraGYSmaWprQwhtUxFRiXbnaKGAhN575xz4qgQba+00TeUxKSUiyjkvy4KIZVmoatmQAFA+rF6u8s5fvxdzuYioqioRSSnFGJmZmUsgaZrGGMPMiOi9N8ZYaxUpxlACrYh47xGxvGeDZK3lyxVCAAAgJXQGqbznRAkRy2spS7mT5W2HkEQEEZ+/PE/TPI5TznmaFmOMKopIuT8l1AFQefPGmDmG8iZfMwBVdc7d3NxcX1+v1+uqqpi1vIHyGYmodk5VFZeYUwghxLiEOcaYUgAUS8Y51zRNt+pLJtE0ze3tbeXqYRgev7zknF+ej7vdzjq/TdG5ap7nEEJK6XA4eE+bzebNmzd91XjvQwj7/d5b673vuq6EOhExzhpj2rZ1rlIFZlAAohKhAFAVFEBFhQCcc5VFROyrpnyJ1gIiACBSY611zlkDAFTWkqqKoqpeng4Qz3kcIpTbparMmpIAEBGKQkp8PB5//PHHP/7xjyGGb7/99ttvv3nz7s1ms9lut+t1ZQl2+9USUwihrtvtdpskPL88Pj4+ltxru92WLDCFyMzjOKYcrXFV7XNMh+N+n/nx8fGf/v4f7+/vf/Pdb795923TdCISQhBQRATSLDKH6TSNyzLGGBnKSSIgiGistQDQNe08pxDn8o1ba601ROSqyjlXHlNWSEm8rq6aGPOyhGma5iWW9d80zfqm93VlLZXUMHNMKaUURbKx2La189d1U51Op2VZNOaoiRSMMahKBGjIoq3ruqwrBERjEcpBgarKwll4GAYiqmwu+VzZXwCQVSxQzvnl5SXFxRgjIC/7/ePj08t+N40xq7RN3/adAviq8k2tyvOyIMJMlDKDHZxviHAcT8fTfpymkEJmtt45rZqubWMHpCIiwsJqrI0xcgZQ+/nzZ2vp5587QI5peHp++Pnnvz0/P4eQENGQR8Smdl9lzK7kzUQUYjTGCEP54FVVeY+qOg0nEcnCRgxaZ42xZBBxWEJZcmXtIaIxBEAxxmUOyzTP8xyXBQCcsapaVRUaUtWYMzCzStnCS1hEpPKN91Xf9wDwstv99a8/3d+/DSFMy5RzdJW1lsrB9fj08vDwcDjsvnz5Ms1Dznkcx4eHBxWMMY7jHGMsB1HOcVnYVMBZc0wpsSFGY5va13Xd972zFkRAZJoWgp0t+6qk54AiIjlnRBdjLvmmtd45KevPew9U+YoQHLM6VzVN470X0XEc9vv96XQCgO12e3v7Zr1ee+sBABBUSrYAxlgiAwzONoAGEZwDVkg5pTwDSAhzvSBbh4gIYIgsGfyXldz5LPjvDnvlq7LWbrer77//fpgjALzsDkTIzKvVql/V//5/+8PNzdV2tRaBMITHL7sf//i3eQpLGo7Ho7dmvV5D0wCANdg0TcldSr1FaMkaYz0Raea+79u29d4/Pz+fTqeUUll8JewhYlVVzjlVLf9veRLn3GuQ/jX2AJSdVj4CIlZVXUJdOTJEpJwU3nsAKLGhFLKSEyuM43w8jdM0lVDnvS8BpjwbiMYYS42oqgCqzIpyjtCXUMfMIIqIzrkSyFUVRY2hruvatlv16Xg87vfH4/Goes5XQggAAgCqxhgDCETUtm276puuU1WeVQnrrnXO3b97t1qtutWKiJZhnpblOA6H4ZQzlw+IiD5LCQwp5xBBVQGEzPkuee+bplmv103dOee8933XV1XlTc0pD6d/OhwOwzitdi9dt5qWebfbzfPc9/3V1erDhw8fPnzofK2qcQlExCnVdX2uTphFhFWE84toSk3XN1XlnAUAUIAsYomy5ByjiCAassYaRwT4r9aps0jkL5kMWEOvkazUMYioAHpZ4URwvuGIAKiKmSXnfDwed7v9H//pzz/99Iuq3t7e3t3dtW1rra2qap5HALDGzfNcjjZVHYajq92yLPMwSspN09zc3BhDVeWH0/H27sY6w5y8rZjT8Xg8HA6VrZZlOQyn4/H4l7/8rWvazeZqtVpdXV01TbPZbNq+8waU8hRMtoY5iaiIlPym8r723tq
"text/plain": [
"<PIL.Image.Image image mode=RGB size=590x480 at 0x7FC9CE7278E0>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = fetch_image(\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Red_Smooth_Saluki.jpg/590px-Red_Smooth_Saluki.jpg\")\n",
"img_preprocessed = preprocess_image(img)\n",
"\n",
"img"
]
},
{
"cell_type": "markdown",
"id": "55852670",
"metadata": {},
"source": [
"### Define the module and compile it"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3af95ee7",
"metadata": {},
"outputs": [],
"source": [
"class ResNet18Module(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.resnet = torchvision.models.resnet18(pretrained=True)\n",
" self.train(False)\n",
" @export\n",
" @annotate_args([\n",
" None,\n",
" ([1, 3, 224, 224], torch.float32, True),\n",
" ])\n",
" def forward(self, img):\n",
" return self.resnet.forward(img)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "89cbebf6",
"metadata": {},
"outputs": [],
"source": [
"# Create the module and compile it.\n",
"compiled = compile_module(ResNet18Module())\n",
"# Load it for in-process execution.\n",
"jit_module = BACKEND.load(compiled)"
]
},
{
"cell_type": "markdown",
"id": "2890c8c6",
"metadata": {},
"source": [
"### Execute the classification!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "23476f38",
"metadata": {},
"outputs": [],
"source": [
"logits = torch.from_numpy(jit_module.forward(img_preprocessed.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a826c789",
"metadata": {},
"outputs": [],
"source": [
"# torch-mlir doesn't currently support these final postprocessing operations, so perform them in Torch.\n",
"def top3_possibilities(logits):\n",
" _, indexes = torch.sort(logits, descending=True)\n",
" percentage = torch.nn.functional.softmax(logits, dim=1)[0] * 100\n",
" top3 = [(IMAGENET_LABELS[idx], percentage[idx].item()) for idx in indexes[0][:3]]\n",
" return top3"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "355ca74e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Saluki, gazelle hound', 74.8702163696289),\n",
" ('Ibizan hound, Ibizan Podenco', 18.07537841796875),\n",
" ('whippet', 6.3394775390625)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top3_possibilities(logits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a622c8f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
},
"kernelspec": {
"display_name": "torch-mlir",
"language": "python",
"name": "torch-mlir"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}