mirror of https://github.com/llvm/torch-mlir
Removed import typo in torchfx example
parent
603e068e45
commit
2b18aad807
|
@ -8,9 +8,9 @@ import npcomp
|
|||
from npcomp.compiler.pytorch.backend import refbackend
|
||||
from npcomp.passmanager import PassManager
|
||||
|
||||
from torchfx2iree.builder import build_module
|
||||
from torchfx2iree.annotator import annotate_forward_args
|
||||
from torchfx2iree.torch_mlir_types import TorchTensorType
|
||||
from torchfx.builder import build_module
|
||||
from torchfx.annotator import annotate_forward_args
|
||||
from torchfx.torch_mlir_types import TorchTensorType
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
@ -49,5 +49,5 @@ jit_module = backend.load(compiled)
|
|||
|
||||
print("\n\nRunning Forward Function")
|
||||
t = torch.rand((2, 2), dtype=torch.float)
|
||||
print("Compiled result:\n", jit_module.forward(t, t))
|
||||
print("Compiled result:\n", jit_module.forward(t.numpy(), t.numpy()))
|
||||
print("\nExpected result:\n", module.forward(t, t))
|
||||
|
|
Loading…
Reference in New Issue