torch-mlir/examples/torchscript_resnet_inferenc...

415 lines
669 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"id": "0060537a",
"metadata": {},
"outputs": [],
"source": [
"# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
2021-10-05 01:53:46 +08:00
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n",
"# Also available under a BSD-style license. See LICENSE."
]
},
{
"cell_type": "markdown",
"id": "0e6ebf55",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"### Configuring Jupyter kernel.\n",
"\n",
"We assume that you have followed the instructions for setting up Torch-MLIR. See [README.md](https://github.com/llvm/torch-mlir) if not.\n",
"\n",
"To run this notebook, you need to configure Jupyter to access the Torch-MLIR Python modules that are built as part of your development setup. An easy way to do this is to run the following command with the same Python (and shell) that is correctly set up and able to run the Torch-MLIR end-to-end tests with RefBackend:\n",
"\n",
"```shell\n",
"python -m ipykernel install --user --name=torch-mlir --env PYTHONPATH \"$PYTHONPATH\"\n",
"```\n",
"\n",
"You should then have an option in Jupyter to select this kernel for running this notebook.\n",
"\n",
"**TODO**: Make this notebook standalone and work based entirely on pip-installable packages.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "2aa20cfc",
"metadata": {},
"source": [
"### Additional dependencies for this notebook"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ee25f979",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (2.25.1)\n",
"Requirement already satisfied: pillow in /usr/lib/python3/dist-packages (9.0.1)\n"
]
}
],
"source": [
"!python -m pip install requests pillow"
]
},
{
"cell_type": "markdown",
"id": "c8c95904",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "markdown",
"id": "c4149213",
"metadata": {},
"source": [
"### torch-mlir imports"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "847868f0",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"\n",
"import torch_mlir\n",
"from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder\n",
"from torch_mlir.dialects.torch.importer.jit_ir.torchscript_annotations import extract_annotations\n",
"\n",
"from torch_mlir.passmanager import PassManager\n",
"from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend"
]
},
{
"cell_type": "markdown",
"id": "91992cea",
"metadata": {},
"source": [
"### General dependencies"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a80963f8",
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"id": "322eaa75",
"metadata": {},
"source": [
"### Utilities"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f4bdf926",
"metadata": {},
"outputs": [],
"source": [
"def compile_and_load_on_refbackend(module):\n",
" \"\"\"Compile an MLIR Module to an executable module.\n",
"\n",
" This uses the Torch-MLIR reference backend which accepts\n",
" linalg-on-tensors as the way to express tensor computations.\n",
" \"\"\"\n",
" backend = RefBackendLinalgOnTensorsBackend()\n",
" compiled = backend.compile(module)\n",
" return backend.load(compiled)"
]
},
{
"cell_type": "markdown",
"id": "71cc1403",
"metadata": {},
"source": [
"## Basic tanh module"
]
},
{
"cell_type": "markdown",
"id": "269f3dc5",
"metadata": {},
"source": [
"A simple tiny module that is easier to understand and look at than a full ResNet."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6f46e706",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.7615941, 0.7615941, 0. ], dtype=float32)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class TanhModule(torch.nn.Module):\n",
" def forward(self, a):\n",
" return torch.tanh(a)\n",
"\n",
"# Compile the model with an example input.\n",
"# We lower to the linalg-on-tensors form that the reference backend supports.\n",
"compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n",
"# Load it on the reference backend.\n",
"jit_module = compile_and_load_on_refbackend(compiled)\n",
"# Run it!\n",
"jit_module.forward(torch.tensor([-1.0, 1.0, 0.0]).numpy())"
]
},
{
"cell_type": "markdown",
"id": "b7f8773b",
"metadata": {},
"source": [
"## ResNet Inference"
]
},
{
"cell_type": "markdown",
"id": "dd226e86",
"metadata": {},
"source": [
"Do some one-time preparation."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "88e8a5c3",
"metadata": {},
"outputs": [],
"source": [
"def _load_labels():\n",
" classes_text = requests.get(\n",
" \"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt\",\n",
" stream=True,\n",
" ).text\n",
" labels = [line.strip() for line in classes_text.splitlines()]\n",
" return labels\n",
"IMAGENET_LABELS = _load_labels()\n",
"\n",
"def _get_preprocess_transforms():\n",
" # See preprocessing specification at: https://pytorch.org/vision/stable/models.html\n",
" T = torchvision.transforms\n",
" return T.Compose(\n",
" [\n",
" T.Resize(256),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
" )\n",
"PREPROCESS_TRANSFORMS = _get_preprocess_transforms()"
]
},
{
"cell_type": "markdown",
"id": "d42e8a1b",
"metadata": {},
"source": [
"Define some helper functions."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "abd53b61",
"metadata": {},
"outputs": [],
"source": [
"def fetch_image(url: str):\n",
" # Use some \"realistic\" User-Agent so that we aren't mistaken for being a scraper.\n",
" headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.131 Safari/537.36\"}\n",
" return Image.open(requests.get(url, headers=headers, stream=True).raw).convert(\"RGB\")\n",
"\n",
"def preprocess_image(img: Image):\n",
" # Preprocess and add a batch dimension.\n",
" return torch.unsqueeze(PREPROCESS_TRANSFORMS(img), 0)"
]
},
{
"cell_type": "markdown",
"id": "dbec8e96",
"metadata": {},
"source": [
"### Fetch our sample image."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8270ef94",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHgCAIAAADcxXWhAAEAAElEQVR4nKz92Y5kSZIlCBIRL3eVRTdTM3M391g8q6o7uqsxQKF+ax7nJ+azBpipt8EUenKJCM+IcHczU9NFtrvxQkTzwCLqlp3VtQBzkfDQVBOV5QozE9E5hw7h//X/9n9PKYQQUl6Mxbp2TV1bS1dXV99995vbmztDVQr89PTy80+f94cvZCdjwTmP4FSobft3bz/c3983TTPNw6dPP/3tpz+HOH347v7f/2//yw+//d1tfwsKqqAEMabTMuWclZAFu3blK6sKHGGZYbebd/uHu7t2e9WtVp0qpCWmlMgAEfVNA5eL4NcL4b92iYiqGmMA4HA4fHp8+X/8v/7f//nv/+npeU9EkbO1Zr3urYMP3713zhDI6TjuH4/TGCvXdF3X9vW3H96/fXPnvXeITdNYg8ycc44xhphDTirEKjFxzpljGsdxt9uN4wgA1lpVjTGGEKqqatu2rmvnnDHGGENEh5ddzjmllHNWVSIq/1pVFQAgIhG9/hcRhyksy7IsS0qJU8w5A4gxpm2qN2/e3NxcOeeIqK7rpvJARgROw/T8/Pz09HQ4HEIIIgIAkjnnzMzlTVprEVGAwRmy2DRd3/dt1SFiCCmEkGMKcSai6+v1u3fvtlcr7z0RGTCPj0+PX55FpO/XVVUtSzydTuM4Hg4HgHL/0Vob09K27d32+u7u7up6W1VVzvHz58+//PLTNE2q+ubNm2+++XB1dWWtn6bp+ellv98/7w8558yac55DWmIwxvqqurrejuM4DMeYFmUhgso7X7vbN3ff/+77m+2N9z7M8XA4jMPy+PjoXfvNN998++G7br0SgZf97p//+tPPv/xl9/xlnk+b9fru7m7ddNfX1+t+FUI4HQ7e+7bv+743ziKir6u2qTtfb9Zt31dEIAqqjIiAIqoGTVmbgTXGyCyqOh1GY4z3tq5r772qLmEOIfR9b4wBkJwzAFRVZQgBwPy6dAEQECFniDFNS6yqyntLBIiACAogArvn8NPffvnP//v/549/+sfT6bTdbn7z+998++233//2d9ZaUByG4en5cDweVch5G/IMKMy82+12T885ZxEZx/Hnv/203+/HcVRVZo4xOmObpmm9//777//tD//m7du363ZLRMu4nE6n0zQCABCSMUg6huVw2A3DMA6DI2OMUVVOICJEZK1FBZEsKc/LeDod5nl21tRts1qt1tvtarMmopTyPM/7w+l4PA5TVFVVAABRNMZ0Xdeueu/atu/atlbVGJecc+YowkQU4jiO4/G4H8ZTSgkACLBFyynHGEWzMcYYTJyXZTHGLMtCaK21UP4LNIUlQlaE9Xq73W4d1SIS5nA6nabTUNd1ZYx1dHt7+9vffv/2/s577yq32+0eHp+enp6eng+HwyEkIaKbm7uu697c3L579+7u9rrv+8paRWuq3lc1Ih5Pu+fHp5fdl/3+ZRgPHMM0n3bPLy+7p2WZiMhbUsQQuOvXbdutV9sffvh3/+v/+oe7u5vMIabh46ef/vjHP/70019Pp1EYiAwiWpAQ52lZcs5Etq7rqm6dc4LYNI2z1bIswzTGGHMWESFQ731fN845C1huGiJyTERUvru6rquqEoSUuK3aFHNcwjzPYZ5zzrWvmqax1hIREEo5XwAS5znPQz7sx8O4H0DNm6v7f/dv//Af/i//4e9++Ld13QFo4pzSnDjHNO+Px8NxN04LABDBp0+fHp8evnz5stvtvvvuu+NhSCmN47Lf74fTpKptu24aJ5ROw25esrWWjBcBBONc1dSttdYR1d5v16ubqyt7Xk0AiAgAIsrMiLparVar1Xq9RTAzhhItRGQeZ+chRQGI1lRdZ4hIVadp+vL45ZdffhnH+fpm/c0339zd3TVNAwiggAYAwDlnk89JcmKyNiVWsERQYiGAEGF5tpxFVUXEWusr68zrCfDfe5U9Vp6t/Kbrululv/s3P8xJmubTvIRxHAXVOecrenp6Wq/76+36/fv37+++HU7L/uV4OBxiXqraGYSmaWprQwhtUxFRiXbnaKGAhN575xz4qgQba+00TeUxKSUiyjkvy4KIZVmoatmQAFA+rF6u8s5fvxdzuYioqioRSSnFGJmZmUsgaZrGGMPMiOi9N8ZYaxUpxlACrYh47xGxvGeDZK3lyxVCAAAgJXQGqbznRAkRy2spS7mT5W2HkEQEEZ+/PE/TPI5TznmaFmOMKopIuT8l1AFQefPGmDmG8iZfMwBVdc7d3NxcX1+v1+uqqpi1vIHyGYmodk5VFZeYUwghxLiEOcaYUgAUS8Y51zRNt+pLJtE0ze3tbeXqYRgev7zknF+ej7vdzjq/TdG5ap7nEEJK6XA4eE+bzebNmzd91XjvQwj7/d5b673vuq6EOhExzhpj2rZ1rlIFZlAAohKhAFAVFEBFhQCcc5VFROyrpnyJ1gIiACBSY611zlkDAFTWkqqKoqpeng4Qz3kcIpTbparMmpIAEBGKQkp8PB5//PHHP/7xjyGGb7/99ttvv3nz7s1ms9lut+t1ZQl2+9USUwihrtvtdpskPL88Pj4+ltxru92WLDCFyMzjOKYcrXFV7XNMh+N+n/nx8fGf/v4f7+/vf/Pdb795923TdCISQhBQRATSLDKH6TSNyzLGGBnKSSIgiGistQDQNe08pxDn8o1ba601ROSqyjlXHlNWSEm8rq6aGPOyhGma5iWW9d80zfqm93VlLZXUMHNMKaUURbKx2La189d1U51Op2VZNOaoiRSMMahKBGjIoq3ruqwrBERjEcpBgarKwll4GAYiqmwu+VzZXwCQVSxQzvnl5SXFxRgjIC/7/ePj08t+N40xq7RN3/adAviq8k2tyvOyIMJMlDKDHZxviHAcT8fTfpymkEJmtt45rZqubWMHpCIiwsJqrI0xcgZQ+/nzZ2vp5587QI5peHp++Pnnvz0/P4eQENGQR8Smdl9lzK7kzUQUYjTGCEP54FVVeY+qOg0nEcnCRgxaZ42xZBBxWEJZcmXtIaIxBEAxxmUOyzTP8xyXBQCcsapaVRUaUtWYMzCzStnCS1hEpPKN91Xf9wDwstv99a8/3d+/DSFMy5RzdJW1lsrB9fj08vDwcDjsvnz5Ms1Dznkcx4eHBxWMMY7jHGMsB1HOcVnYVMBZc0wpsSFGY5va13Xd972zFkRAZJoWgp0t+6qk54AiIjlnRBdjLvmmtd45KevPew9U+YoQHLM6VzVN470X0XEc9vv96XQCgO12e3v7Zr1ee+sBABBUSrYAxlgiAwzONoAGEZwDVkg5pTwDSAhzvSBbh4gIYIgsGfyXldz5LPjvDnvlq7LWbrer77//fpgjALzsDkTIzKvVql/V//5/+8PNzdV2tRaBMITHL7sf//i3eQpLGo7Ho7dmvV5D0wCANdg0TcldSr1FaMkaYz0Raea+79u29d4/Pz+fTqeUUll8JewhYlVVzjlVLf9veRLn3GuQ/jX2AJSdVj4CIlZVXUJdOTJEpJwU3nsAKLGhFLKSEyuM43w8jdM0lVDnvS8BpjwbiMYYS42oqgCqzIpyjtCXUMfMIIqIzrkSyFUVRY2hruvatlv16Xg87vfH4/Goes5XQggAAgCqxhgDCETUtm276puuU1WeVQnrrnXO3b97t1qtutWKiJZhnpblOA6H4ZQzlw+IiD5LCQwp5xBBVQGEzPkuee+bplmv103dOee8933XV1XlTc0pD6d/OhwOwzitdi9dt5qWebfbzfPc9/3V1erDhw8fPnzofK2qcQlExCnVdX2uTphFhFWE84toSk3XN1XlnAUAUIAsYomy5ByjiCAassYaRwT4r9aps0jkL5kMWEOvkazUMYioAHpZ4URwvuGIAKiKmSXnfDwed7v9H//pzz/99Iuq3t7e3t3dtW1rra2qap5HALDGzfNcjjZVHYajq92yLPMwSspN09zc3BhDVeWH0/H27sY6w5y8rZjT8Xg8HA6VrZZlOQyn4/H4l7/8rWvazeZqtVpdXV01TbPZbNq+8waU8hRMtoY5iaiIlPym8r723tq
"text/plain": [
"<PIL.Image.Image image mode=RGB size=590x480 at 0x7F36601D8EB0>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = fetch_image(\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Red_Smooth_Saluki.jpg/590px-Red_Smooth_Saluki.jpg\")\n",
"img_preprocessed = preprocess_image(img)\n",
"\n",
"img"
]
},
{
"cell_type": "markdown",
"id": "55852670",
"metadata": {},
"source": [
"### Define the module and compile it"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "a80462ba",
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_2291582/1991440124.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mresnet18\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresnet18\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResNet18_Weights\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFAULT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mresnet18\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcompiled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch_mlir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresnet18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m224\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m224\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch_mlir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOutputType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLINALG_ON_TENSORS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mjit_module\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompile_and_load_on_refbackend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompiled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/pg/torch-mlir/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py\u001b[0m in \u001b[0;36mcompile\u001b[0;34m(model, example_args, output_type)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mmb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscripted\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_annotator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m run_pipeline_with_repro_report(mb.module,\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\"torchscript-module-to-torch-backend-pipeline\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \"Lowering TorchScript IR -> Torch Backend IR\")\n",
"\u001b[0;32m~/pg/torch-mlir/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py\u001b[0m in \u001b[0;36mrun_pipeline_with_repro_report\u001b[0;34m(module, pipeline, description)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mpm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPassManager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpipeline\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;31m# TODO: More robust.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n",
"resnet18.eval()\n",
"compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n",
"jit_module = compile_and_load_on_refbackend(compiled)"
]
},
{
"cell_type": "markdown",
"id": "2890c8c6",
"metadata": {},
"source": [
"### Execute the classification!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23476f38",
"metadata": {},
"outputs": [],
"source": [
"logits = torch.from_numpy(jit_module.forward(img_preprocessed.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a826c789",
"metadata": {},
"outputs": [],
"source": [
"# Torch-MLIR doesn't currently support these final postprocessing operations, so perform them in Torch.\n",
"def top3_possibilities(logits):\n",
" _, indexes = torch.sort(logits, descending=True)\n",
" percentage = torch.nn.functional.softmax(logits, dim=1)[0] * 100\n",
" top3 = [(IMAGENET_LABELS[idx], percentage[idx].item()) for idx in indexes[0][:3]]\n",
" return top3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "355ca74e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Saluki, gazelle hound', 74.8702163696289),\n",
" ('Ibizan hound, Ibizan Podenco', 18.07537841796875),\n",
" ('whippet', 6.3394775390625)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top3_possibilities(logits)"
]
}
],
"metadata": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
},
"kernelspec": {
"display_name": "torch-mlir",
"language": "python",
"name": "torch-mlir"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}