mirror of https://github.com/llvm/torch-mlir
60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
# Also available under a BSD-style license. See LICENSE.
|
||
|
|
||
|
import sys
|
||
|
from pathlib import Path
|
||
|
|
||
|
import torch
|
||
|
import torch.utils._pytree as pytree
|
||
|
import torchvision.models as models
|
||
|
from torch_mlir import fx
|
||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||
|
from torch_mlir_e2e_test.configs.utils import (
|
||
|
recursively_convert_to_numpy,
|
||
|
)
|
||
|
|
||
|
sys.path.append(str(Path(__file__).absolute().parent))
|
||
|
from _example_utils import (
|
||
|
top3_possibilities,
|
||
|
load_and_preprocess_image,
|
||
|
load_labels,
|
||
|
DEFAULT_IMAGE_URL,
|
||
|
)
|
||
|
|
||
|
|
||
|
print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr)
|
||
|
img = load_and_preprocess_image(DEFAULT_IMAGE_URL)
|
||
|
labels = load_labels()
|
||
|
|
||
|
resnet18 = models.resnet18(pretrained=True).eval()
|
||
|
module = fx.export_and_import(
|
||
|
resnet18,
|
||
|
torch.ones(1, 3, 224, 224),
|
||
|
output_type="linalg-on-tensors",
|
||
|
func_name=resnet18.__class__.__name__,
|
||
|
)
|
||
|
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||
|
compiled = backend.compile(module)
|
||
|
fx_module = backend.load(compiled)
|
||
|
|
||
|
params = {
|
||
|
**dict(resnet18.named_buffers(remove_duplicate=False)),
|
||
|
}
|
||
|
params_flat, params_spec = pytree.tree_flatten(params)
|
||
|
params_flat = list(params_flat)
|
||
|
with torch.no_grad():
|
||
|
numpy_inputs = recursively_convert_to_numpy(params_flat + [img])
|
||
|
|
||
|
golden_prediction = top3_possibilities(resnet18.forward(img), labels)
|
||
|
print("PyTorch prediction")
|
||
|
print(golden_prediction)
|
||
|
|
||
|
prediction = top3_possibilities(
|
||
|
torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)),
|
||
|
labels,
|
||
|
)
|
||
|
print("torch-mlir prediction")
|
||
|
print(prediction)
|