torch-mlir/frontends/pytorch/examples/mm_e2e.py

30 lines
730 B
Python
Raw Normal View History

# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
lhs = torch.rand(2, 3)
rhs = torch.rand(3, 4)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs)
f.returns([result])
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs + 1, rhs - 1)