mirror of https://github.com/llvm/torch-mlir
[torchdynamo] Add ResNet18 example with TorchDynamo
This is a minor variation on our other resnet18 examples swapping in TorchDynamo. We replicate the refbackend_torchdynamo_backend out of the e2e test config to avoid making that appear like a public API. Also, some minor cleanups to TorchDynamoTestConfig.pull/1696/head
parent
98d80a642a
commit
b1f9e09f85
|
@ -0,0 +1,94 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
import torch
|
||||
import torch._dynamo as dynamo
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36'
|
||||
}
|
||||
img = Image.open(requests.get(url, headers=headers,
|
||||
stream=True).raw).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
|
||||
|
||||
def predictions(torch_func, jit_func, img, labels):
|
||||
golden_prediction = top3_possibilities(torch_func(img))
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
|
||||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors")
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(mlir_module)
|
||||
loaded = backend.load(compiled)
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = loaded.forward(*inputs)
|
||||
if not isinstance(result, tuple):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
return result
|
||||
return compiled_callable
|
||||
|
||||
resnet18 = models.resnet18(pretrained=True)
|
||||
resnet18.train(False)
|
||||
dynamo_callable = dynamo.optimize(refbackend_torchdynamo_backend)(resnet18)
|
||||
|
||||
predictions(resnet18.forward, lambda x: dynamo_callable(torch.from_numpy(x)).numpy(), img, labels)
|
|
@ -16,8 +16,8 @@ from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
|||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor]):
|
||||
# Use the LinalgOnTensors backend, since it is the most complete.
|
||||
# In theory we could mix and match TorchDynamo with the other backends,
|
||||
# since they all lower through the same backend contract.
|
||||
|
@ -49,10 +49,6 @@ def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule,
|
|||
return compiled_callable
|
||||
|
||||
|
||||
@dynamo.optimize(refbackend_torchdynamo_backend)
|
||||
def f(method, *inputs):
|
||||
return method(*inputs)
|
||||
|
||||
class TorchDynamoTestConfig(TestConfig):
|
||||
"""TestConfig that runs the torch.nn.Module with TorchDynamo"""
|
||||
|
||||
|
@ -67,7 +63,9 @@ class TorchDynamoTestConfig(TestConfig):
|
|||
# stateful then it does not mutate the original compiled program.
|
||||
result: Trace = []
|
||||
for item in trace:
|
||||
output = f(getattr(artifact, item.symbol), *item.inputs)
|
||||
f = lambda method, *inputs: method(*inputs)
|
||||
dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend)(f)
|
||||
output = dynamo_f(getattr(artifact, item.symbol), *item.inputs)
|
||||
result.append(
|
||||
TraceItem(symbol=item.symbol,
|
||||
inputs=item.inputs,
|
||||
|
|
Loading…
Reference in New Issue