torch-mlir/examples/resnet_inference.ipynb

457 lines
663 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4b063b1e",
"metadata": {},
"outputs": [],
"source": [
"# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
]
},
{
"cell_type": "markdown",
"id": "bd0ab562",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Configuring jupyter kernel.\n",
"\n",
"We assume that you have followed the instructions for setting up npcomp. See [README.md](https://github.com/llvm/torch-mlir) if not.\n",
"\n",
"To run this notebook, you need to configure jupyter to access the npcomp Python modules that are built as part of your development setup. An easy way to do this is to run the following command with the same Python (and shell) that is correctly set up and able to run the npcomp end-to-end tests with refbackend:\n",
"\n",
"```shell\n",
"python -m ipykernel install --user --name=npcomp --env PYTHONPATH \"$PYTHONPATH\"\n",
"```\n",
"\n",
"You should then have an option in jupyter to select this kernel for running this notebook.\n",
"\n",
"**TODO**: Make this notebook standalone and work based entirely on pip-installable packages.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "d6d92319",
"metadata": {},
"source": [
"### Additional dependencies for this notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b992f23",
"metadata": {},
"outputs": [],
"source": [
"!python -m pip install requests pillow"
]
},
{
"cell_type": "markdown",
"id": "10788b70",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "markdown",
"id": "0c7ee78e",
"metadata": {},
"source": [
"### Npcomp imports"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6b9ba793",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n",
"from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n",
"from npcomp_torchscript.annotations import annotate_args, export\n",
"\n",
"import npcomp\n",
"from npcomp.passmanager import PassManager\n",
"from npcomp.compiler.pytorch.backend import refbackend"
]
},
{
"cell_type": "markdown",
"id": "c4977302",
"metadata": {},
"source": [
"### General dependencies"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b1b42df6",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"id": "cfa688f3",
"metadata": {},
"source": [
"### Utilities"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2da1e08c",
"metadata": {},
"outputs": [],
"source": [
"def compile_module(program: torch.nn.Module):\n",
" \"\"\"Compiles a torch.nn.Module into an compiled artifact.\n",
" \n",
" This artifact is suitable for inclusion in a user's application. It only\n",
" depends on the rebackend runtime.\n",
" \"\"\"\n",
" ## Script the program.\n",
" scripted = torch.jit.script(program)\n",
"\n",
" ## Extract annotations.\n",
" class_annotator = ClassAnnotator()\n",
" extract_annotations(program, scripted, class_annotator)\n",
"\n",
" ## Import the TorchScript module into MLIR.\n",
" mb = ModuleBuilder()\n",
" mb.import_module(scripted._c, class_annotator)\n",
"\n",
" ## Lower the MLIR from TorchScript to refbackend, passing through npcomp's backend contract.\n",
" with npcomp.ir.Context() as ctx:\n",
" npcomp.register_all_dialects(ctx)\n",
" lowered_mlir_module = npcomp.ir.Module.parse(str(mb.module))\n",
" pm = PassManager.parse('torchscript-to-npcomp-backend-pipeline')\n",
" pm.run(lowered_mlir_module)\n",
"\n",
" ## Invoke refbackend to compile to compiled artifact form.\n",
" backend = refbackend.RefBackendNpcompBackend()\n",
" return backend.compile(lowered_mlir_module)"
]
},
{
"cell_type": "markdown",
"id": "5e2f0d75",
"metadata": {},
"source": [
"## Basic tanh module"
]
},
{
"cell_type": "markdown",
"id": "2e9551a9",
"metadata": {},
"source": [
"A simple tiny module that is easier to understand and look at than a full ResNet."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e06d3c25",
"metadata": {},
"outputs": [],
"source": [
"class TanhModule(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" # The `export` annotation controls which parts of the model the npcomp\n",
" # compiler should assume are externally accessible. By default,\n",
" # the npcomp compiler will only export the explicitly exported functions.\n",
" # NOTE: THis is different from `torch.jit.export`. The `torch.jit.export`\n",
" # decorator controls which methods of the original torch.nn.Module get\n",
" # compiled into TorchScript. This decorator\n",
" # (`npcomp_torchscript.annotations.export`) controls which TorchScript\n",
" # methods are compiled by npcomp. \n",
" @export\n",
" # The `annotate_args` annotation provides metadata to the npcomp compiler\n",
" # regarding the constraints on arguments. The value `None` means that\n",
" # no additional information is provided for that argument. Otherwise,\n",
" # it is a 3-tuple specifying the shape (`-1` for unknown extent along a dimension),\n",
" # along with the dtype and whether the tensor has\n",
" # value semantics (this would be False if you need to mutate an input\n",
" # tensor in-place).\n",
" @annotate_args([\n",
" None,\n",
" ([-1], torch.float32, True)\n",
" ])\n",
" def forward(self, a):\n",
" return torch.tanh(a)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c7d8c369",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.7615941, 0.7615941, 0. ], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create the module and compile it.\n",
"compiled = compile_module(TanhModule())\n",
"# Loads the compiled artifact into the runtime\n",
"jit_module = refbackend.RefBackendNpcompBackend().load(compiled)\n",
"# Run it!\n",
"jit_module.forward(torch.tensor([-1.0, 1.0, 0.0]).numpy())"
]
},
{
"cell_type": "markdown",
"id": "fdecb7c8",
"metadata": {},
"source": [
"## ResNet Inference"
]
},
{
"cell_type": "markdown",
"id": "7fc20f95",
"metadata": {},
"source": [
"Do some one-time preparation."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5a5f6ef2",
"metadata": {},
"outputs": [],
"source": [
"def _load_labels():\n",
" classes_text = requests.get(\n",
" \"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt\",\n",
" stream=True,\n",
" ).text\n",
" labels = [line.strip() for line in classes_text.splitlines()]\n",
" return labels\n",
"IMAGENET_LABELS = _load_labels()\n",
"\n",
"def _get_preprocess_transforms():\n",
" # See preprocessing specification at: https://pytorch.org/vision/stable/models.html\n",
" T = torchvision.transforms\n",
" return T.Compose(\n",
" [\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
" )\n",
"PREPROCESS_TRANSFORMS = _get_preprocess_transforms()"
]
},
{
"cell_type": "markdown",
"id": "8e31f4c4",
"metadata": {},
"source": [
"Define some helper functions."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "acf648ce",
"metadata": {},
"outputs": [],
"source": [
"def fetch_image(url: str):\n",
" # Use some \"realistic\" User-Agent so that we aren't mistaken for being a scraper.\n",
" headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36\"}\n",
" return Image.open(requests.get(url, headers=headers, stream=True).raw).convert(\"RGB\")\n",
"\n",
"def preprocess_image(img: Image):\n",
" # Preprocess and add a batch dimension.\n",
" return torch.unsqueeze(PREPROCESS_TRANSFORMS(img), 0)"
]
},
{
"cell_type": "markdown",
"id": "903353e2",
"metadata": {},
"source": [
"### Fetch our sample image."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "45e6988a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHgCAIAAADcxXWhAAEAAElEQVR4nKT919JkWZIeirn7EluF+nWKEt1dNQOc08QcoxEGPhYv+RJ8Jt6Dl4ARp2d6qqars6oy//xFqK2WcueFR0RlD0AQ4NlWlpYVGWKLtVx8/vnn+H/7v/8/UgohzikFY6Wuq7rxztirq6tvvvnd7c2dpSqG8vz8+vOHT7vdjqonY4yzNYITprZdvX3z/uHhTdPU49R//Pjhrx9+CHH8+puHf/jf/viHP/zuvn0AABEAgpDCcRpjyYCGGdpuVXkrDCXBPMF2O21f93dvymazWS47EUhzTDkQEREtmgb+9iAAAED47x3MLCLGGADY7/efP3/+f/6//vRf/st/eXnZIplYsrV2teqMh2++fuccEsDxMLw+7+c+edcs2nW7oq++evfm/sFXziB1dWWNKaXknGOMIeaQMgMyc0wl51xiGoZhu90OwwAA1loRiTGGEKqqatu2rmvnnDHGGENE+9dtzjmllHMWESLSf62qCgAQkYgufyJiP4Z5nud5TimVFHPOAGyMaZrm4f725ubGOUNEdV03lSeixKbv+5eXl+fn5/1+H0JgZgAopZSUSyl6ktZaRBQRrsQY0zTdYrFo6o7AhBBCCCmlOAcycH29efvufrNZee8tAYJ/enp6+vzCzIvFqqqaeZ6Px+MwDPv9HgBEijForY1pbtv2bnN9d3d3db2pqirn+OnTr7/88ss4jiJyf3///v37q6sba/04ji/Pr7vdbrvdx5JLllTyFNIcgzHWV9XV9WYYhr4/xDRLyQbJV66qqqs3V7///e9vNjfe+zDF/e4wDMPT04t39fv3X3/19Tfdal2EX7b7n3766eeff96+/nWapvVqdXd3t2y7m6vr1WIZQjju9977drFYLBbGWUT0ddU0zdq51WrRLisiYAGRgoiCLCIGja7KwCWGXAqLyHAYjDHe+7quvbciMs9zCPNisTDGAHLOGQCqqjKIAGAvS7cAECBCzhBjGue5qhrvLREgAiIIADNsn9OHDx/+8//7P/35z38+9vvNZv2773731Vdfffv731nrQfA4jC/P28OhFxFnfeBen/52u90+v+ScmXkYhp//+mG32w3DICKllBijM7Zpmtb7b7/93b/5/u/ePrxbLFYGaR70+Y6MAoRkjJBMYdzv94ehD4de17YI5pylABE5YwGAOXPK0zwcj/tpHp2xTVO1q/X6arNarYhMSmkaw+6wPxz6fhxFUAQEQASNMV3XLRYLU7Vd17VtK1JijDnHXCJzIYIQwjAeD4ddPxxTSgCAKCuucklzjCzZGGMM5pynMBtj5nkmtNZ6QLLWAtA8zxNlAFqtVpv1tTOWGcIUhsOx7/umrq013tLt7e0ffvftw8ODr6yr3Ha7ffr8/Pnl9fllu98fQ8pEdHNz13Xd/c3t27dv725uV4vWOYeIUF1VVYUox8P++fnz6/bzYfvaD/uUwjT2r68vu9fneR6JyDpCxDGUxWLVtt1qef3993//7/7dv7u7u8olxNT/+vHDn//8jx8+fDgeBy5A5BCMpTmEMM9zzJnI1nVd1a1zDgCrtnHOzVPsxyHGmDMzM4HU3jdN460zgAhAgAYwpWSQ0JC1tq7rqqoYIaXU+mVKKYQwj1OYppxz7aumaay1RASEzMwAgqAW8hU+H4/H434kNrfXD//r3//x//J//g9///2/qesaAHLOMYdU4pymw+GwO2zHcQYAIvj48ePT09Pnp0/b7fabb7457PuU0jBMu92+P44i2LZt03RAx31/nKdonCfjmQHBOFc1TeOMtWRq7zer5c3VlRUR3VqICCDMXEohwOVyuVwuV6sVgRsxqLdg5mkYnK2SZRBrTdV1SEQiMo7j56fPv/zyyzAM1zfr9+/f393dN013+nICQDCusinHwjmzMTalAmKJAAREAIDJgH4bJxYRZrbWeu+9Mf+1k/vvH8ysPvJygV3X3d7e/t3ffRdjbJqP0xyGYWAE55z39Pz8vFq115vNu3fv3ty/Hw9h+3o47IbIqaoqg9Q0jXc2zb5tGiJSb5dzLqUIEiJ6751z4CvnHBFZa8dx1PeklIgo5zzPMyLqshAR3ZAAoBcr50PP/PJczPkgoqqqmDmlFGMspZRyciRN0xhjSimI4r03xqj3mmJWR8vM3ntE1HNWA1TORwjhtAYM6hJMKTlKiKff0tMDQD1t/UKD8vT54ziOwzDmnMdxNsbptej9UVcHQHryxpgQQilF3WqMMYQkIs65m5ub6+vr1WpTVVUpoieg10hErnYeRHCOOYUQQoxzmGKMKQVAtoTeu6ZpFotFXdfOVk3T3N7eVq7uj8PT01PO/Pqy22631vlNyta7aZrUf+/3e+/9er2+v7/v6qZyPoSw2+28td77k3l1lpmNs8aYumuNdyJQCggAkSEEAbiYZRYmgco6sIBIi6rRG2stIAIAEtbOGu+cNQBABqmUAixCKCJA+OWWAVDHhrowShFdMkgoAimlw+Hww7/8+Oc//zmE8NVXX3311fv7t/fr9Xqz2axWjSV43a/DnEJIdV1v1lcJxpeXl6enJ429NpuNRoEpxFLKMAwpR2ttXfmU0v6w2+Xy9PT0pz/96e3929998+37t181TQeFpzifzpYkM09hHIZhmoYcZ0QEIGYWQQJjrQWArmmnKWnYJCLWWmMtkfV15Zyz1gIgEZE1GnhdXV3FmOc5DOM4z7MxTuO5q9WiqiprKecsUnKJKaWUInM2Ftu2dc7UTXU8Hud5zjmGHAHAGIOnbYXG2RprXVeAiIYQ1FCguvlSct/3hLZy3vta93LTNAjAzACUc359fY0xGosM/LrbPT09v273/TRzgabr2q4TAF9VvqkLlCmMSEKEnHJ2wXtvEIfxuD/uhnGcU0hcjHeeq6Zr59gJCTOzFC7FWhtjLIVB7KdPn6y1P//cAJaY+ueXx59//uvLy1MICcEZYxFNV+sq0ojZadxMRCFEjIaZS2E1I96jiIz9MTOXUpiMs84aY8kYwBACgCCIrj1ENIYAYE4xTvM0TfM4xXkGAGesiFRVRUQMknPmUpg55xxymvPMDFVV1bZZLBYAsN1uf/rpp4eHNyGEcR5zjq6y5E4m9Pn59fHxcb/ffv78eRzHXOIwDI+Pj8IYYxyGKcaohkjNqatKyZJSilkMFTS2qX1d14vFwhkHzMA8jiNdgkg1Q4DIzDkxQYkx55wByFrrHFvj1esAN95XCKYUdK5q6s77SgPD3W53PB4BYLPZ3N7erlYrbzyImnIABEQw5AgLSHG2AjCI4BwUgZRTyhNgTDMn55N1iAgABsmS+R/3cP/NQx+VtXaz2fz+WxumCQBet3siSpyXy+ViWf/D//bHm5urzWohmcYxvDzuf/zzX+bxw5zGw+HgrVmtVl3bkIA1pmkazY1O+ZaxxhhjPRFJLovFom1b7/3Ly8vxeEwpqfdSt4eIVVU550RE/1e/xDl3cdIX36OPRlNDfWdV1erq1GQws1oK7z0AxBhLQU1kOScROQzxeDyO46iuznuvng8RCU7+RnNE/QkpoK4u55zOrk5tMZI459WRiyCyoEUNsZeLdDgcdruDZg/qyM+XwCLGGAMIRFR3bbtcNN1ChMskQlh3rXPu4e3b5XLZLZdENPfTOIfD0O/7Y8nFeGetBUKfWYOMlHOIICIaHuld8t43TbNabpqmcc557xfdovK1c1VJ3B/H/X7fD+Ny+9oultM0bbfbaZoWi8XV1dXXX3/99ddft1UNLHEORFRSquvaGKPmj5mLcCllV3JKXbtoqso5CwAgAJnZEmXOKUVm1uDEGkcE9F/hDs4ikT9HMmANXTyZiIho0AlyRiyI4HTDEQFQhEspKfPhcNhut3/+p798+PBBRG7vru/u7uqutdZWVTVNk27eeRzIQFU5EemHo2twnudhGDjlpmlubm6MMVXl++Ph9u7GOlNKcc5wLofDYb/fV7bSNL3fH/7yl790zWKzWi+X66urq7pt1utlu+i8AaE8BmOTDZxFhItofOO8rypvrdfwTiNLY5CZBRiJpmn
"text/plain": [
"<PIL.Image.Image image mode=RGB size=590x480 at 0x7F3A1DCF95E0>"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = fetch_image(\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Red_Smooth_Saluki.jpg/590px-Red_Smooth_Saluki.jpg\")\n",
"img_preprocessed = preprocess_image(img)\n",
"\n",
"img"
]
},
{
"cell_type": "markdown",
"id": "19f53027",
"metadata": {},
"source": [
"### Define the module and compile it"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5d1fc533",
"metadata": {},
"outputs": [],
"source": [
"class ResNet18Module(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.resnet = torchvision.models.resnet18(pretrained=True)\n",
" self.train(False)\n",
" @export\n",
" @annotate_args([\n",
" None,\n",
" ([1, 3, 224, 224], torch.float32, True),\n",
" ])\n",
" def forward(self, img):\n",
" return self.resnet.forward(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16ce5fad",
"metadata": {},
"outputs": [],
"source": [
"# Create the module and compile it.\n",
"compiled = compile_module(ResNet18Module())\n",
"# Load it for in-process execution.\n",
"jit_module = refbackend.RefBackendNpcompBackend().load(compiled)"
]
},
{
"cell_type": "markdown",
"id": "7a9fb286",
"metadata": {},
"source": [
"### Execute the classification!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e6b9d4c",
"metadata": {},
"outputs": [],
"source": [
"logits = torch.from_numpy(jit_module.forward(img_preprocessed.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba4a5d62",
"metadata": {},
"outputs": [],
"source": [
"# Npcomp doesn't currently support these final postprocessing operations, so perform them in Torch.\n",
"def top3_possibilities(logits):\n",
" _, indexes = torch.sort(logits, descending=True)\n",
" percentage = torch.nn.functional.softmax(logits, dim=1)[0] * 100\n",
" top3 = [(IMAGENET_LABELS[idx], percentage[idx].item()) for idx in indexes[0][:3]]\n",
" return top3"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "fcbd2a75",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Saluki, gazelle hound', 75.50347900390625),\n",
" ('Ibizan hound, Ibizan Podenco', 17.549449920654297),\n",
" ('whippet', 6.24681282043457)]"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top3_possibilities(logits)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "npcomp",
"language": "python",
"name": "npcomp"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}