Removed import typo in torchfx example

pull/325/head
Ramiro Leal-Cavazos 2021-09-22 23:13:45 +00:00 committed by Sean Silva
parent 603e068e45
commit 2b18aad807
1 changed files with 4 additions and 4 deletions

View File

@ -8,9 +8,9 @@ import npcomp
from npcomp.compiler.pytorch.backend import refbackend
from npcomp.passmanager import PassManager
from torchfx2iree.builder import build_module
from torchfx2iree.annotator import annotate_forward_args
from torchfx2iree.torch_mlir_types import TorchTensorType
from torchfx.builder import build_module
from torchfx.annotator import annotate_forward_args
from torchfx.torch_mlir_types import TorchTensorType
class MyModule(torch.nn.Module):
@ -49,5 +49,5 @@ jit_module = backend.load(compiled)
print("\n\nRunning Forward Function")
t = torch.rand((2, 2), dtype=torch.float)
print("Compiled result:\n", jit_module.forward(t, t))
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
print("\nExpected result:\n", module.forward(t, t))