mirror of https://github.com/llvm/torch-mlir
34 lines
810 B
Python
34 lines
810 B
Python
# -*- Python -*-
|
|
# This file is licensed under a pytorch-style license
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
import torch
|
|
import torch_mlir
|
|
|
|
import npcomp
|
|
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
|
from npcomp.compiler.utils import logging
|
|
|
|
import test_utils
|
|
|
|
logging.enable()
|
|
|
|
torch.manual_seed(0)
|
|
|
|
arg0 = torch.ones(2, 2)
|
|
|
|
def fun(a):
|
|
z = torch.zeros(2, 2)
|
|
torch.tanh(a, out=z)
|
|
return z
|
|
|
|
mb = torch_mlir.ModuleBuilder()
|
|
with mb.capture_function("test", [arg0]) as f:
|
|
f.returns([fun(arg0)])
|
|
|
|
backend = iree.IreeNpcompBackend()
|
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
|
|
|
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
|
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)
|