diff --git a/examples/torchdynamo_resnet18.py b/examples/torchdynamo_resnet18.py new file mode 100644 index 000000000..44d155b5d --- /dev/null +++ b/examples/torchdynamo_resnet18.py @@ -0,0 +1,94 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from typing import List + +from PIL import Image +import requests + +import torch +import torch._dynamo as dynamo +import torchvision.models as models +from torchvision import transforms + +import torch_mlir +from torch_mlir.dynamo import make_simple_dynamo_backend +from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend + + +def load_and_preprocess_image(url: str): + headers = { + 'User-Agent': + 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36' + } + img = Image.open(requests.get(url, headers=headers, + stream=True).raw).convert("RGB") + # preprocessing pipeline + preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + img_preprocessed = preprocess(img) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(): + classes_text = requests.get( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res): + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 + top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] + return top3 + + +def predictions(torch_func, jit_func, img, labels): + golden_prediction = top3_possibilities(torch_func(img)) + print("PyTorch prediction") + print(golden_prediction) + prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) + print("torch-mlir prediction") + print(prediction) + +image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" + +print("load image from " + image_url, file=sys.stderr) +img = load_and_preprocess_image(image_url) +labels = load_labels() + +@make_simple_dynamo_backend +def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): + mlir_module = torch_mlir.compile( + fx_graph, example_inputs, output_type="linalg-on-tensors") + backend = refbackend.RefBackendLinalgOnTensorsBackend() + compiled = backend.compile(mlir_module) + loaded = backend.load(compiled) + + def compiled_callable(*inputs): + inputs = [x.numpy() for x in inputs] + result = loaded.forward(*inputs) + if not isinstance(result, tuple): + result = torch.from_numpy(result) + else: + result = tuple(torch.from_numpy(x) for x in result) + return result + return compiled_callable + +resnet18 = models.resnet18(pretrained=True) +resnet18.train(False) +dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18) + +predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels) diff --git a/python/torch_mlir_e2e_test/configs/torchdynamo.py b/python/torch_mlir_e2e_test/configs/torchdynamo.py index 293fe15d1..044059818 100644 --- a/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -16,8 +16,8 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @make_simple_dynamo_backend -def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, - example_inputs: List[torch.Tensor]): +def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): # Use the LinalgOnTensors backend, since it is the most complete. # In theory we could mix and match TorchDynamo with the other backends, # since they all lower through the same backend contract. @@ -49,10 +49,6 @@ def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, return compiled_callable -@dynamo.optimize(refbackend_torchdynamo_backend) -def f(method, *inputs): - return method(*inputs) - class TorchDynamoTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with TorchDynamo""" @@ -67,7 +63,9 @@ class TorchDynamoTestConfig(TestConfig): # stateful then it does not mutate the original compiled program. result: Trace = [] for item in trace: - output = f(getattr(artifact, item.symbol), *item.inputs) + f = lambda method, *inputs: method(*inputs) + dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend)(f) + output = dynamo_f(getattr(artifact, item.symbol), *item.inputs) result.append( TraceItem(symbol=item.symbol, inputs=item.inputs,