mirror of https://github.com/llvm/torch-mlir
Add cos_e2e.py, test_utils and support for tensor inputs (#134)
parent
e2405e3ca8
commit
699bf5df45
|
@ -0,0 +1,32 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
torch.manual_seed(0)
|
||||
input = torch.rand(2, 3)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("cos", [input]) as f:
|
||||
result = torch.cos(input)
|
||||
f.returns([result])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
logging.debug(f"Executing jit_module.cos")
|
||||
test_utils.compare_outputs(torch.cos, jit_module.cos, input)
|
||||
|
||||
# This fails because ModuleBuilder represents torch.cos with a constant:
|
||||
# https://github.com/llvm/mlir-npcomp/issues/135
|
||||
test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1)
|
|
@ -2,15 +2,15 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend.refjit import *
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -22,12 +22,8 @@ with mb.capture_function("mm", [lhs, rhs]) as f:
|
|||
result = torch.mm(lhs, rhs)
|
||||
f.returns([result])
|
||||
|
||||
backend = CompilerBackend()
|
||||
backend = refjit.CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
jit_result = jit_module.mm(lhs.numpy(), rhs.numpy())
|
||||
|
||||
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
|
||||
print(f"JIT Result = {jit_result}", file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(result.numpy(), jit_result)
|
||||
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
|
||||
test_utils.compare_outputs(torch.mm, jit_module.mm, lhs + 1, rhs - 1)
|
||||
|
|
|
@ -2,15 +2,15 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend.refjit import *
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
lhs = torch.ones((4, 6, 1))
|
||||
|
@ -18,19 +18,20 @@ rhs = torch.ones((1, 1, 3)) * 0.6
|
|||
bias = torch.ones((1, 1, 3)) * 0.2
|
||||
threshold = torch.tensor((0.75, 0.25, 0.10))
|
||||
|
||||
|
||||
def mul_maximum(lhs, rhs, threshold, bias):
|
||||
return torch.maximum(lhs * rhs, threshold) + bias
|
||||
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f:
|
||||
result = torch.maximum(lhs * rhs, threshold)
|
||||
result = result + bias
|
||||
result = mul_maximum(lhs, rhs, threshold, bias)
|
||||
f.returns([result])
|
||||
|
||||
backend = CompilerBackend()
|
||||
backend = refjit.CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(),
|
||||
bias.numpy())
|
||||
|
||||
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
|
||||
print(f"JIT Result = {jit_result}", file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(result.numpy(), jit_result)
|
||||
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
|
||||
threshold, bias)
|
||||
test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs + 1,
|
||||
rhs + 2, threshold, bias)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
import numpy as np
|
||||
|
||||
INDENT = " "
|
||||
|
||||
|
||||
def _indent(value):
|
||||
return textwrap.indent(str(value), INDENT)
|
||||
|
||||
|
||||
def compare_outputs(torch_func, jit_func, *args):
|
||||
print('—' * 80)
|
||||
|
||||
print(f"Input args:\n{_indent(args)}", file=sys.stderr)
|
||||
result = torch_func(*args)
|
||||
jit_result = jit_func(*args)
|
||||
|
||||
np.testing.assert_allclose(result.numpy(), jit_result)
|
||||
|
||||
# Only print these if the test passes, as np.testing will print them if it
|
||||
# fails.
|
||||
print(f"PyTorch Result:\n{_indent(result.numpy())}", file=sys.stderr)
|
||||
print(f"JIT Result:\n{_indent(jit_result)}", file=sys.stderr)
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.passmanager import *
|
||||
from npcomp.compiler.generic.backend import refjit as refjit_backend
|
||||
|
@ -25,6 +27,20 @@ TORCH_TO_TCF_PASSES = (
|
|||
is_enabled = refjit_backend.is_enabled
|
||||
|
||||
|
||||
class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker):
|
||||
"""Allows torch.Tensor inputs to be passed to module invocations."""
|
||||
|
||||
def __getitem__(self, function_name: str):
|
||||
numpy_invoke = super().__getitem__(function_name)
|
||||
|
||||
def invoke(*args):
|
||||
args = tuple(
|
||||
arg.numpy() if isinstance(arg, torch.Tensor) else arg for arg in args)
|
||||
return numpy_invoke(*args)
|
||||
|
||||
return invoke
|
||||
|
||||
|
||||
class CompilerBackend:
|
||||
"""Main entry-point for the backend."""
|
||||
|
||||
|
@ -68,9 +84,6 @@ class CompilerBackend:
|
|||
imported_module, refjit_backend.get_runtime_libs())
|
||||
return jit_module
|
||||
|
||||
def load(self, jit_module):
|
||||
"""Loads a compiled artifact into the runtime.
|
||||
|
||||
Since this is a JIT instead of an AOT compiler,
|
||||
"""
|
||||
return refjit_backend.JitModuleInvoker(jit_module)
|
||||
def load(self, jit_module) -> TorchJitModuleInvoker:
|
||||
"""Loads a compiled artifact into the runtime."""
|
||||
return TorchJitModuleInvoker(jit_module)
|
||||
|
|
Loading…
Reference in New Issue