Fix torchscript_resnet18_e2e.py and resnet_inference.ipynb

Fix the tests to run with refbackend.
pull/329/head
Yi Zhang 2021-09-24 18:03:25 -04:00
parent cd7053dfde
commit aa10ec66a7
2 changed files with 64 additions and 103 deletions

File diff suppressed because one or more lines are too long

View File

@ -7,17 +7,15 @@ import requests
import torch
import torchvision.models as models
from torchvision import transforms
import typing
import torch_mlir
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
import npcomp
from npcomp.passmanager import PassManager
from npcomp.compiler.pytorch.backend import refbackend
from npcomp.compiler.utils import logging
mb = torch_mlir.ModuleBuilder()
mb = ModuleBuilder()
def load_and_preprocess_image(url: str):
headers = {
@ -58,7 +56,7 @@ def predictions(torch_func, jit_func, img, labels):
golden_prediction = top3_possibilities(torch_func(img))
print("PyTorch prediction")
print(golden_prediction)
prediction = top3_possibilities(torch.from_numpy(jit_func(img)))
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
print("NPCOMP prediction")
print(prediction)
@ -92,7 +90,7 @@ img = load_and_preprocess_image(image_url)
labels = load_labels()
test_module = TestModule()
class_annotator = torch_mlir.ClassAnnotator()
class_annotator = ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)
torch.jit.save(recursivescriptmodule, "/tmp/foo.pt")
@ -110,8 +108,13 @@ class_annotator.annotateArgs(
mb.import_module(recursivescriptmodule._c, class_annotator)
backend = refbackend.RefBackendNpcompBackend()
PassManager.parse("torchscript-to-npcomp-backend-pipeline").run(mb.module)
compiled = backend.compile(mb.module)
with npcomp.ir.Context() as ctx:
npcomp.register_all_dialects(ctx)
lowered_mlir_module = npcomp.ir.Module.parse(str(mb.module))
pm = PassManager.parse('torchscript-to-npcomp-backend-pipeline')
pm.run(lowered_mlir_module)
compiled = backend.compile(lowered_mlir_module)
jit_module = backend.load(compiled)
predictions(test_module.forward, jit_module.forward, img, labels)