mirror of https://github.com/llvm/torch-mlir
Update examples.
TorchFX example has been simplified, since it seems to be hitting that weird RefBackend bug. Will dig into that.pull/333/head
parent
e687d39074
commit
64ce5d54d3
File diff suppressed because one or more lines are too long
|
@ -397,7 +397,8 @@ def _add_handler(func_builder: _ForwardFunctionBuilder,
|
|||
and an argument named `other`'
|
||||
tensor_type = TorchTensorType().to_mlir(func_builder.context)
|
||||
int_type = PythonType(int).to_mlir(func_builder.context)
|
||||
int_attr = ir.IntegerAttr.get(int_type, 1)
|
||||
attr_type = ir.Type.parse('i64', func_builder.context)
|
||||
int_attr = ir.IntegerAttr.get(attr_type, 1)
|
||||
alpha_arg = torch_d.ConstantIntOp(int_type,
|
||||
int_attr,
|
||||
loc=func_builder.loc,
|
||||
|
|
|
@ -2,15 +2,20 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
# From the torch-mlir root, run with:
|
||||
# `python -m examples.torchfx.examples.example_add_tanh_sigmoid`
|
||||
# (after setting up python environment with write_env_file.sh)
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.fx_acc import acc_tracer
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refbackend
|
||||
from npcomp.passmanager import PassManager
|
||||
import torch_mlir
|
||||
from torch_mlir.dialects.torch import register_dialect
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||
|
||||
from torchfx.builder import build_module
|
||||
from torchfx.annotator import annotate_forward_args
|
||||
from torchfx.torch_mlir_types import TorchTensorType
|
||||
from ..builder import build_module
|
||||
from ..annotator import annotate_forward_args
|
||||
from ..torch_mlir_types import TorchTensorType
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
@ -18,7 +23,9 @@ class MyModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return torch.tanh(x) + torch.sigmoid(y)
|
||||
# TODO: Debug issue with RefBackend
|
||||
#return torch.tanh(x) + torch.sigmoid(y)
|
||||
return torch.tanh(x)
|
||||
|
||||
|
||||
module = MyModule()
|
||||
|
@ -33,18 +40,17 @@ torch_mlir_module = build_module(traced_module)
|
|||
|
||||
print("\n\nTORCH MLIR")
|
||||
torch_mlir_module.dump()
|
||||
print(torch_mlir_module.operation.verify())
|
||||
|
||||
with npcomp.ir.Context() as ctx:
|
||||
npcomp.register_all_dialects(ctx)
|
||||
lowered_mlir_module = npcomp.ir.Module.parse(str(torch_mlir_module))
|
||||
pm = PassManager.parse('torchscript-to-npcomp-backend-pipeline')
|
||||
pm.run(lowered_mlir_module)
|
||||
with torch_mlir_module.context:
|
||||
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(torch_mlir_module)
|
||||
|
||||
print("\n\nLOWERED MLIR")
|
||||
lowered_mlir_module.dump()
|
||||
torch_mlir_module.dump()
|
||||
|
||||
backend = refbackend.RefBackendNpcompBackend()
|
||||
compiled = backend.compile(lowered_mlir_module)
|
||||
backend = RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(torch_mlir_module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
print("\n\nRunning Forward Function")
|
||||
|
|
|
@ -10,10 +10,9 @@ from torchvision import transforms
|
|||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ModuleBuilder
|
||||
|
||||
import npcomp
|
||||
from npcomp.passmanager import PassManager
|
||||
from npcomp.compiler.pytorch.backend import refbackend
|
||||
from npcomp.compiler.utils import logging
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
@ -57,7 +56,7 @@ def predictions(torch_func, jit_func, img, labels):
|
|||
print("PyTorch prediction")
|
||||
print(golden_prediction)
|
||||
prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())))
|
||||
print("NPCOMP prediction")
|
||||
print("torch-mlir prediction")
|
||||
print(prediction)
|
||||
|
||||
|
||||
|
@ -107,14 +106,12 @@ class_annotator.annotateArgs(
|
|||
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
|
||||
mb.import_module(recursivescriptmodule._c, class_annotator)
|
||||
|
||||
backend = refbackend.RefBackendNpcompBackend()
|
||||
with npcomp.ir.Context() as ctx:
|
||||
npcomp.register_all_dialects(ctx)
|
||||
lowered_mlir_module = npcomp.ir.Module.parse(str(mb.module))
|
||||
pm = PassManager.parse('torchscript-to-npcomp-backend-pipeline')
|
||||
pm.run(lowered_mlir_module)
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse('torchscript-to-linalg-on-tensors-backend-pipeline')
|
||||
pm.run(mb.module)
|
||||
|
||||
compiled = backend.compile(lowered_mlir_module)
|
||||
compiled = backend.compile(mb.module)
|
||||
jit_module = backend.load(compiled)
|
||||
|
||||
predictions(test_module.forward, jit_module.forward, img, labels)
|
||||
|
|
Loading…
Reference in New Issue