torch-mlir/projects/pt1/examples/fximporter_resnet18.py

60 lines
1.7 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from pathlib import Path
import torch
import torch.utils._pytree as pytree
import torchvision.models as models
from torch_mlir import fx
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
from torch_mlir_e2e_test.configs.utils import (
recursively_convert_to_numpy,
)
sys.path.append(str(Path(__file__).absolute().parent))
from _example_utils import (
top3_possibilities,
load_and_preprocess_image,
load_labels,
DEFAULT_IMAGE_URL,
)
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
labels = load_labels()
resnet18 = models.resnet18(pretrained=True).eval()
module = fx.export_and_import(
resnet18,
torch.ones(1, 3, 224, 224),
output_type="linalg-on-tensors",
func_name=resnet18.__class__.__name__,
)
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
fx_module = backend.load(compiled)
params = {
**dict(resnet18.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
with torch.no_grad():
numpy_inputs = recursively_convert_to_numpy(params_flat + [img])
golden_prediction = top3_possibilities(resnet18.forward(img), labels)
print("PyTorch prediction")
print(golden_prediction)
prediction = top3_possibilities(
torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)),
labels,
)
print("torch-mlir prediction")
print(prediction)