2020-11-25 11:02:50 +08:00
|
|
|
# -*- Python -*-
|
|
|
|
# This file is licensed under a pytorch-style license
|
|
|
|
# See frontends/pytorch/LICENSE for license information.
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch_mlir
|
|
|
|
|
|
|
|
import npcomp
|
2021-04-09 04:05:16 +08:00
|
|
|
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
|
2020-11-25 11:02:50 +08:00
|
|
|
from npcomp.compiler.utils import logging
|
|
|
|
|
|
|
|
import test_utils
|
|
|
|
|
|
|
|
logging.enable()
|
|
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
input = torch.rand(2, 3)
|
|
|
|
|
|
|
|
mb = torch_mlir.ModuleBuilder()
|
|
|
|
with mb.capture_function("cos", [input]) as f:
|
|
|
|
result = torch.cos(input)
|
|
|
|
f.returns([result])
|
|
|
|
|
|
|
|
backend = refjit.CompilerBackend()
|
2021-04-09 04:05:16 +08:00
|
|
|
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))
|
2020-11-25 11:02:50 +08:00
|
|
|
|
|
|
|
logging.debug(f"Executing jit_module.cos")
|
|
|
|
test_utils.compare_outputs(torch.cos, jit_module.cos, input)
|
|
|
|
|
|
|
|
# This fails because ModuleBuilder represents torch.cos with a constant:
|
|
|
|
# https://github.com/llvm/mlir-npcomp/issues/135
|
|
|
|
test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1)
|