mirror of https://github.com/llvm/torch-mlir
499 lines
668 KiB
Plaintext
499 lines
668 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "4b063b1e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n",
|
||
|
"# See https://llvm.org/LICENSE.txt for license information.\n",
|
||
|
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "bd0ab562",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Setup\n",
|
||
|
"\n",
|
||
|
"### Configuring jupyter kernel.\n",
|
||
|
"\n",
|
||
|
"We assume that you have followed the instructions for setting up npcomp with IREE support. See [README.md](https://github.com/llvm/mlir-npcomp) if not.\n",
|
||
|
"\n",
|
||
|
"To run this notebook, you need to configure jupyter to access the npcomp Python modules that are built as part of your development setup. An easy way to do this is to run the following command with the same Python (and shell) that is correctly set up and able to run the npcomp end-to-end tests with IREE:\n",
|
||
|
"\n",
|
||
|
"```shell\n",
|
||
|
"python -m ipykernel install --user --name=npcomp --env PYTHONPATH \"$PYTHONPATH\"\n",
|
||
|
"```\n",
|
||
|
"\n",
|
||
|
"You should then have an option in jupyter to select this kernel for running this notebook.\n",
|
||
|
"\n",
|
||
|
"**TODO**: Make this notebook standalone and work based entirely on pip-installable packages.\n",
|
||
|
"\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "d6d92319",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Additional dependencies for this notebook"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "8b992f23",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||
|
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (2.25.1)\n",
|
||
|
"Requirement already satisfied: pillow in /usr/lib/python3/dist-packages (8.1.2)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"!python -m pip install requests pillow"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "10788b70",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Imports"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0c7ee78e",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Npcomp imports"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "6b9ba793",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torchvision\n",
|
||
|
"import torch_mlir\n",
|
||
|
"\n",
|
||
|
"from npcomp.passmanager import PassManager\n",
|
||
|
"\n",
|
||
|
"from npcomp_torchscript.annotations import annotate_args, export\n",
|
||
|
"from torch_mlir.torchscript_annotations import extract_annotations"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2e4b6e49",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### IREE imports"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "7d19c0f7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import iree.runtime as ireert\n",
|
||
|
"import iree.compiler as ireec"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "c4977302",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### General dependencies"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "b1b42df6",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import requests\n",
|
||
|
"from PIL import Image"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "cfa688f3",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Utilities"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "2da1e08c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def compile_to_iree_flatbuffer(program: torch.nn.Module):\n",
|
||
|
" \"\"\"Compiles a torch.nn.Module into an IREE flatbuffer compiled artifact.\n",
|
||
|
" \n",
|
||
|
" This artifact is suitable for inclusion in a user's application. It only\n",
|
||
|
" depends on the IREE runtime.\n",
|
||
|
" \"\"\"\n",
|
||
|
" ## Script the program.\n",
|
||
|
" scripted = torch.jit.script(program)\n",
|
||
|
"\n",
|
||
|
" ## Extract annotations.\n",
|
||
|
" class_annotator = torch_mlir.ClassAnnotator()\n",
|
||
|
" extract_annotations(program, scripted, class_annotator)\n",
|
||
|
"\n",
|
||
|
" ## Import the TorchScript module into MLIR.\n",
|
||
|
" mb = torch_mlir.ModuleBuilder()\n",
|
||
|
" mb.import_module(scripted._c, class_annotator)\n",
|
||
|
"\n",
|
||
|
" ## Lower the MLIR from TorchScript to IREE, passing through npcomp's backend contract.\n",
|
||
|
" with mb.module.context:\n",
|
||
|
" pipeline_str = \",\".join([\n",
|
||
|
" # Lower from the TorchScript MLIR representation to the npcomp backend contract.\n",
|
||
|
" \"torchscript-to-npcomp-backend-pipeline\",\n",
|
||
|
" # Lower from the npcomp backend contract to IREE-specific IR.\n",
|
||
|
" # This is a very lightweight process, as IREE's frontend contract\n",
|
||
|
" # and npcomp's backend contract are largely the same.\n",
|
||
|
" \"npcomp-backend-to-iree-frontend-pipeline\",\n",
|
||
|
" ])\n",
|
||
|
" pm = PassManager.parse(pipeline_str)\n",
|
||
|
" pm.run(mb.module)\n",
|
||
|
"\n",
|
||
|
" ## Invoke IREE to compile to its flatbuffer compiled artifact form.\n",
|
||
|
" return ireec.compile_str(str(mb.module), target_backends=[\"dylib-llvm-aot\"])\n",
|
||
|
"\n",
|
||
|
"def load_iree_flatbuffer_to_context(flatbuffer: str):\n",
|
||
|
" \"\"\"Loads an IREE flatbuffer into a fresh IREE context.\n",
|
||
|
" \n",
|
||
|
" This is suitable for simple in-process execution of IREE flatbuffers.\n",
|
||
|
" \"\"\"\n",
|
||
|
" iree_config = ireert.Config(driver_name=\"dylib\")\n",
|
||
|
" ctx = ireert.SystemContext(config=iree_config)\n",
|
||
|
" ctx.add_vm_module(ireert.VmModule.from_flatbuffer(flatbuffer))\n",
|
||
|
" return ctx"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "5e2f0d75",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Basic tanh module"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "2e9551a9",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"A simple tiny module that is easier to understand and look at than a full ResNet."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "e06d3c25",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class TanhModule(torch.nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
"\n",
|
||
|
" # The `export` annotation controls which parts of the model the npcomp\n",
|
||
|
" # compiler should assume are externally accessible. By default,\n",
|
||
|
" # the npcomp compiler will only export the explicitly exported functions.\n",
|
||
|
" # NOTE: THis is different from `torch.jit.export`. The `torch.jit.export`\n",
|
||
|
" # decorator controls which methods of the original torch.nn.Module get\n",
|
||
|
" # compiled into TorchScript. This decorator\n",
|
||
|
" # (`npcomp_torchscript.annotations.export`) controls which TorchScript\n",
|
||
|
" # methods are compiled by npcomp. \n",
|
||
|
" @export\n",
|
||
|
" # The `annotate_args` annotation provides metadata to the npcomp compiler\n",
|
||
|
" # regarding the constraints on arguments. The value `None` means that\n",
|
||
|
" # no additional information is provided for that argument. Otherwise,\n",
|
||
|
" # it is a 3-tuple specifying the shape (`-1` for unknown extent along a dimension),\n",
|
||
|
" # along with the dtype and whether the tensor has\n",
|
||
|
" # value semantics (this would be False if you need to mutate an input\n",
|
||
|
" # tensor in-place).\n",
|
||
|
" @annotate_args([\n",
|
||
|
" None,\n",
|
||
|
" ([-1], torch.float32, True)\n",
|
||
|
" ])\n",
|
||
|
" def forward(self, a):\n",
|
||
|
" return torch.tanh(a)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "c7d8c369",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([-0.7615942, 0.7615942, 0. ], dtype=float32)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Create the module and compile it.\n",
|
||
|
"flatbuffer = compile_to_iree_flatbuffer(TanhModule())\n",
|
||
|
"# Use an in-process IREE runtime to execute the artifact.\n",
|
||
|
"ctx = load_iree_flatbuffer_to_context(flatbuffer)\n",
|
||
|
"# Run it!\n",
|
||
|
"ctx.modules.module[\"forward\"](torch.tensor([-1.0, 1.0, 0.0]).numpy())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "fdecb7c8",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## ResNet Inference"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7fc20f95",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Do some one-time preparation."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "5a5f6ef2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def _load_labels():\n",
|
||
|
" classes_text = requests.get(\n",
|
||
|
" \"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt\",\n",
|
||
|
" stream=True,\n",
|
||
|
" ).text\n",
|
||
|
" labels = [line.strip() for line in classes_text.splitlines()]\n",
|
||
|
" return labels\n",
|
||
|
"IMAGENET_LABELS = _load_labels()\n",
|
||
|
"\n",
|
||
|
"def _get_preprocess_transforms():\n",
|
||
|
" # See preprocessing specification at: https://pytorch.org/vision/stable/models.html\n",
|
||
|
" T = torchvision.transforms\n",
|
||
|
" return T.Compose(\n",
|
||
|
" [\n",
|
||
|
" T.Resize(256),\n",
|
||
|
" T.CenterCrop(224),\n",
|
||
|
" T.ToTensor(),\n",
|
||
|
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
|
||
|
" ]\n",
|
||
|
" )\n",
|
||
|
"PREPROCESS_TRANSFORMS = _get_preprocess_transforms()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8e31f4c4",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Define some helper functions."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "acf648ce",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def fetch_image(url: str):\n",
|
||
|
" # Use some \"realistic\" User-Agent so that we aren't mistaken for being a scraper.\n",
|
||
|
" headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36\"}\n",
|
||
|
" return Image.open(requests.get(url, headers=headers, stream=True).raw).convert(\"RGB\")\n",
|
||
|
"\n",
|
||
|
"def preprocess_image(img: Image):\n",
|
||
|
" # Preprocess and add a batch dimension.\n",
|
||
|
" return torch.unsqueeze(PREPROCESS_TRANSFORMS(img), 0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "903353e2",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Fetch our sample image."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "45e6988a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHgCAIAAADcxXWhAAEAAElEQVR4nKz92Y5kSZIlCBIRL3eVRTdTM3M391g8q6o7uqsxQKF+ax7nJ+azBpipt8EUenKJCM+IcHczU9NFtrvxQkTzwCLqlp3VtQBzkfDQVBOV5QozE9E5hw7h//X/9n9PKYQQUl6Mxbp2TV1bS1dXV99995vbmztDVQr89PTy80+f94cvZCdjwTmP4FSobft3bz/c3983TTPNw6dPP/3tpz+HOH347v7f/2//yw+//d1tfwsKqqAEMabTMuWclZAFu3blK6sKHGGZYbebd/uHu7t2e9WtVp0qpCWmlMgAEfVNA5eL4NcL4b92iYiqGmMA4HA4fHp8+X/8v/7f//nv/+npeU9EkbO1Zr3urYMP3713zhDI6TjuH4/TGCvXdF3X9vW3H96/fXPnvXeITdNYg8ycc44xhphDTirEKjFxzpljGsdxt9uN4wgA1lpVjTGGEKqqatu2rmvnnDHGGENEh5ddzjmllHNWVSIq/1pVFQAgIhG9/hcRhyksy7IsS0qJU8w5A4gxpm2qN2/e3NxcOeeIqK7rpvJARgROw/T8/Pz09HQ4HEIIIgIAkjnnzMzlTVprEVGAwRmy2DRd3/dt1SFiCCmEkGMKcSai6+v1u3fvtlcr7z0RGTCPj0+PX55FpO/XVVUtSzydTuM4Hg4HgHL/0Vob09K27d32+u7u7up6W1VVzvHz58+//PLTNE2q+ubNm2+++XB1dWWtn6bp+ellv98/7w8558yac55DWmIwxvqqurrejuM4DMeYFmUhgso7X7vbN3ff/+77m+2N9z7M8XA4jMPy+PjoXfvNN998++G7br0SgZf97p//+tPPv/xl9/xlnk+b9fru7m7ddNfX1+t+FUI4HQ7e+7bv+743ziKir6u2qTtfb9Zt31dEIAqqjIiAIqoGTVmbgTXGyCyqOh1GY4z3tq5r772qLmEOIfR9b4wBkJwzAFRVZQgBwPy6dAEQECFniDFNS6yqyntLBIiACAogArvn8NPffvnP//v/549/+sfT6bTdbn7z+998++233//2d9ZaUByG4en5cDweVch5G/IMKMy82+12T885ZxEZx/Hnv/203+/HcVRVZo4xOmObpmm9//777//tD//m7du363ZLRMu4nE6n0zQCABCSMUg6huVw2A3DMA6DI2OMUVVOICJEZK1FBZEsKc/LeDod5nl21tRts1qt1tvtarMmopTyPM/7w+l4PA5TVFVVAABRNMZ0Xdeueu/atu/atlbVGJecc+YowkQU4jiO4/G4H8ZTSgkACLBFyynHGEWzMcYYTJyXZTHGLMtCaK21UP4LNIUlQlaE9Xq73W4d1SIS5nA6nabTUNd1ZYx1dHt7+9vffv/2/s577yq32+0eHp+enp6eng+HwyEkIaKbm7uu697c3L579+7u9rrv+8paRWuq3lc1Ih5Pu+fHp5fdl/3+ZRgPHMM0n3bPLy+7p2WZiMhbUsQQuOvXbdutV9sffvh3/+v/+oe7u5vMIabh46ef/vjHP/70019Pp1EYiAwiWpAQ52lZcs5Etq7rqm6dc4LYNI2z1bIswzTGGHMWESFQ731fN845C1huGiJyTERUvru6rquqEoSUuK3aFHNcwjzPYZ5zzrWvmqax1hIREEo5XwAS5znPQz7sx8O4H0DNm6v7f/dv//Af/i//4e9++Ld13QFo4pzSnDjHNO+Px8NxN04LABDBp0+fHp8evnz5stvtvvvuu+NhSCmN47Lf74fTpKptu24aJ5ROw25esrWWjBcBBONc1dSttdYR1d5v16ubqyt7Xk0AiAgAIsrMiLparVar1Xq9RTAzhhItRGQeZ+chRQGI1lRdZ4hIVadp+vL45ZdffhnH+fpm/c0339zd3TVNAwiggAYAwDlnk89JcmKyNiVWsERQYiGAEGF5tpxFVUXEWusr68zrCfDfe5U9Vp6t/Kbrululv/s3P8xJmubTvIRxHAXVOecrenp6Wq/76+36/fv37+++HU7L/uV4OBxiXqraGYSmaWprQwhtUxFRiXbnaKGAhN575xz4qgQba+00TeUxKSUiyjkvy4KIZVmoatmQAFA+rF6u8s5fvxdzuYioqioRSSnFGJmZmUsgaZrGGMPMiOi9N8ZYaxUpxlACrYh47xGxvGeDZK3lyxVCAAAgJXQGqbznRAkRy2spS7mT5W2HkEQEEZ+/PE/TPI5TznmaFmOMKopIuT8l1AFQefPGmDmG8iZfMwBVdc7d3NxcX1+v1+uqqpi1vIHyGYmodk5VFZeYUwghxLiEOcaYUgAUS8Y51zRNt+pLJtE0ze3tbeXqYRgev7zknF+ej7vdzjq/TdG5ap7nEEJK6XA4eE+bzebNmzd91XjvQwj7/d5b673vuq6EOhExzhpj2rZ1rlIFZlAAohKhAFAVFEBFhQCcc5VFROyrpnyJ1gIiACBSY611zlkDAFTWkqqKoqpeng4Qz3kcIpTbparMmpIAEBGKQkp8PB5//PHHP/7xjyGGb7/99ttvv3nz7s1ms9lut+t1ZQl2+9USUwihrtvtdpskPL88Pj4+ltxru92WLDCFyMzjOKYcrXFV7XNMh+N+n/nx8fGf/v4f7+/vf/Pdb795923TdCISQhBQRATSLDKH6TSNyzLGGBnKSSIgiGistQDQNe08pxDn8o1ba601ROSqyjlXHlNWSEm8rq6aGPOyhGma5iWW9d80zfqm93VlLZXUMHNMKaUURbKx2La189d1U51Op2VZNOaoiRSMMahKBGjIoq3ruqwrBERjEcpBgarKwll4GAYiqmwu+VzZXwCQVSxQzvnl5SXFxRgjIC/7/ePj08t+N40xq7RN3/adAviq8k2tyvOyIMJMlDKDHZxviHAcT8fTfpymkEJmtt45rZqubWMHpCIiwsJqrI0xcgZQ+/nzZ2vp5587QI5peHp++Pnnvz0/P4eQENGQR8Smdl9lzK7kzUQUYjTGCEP54FVVeY+qOg0nEcnCRgxaZ42xZBBxWEJZcmXtIaIxBEAxxmUOyzTP8xyXBQCcsapaVRUaUtWYMzCzStnCS1hEpPKN91Xf9wDwstv99a8/3d+/DSFMy5RzdJW1lsrB9fj08vDwcDjsvnz5Ms1Dznkcx4eHBxWMMY7jHGMsB1HOcVnYVMBZc0wpsSFGY5va13Xd972zFkRAZJoWgp0t+6qk54AiIjlnRBdjLvmmtd45KevPew9U+YoQHLM6VzVN470X0XEc9vv96XQCgO12e3v7Zr1ee+sBABBUSrYAxlgiAwzONoAGEZwDVkg5pTwDSAhzvSBbh4gIYIgsGfyXldz5LPjvDnvlq7LWbrer77//fpgjALzsDkTIzKvVql/V//5/+8PNzdV2tRaBMITHL7sf//i3eQpLGo7Ho7dmvV5D0wCANdg0TcldSr1FaMkaYz0Raea+79u29d4/Pz+fTqeUUll8JewhYlVVzjlVLf9veRLn3GuQ/jX2AJSdVj4CIlZVXUJdOTJEpJwU3nsAKLGhFLKSEyuM43w8jdM0lVDnvS8BpjwbiMYYS42oqgCqzIpyjtCXUMfMIIqIzrkSyFUVRY2hruvatlv16Xg87vfH4/Goes5XQggAAgCqxhgDCETUtm276puuU1WeVQnrrnXO3b97t1qtutWKiJZhnpblOA6H4ZQzlw+IiD5LCQwp5xBBVQGEzPkuee+bplmv103dOee8933XV1XlTc0pD6d/OhwOwzitdi9dt5qWebfbzfPc9/3V1erDhw8fPnzofK2qcQlExCnVdX2uTphFhFWE84toSk3XN1XlnAUAUIAsYomy5ByjiCAassYaRwT4r9aps0jkL5kMWEOvkazUMYioAHpZ4URwvuGIAKiKmSXnfDwed7v9H//pzz/99Iuq3t7e3t3dtW1rra2qap5HALDGzfNcjjZVHYajq92yLPMwSspN09zc3BhDVeWH0/H27sY6w5y8rZjT8Xg8HA6VrZZlOQyn4/H4l7/8rWvazeZqtVpdXV01TbPZbNq+8waU8hRMtoY5iaiIlPym8r723tq
|
||
|
"text/plain": [
|
||
|
"<PIL.Image.Image image mode=RGB size=590x480 at 0x7FC621B8C6A0>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"img = fetch_image(\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Red_Smooth_Saluki.jpg/590px-Red_Smooth_Saluki.jpg\")\n",
|
||
|
"img_preprocessed = preprocess_image(img)\n",
|
||
|
"\n",
|
||
|
"img"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "19f53027",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Define the module and compile it"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "5d1fc533",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ResNet18Module(torch.nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.resnet = torchvision.models.resnet18(pretrained=True)\n",
|
||
|
" self.train(False)\n",
|
||
|
" @export\n",
|
||
|
" @annotate_args([\n",
|
||
|
" None,\n",
|
||
|
" ([1, 3, 224, 224], torch.float32, True),\n",
|
||
|
" ])\n",
|
||
|
" def forward(self, img):\n",
|
||
|
" return self.resnet.forward(img)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "16ce5fad",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Create the module and compile it.\n",
|
||
|
"flatbuffer = compile_to_iree_flatbuffer(ResNet18Module())\n",
|
||
|
"# Load it for in-process execution.\n",
|
||
|
"ctx = load_iree_flatbuffer_to_context(flatbuffer)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7a9fb286",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Execute the classification!"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "7e6b9d4c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"logits = torch.from_numpy(ctx.modules.module[\"forward\"](img_preprocessed.numpy()))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "ba4a5d62",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Npcomp doesn't currently support these final postprocessing operations, so perform them in Torch.\n",
|
||
|
"def top3_possibilities(logits):\n",
|
||
|
" _, indexes = torch.sort(logits, descending=True)\n",
|
||
|
" percentage = torch.nn.functional.softmax(logits, dim=1)[0] * 100\n",
|
||
|
" top3 = [(IMAGENET_LABELS[idx], percentage[idx].item()) for idx in indexes[0][:3]]\n",
|
||
|
" return top3"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "fcbd2a75",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[('Saluki, gazelle hound', 74.8702163696289),\n",
|
||
|
" ('Ibizan hound, Ibizan Podenco', 18.07537841796875),\n",
|
||
|
" ('whippet', 6.3394775390625)]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"top3_possibilities(logits)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "npcomp",
|
||
|
"language": "python",
|
||
|
"name": "npcomp"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.2"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|