Add cos_e2e.py, test_utils and support for tensor inputs (#134)

pull/136/head
Phoenix Meadowlark 2020-11-24 19:02:50 -08:00 committed by GitHub
parent e2405e3ca8
commit 699bf5df45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 100 additions and 29 deletions

View File

@ -0,0 +1,32 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
input = torch.rand(2, 3)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("cos", [input]) as f:
result = torch.cos(input)
f.returns([result])
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
logging.debug(f"Executing jit_module.cos")
test_utils.compare_outputs(torch.cos, jit_module.cos, input)
# This fails because ModuleBuilder represents torch.cos with a constant:
# https://github.com/llvm/mlir-npcomp/issues/135
test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1)

View File

@ -2,15 +2,15 @@
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import sys
import numpy as np
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend.refjit import *
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
@ -22,12 +22,8 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs)
f.returns([result])
backend = CompilerBackend()
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
jit_result = jit_module.mm(lhs.numpy(), rhs.numpy())
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
print(f"JIT Result = {jit_result}", file=sys.stderr)
np.testing.assert_allclose(result.numpy(), jit_result)
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs + 1, rhs - 1)

View File

@ -2,15 +2,15 @@
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import sys
import numpy as np
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend.refjit import *
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
lhs = torch.ones((4, 6, 1))
@ -18,19 +18,20 @@ rhs = torch.ones((1, 1, 3)) * 0.6
bias = torch.ones((1, 1, 3)) * 0.2
threshold = torch.tensor((0.75, 0.25, 0.10))
def mul_maximum(lhs, rhs, threshold, bias):
return torch.maximum(lhs * rhs, threshold) + bias
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
result = torch.maximum(lhs * rhs, threshold)
result = result + bias
result = mul_maximum(lhs, rhs, threshold, bias)
f.returns([result])
backend = CompilerBackend()
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(),
bias.numpy())
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
print(f"JIT Result = {jit_result}", file=sys.stderr)
np.testing.assert_allclose(result.numpy(), jit_result)
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
threshold, bias)
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs + 1,
rhs + 2, threshold, bias)

View File

@ -0,0 +1,29 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import sys
import textwrap
import numpy as np
INDENT = " "
def _indent(value):
return textwrap.indent(str(value), INDENT)
def compare_outputs(torch_func, jit_func, *args):
print('' * 80)
print(f"Input args:\n{_indent(args)}", file=sys.stderr)
result = torch_func(*args)
jit_result = jit_func(*args)
np.testing.assert_allclose(result.numpy(), jit_result)
# Only print these if the test passes, as np.testing will print them if it
# fails.
print(f"PyTorch Result:\n{_indent(result.numpy())}", file=sys.stderr)
print(f"JIT Result:\n{_indent(jit_result)}", file=sys.stderr)

View File

@ -4,6 +4,8 @@
import os
import torch
from mlir.ir import *
from mlir.passmanager import *
from npcomp.compiler.generic.backend import refjit as refjit_backend
@ -25,6 +27,20 @@ TORCH_TO_TCF_PASSES = (
is_enabled = refjit_backend.is_enabled
class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
"""Allows torch.Tensor inputs to be passed to module invocations."""
def __getitem__(self, function_name: str):
numpy_invoke = super().__getitem__(function_name)
def invoke(*args):
args = tuple(
arg.numpy() if isinstance(arg, torch.Tensor) else arg for arg in args)
return numpy_invoke(*args)
return invoke
class CompilerBackend:
"""Main entry-point for the backend."""
@ -68,9 +84,6 @@ class CompilerBackend:
imported_module, refjit_backend.get_runtime_libs())
return jit_module
def load(self, jit_module):
"""Loads a compiled artifact into the runtime.
Since this is a JIT instead of an AOT compiler,
"""
return refjit_backend.JitModuleInvoker(jit_module)
def load(self, jit_module) -> TorchJitModuleInvoker:
"""Loads a compiled artifact into the runtime."""
return TorchJitModuleInvoker(jit_module)