mirror of https://github.com/llvm/torch-mlir
Fix torchscript_resnet18_e2e.py and resnet_inference.ipynb
Fix the tests to run with refbackend.pull/329/head
parent
cd7053dfde
commit
aa10ec66a7
File diff suppressed because one or more lines are too long
|
@ -7,17 +7,15 @@ import requests
|
|||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import typing
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
|
||||
import npcomp
|
||||
from npcomp.passmanager import PassManager
|
||||
from npcomp.compiler.pytorch.backend import refbackend
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
|
@ -58,7 +56,7 @@ def predictions(torch_func, jit_func, img, labels):
|
|||
golden_prediction = top3_possibilities(torch_func(img))
|
||||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img)))
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
|
||||
print("NPCOMP prediction")
|
||||
print(prediction)
|
||||
|
||||
|
@ -92,7 +90,7 @@ img = load_and_preprocess_image(image_url)
|
|||
labels = load_labels()
|
||||
|
||||
test_module = TestModule()
|
||||
class_annotator = torch_mlir.ClassAnnotator()
|
||||
class_annotator = ClassAnnotator()
|
||||
recursivescriptmodule = torch.jit.script(test_module)
|
||||
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
|
||||
|
||||
|
@ -110,8 +108,13 @@ class_annotator.annotateArgs(
|
|||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
|
||||
backend = refbackend.RefBackendNpcompBackend()
|
||||
PassManager.parse("torchscript-to-npcomp-backend-pipeline").run(mb.module)
|
||||
compiled = backend.compile(mb.module)
|
||||
with npcomp.ir.Context() as ctx:
|
||||
npcomp.register_all_dialects(ctx)
|
||||
lowered_mlir_module = npcomp.ir.Module.parse(str(mb.module))
|
||||
pm = PassManager.parse('torchscript-to-npcomp-backend-pipeline')
|
||||
pm.run(lowered_mlir_module)
|
||||
|
||||
compiled = backend.compile(lowered_mlir_module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
||||
|
|
Loading…
Reference in New Issue