diff --git a/frontends/pytorch/examples/__init__.py b/frontends/pytorch/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/frontends/pytorch/examples/cos_e2e.py b/frontends/pytorch/examples/cos_e2e.py new file mode 100644 index 000000000..994e9e387 --- /dev/null +++ b/frontends/pytorch/examples/cos_e2e.py @@ -0,0 +1,32 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import torch +import torch_mlir + +import npcomp +from npcomp.compiler.pytorch.backend import refjit +from npcomp.compiler.utils import logging + +import test_utils + +logging.enable() + +torch.manual_seed(0) +input = torch.rand(2, 3) + +mb = torch_mlir.ModuleBuilder() +with mb.capture_function("cos", [input]) as f: + result = torch.cos(input) + f.returns([result]) + +backend = refjit.CompilerBackend() +jit_module = backend.load(backend.compile(mb.module)) + +logging.debug(f"Executing jit_module.cos") +test_utils.compare_outputs(torch.cos, jit_module.cos, input) + +# This fails because ModuleBuilder represents torch.cos with a constant: +# https://github.com/llvm/mlir-npcomp/issues/135 +test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1) diff --git a/frontends/pytorch/examples/mm_e2e.py b/frontends/pytorch/examples/mm_e2e.py index a99fff1e0..fa993b5d7 100644 --- a/frontends/pytorch/examples/mm_e2e.py +++ b/frontends/pytorch/examples/mm_e2e.py @@ -2,15 +2,15 @@ # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. -import sys -import numpy as np import torch import torch_mlir import npcomp -from npcomp.compiler.pytorch.backend.refjit import * +from npcomp.compiler.pytorch.backend import refjit from npcomp.compiler.utils import logging +import test_utils + logging.enable() torch.manual_seed(0) @@ -22,12 +22,8 @@ with mb.capture_function("mm", [lhs, rhs]) as f: result = torch.mm(lhs, rhs) f.returns([result]) -backend = CompilerBackend() +backend = refjit.CompilerBackend() jit_module = backend.load(backend.compile(mb.module)) -jit_result = jit_module.mm(lhs.numpy(), rhs.numpy()) - -print(f"PyTorch Result = {result.numpy()}", file=sys.stderr) -print(f"JIT Result = {jit_result}", file=sys.stderr) - -np.testing.assert_allclose(result.numpy(), jit_result) +test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs) +test_utils.compare_outputs(torch.mm, jit_module.mm, lhs + 1, rhs - 1) diff --git a/frontends/pytorch/examples/mul_maximum_e2e.py b/frontends/pytorch/examples/mul_maximum_e2e.py index 2cb0d8c21..c916f661b 100644 --- a/frontends/pytorch/examples/mul_maximum_e2e.py +++ b/frontends/pytorch/examples/mul_maximum_e2e.py @@ -2,15 +2,15 @@ # This file is licensed under a pytorch-style license # See frontends/pytorch/LICENSE for license information. -import sys -import numpy as np import torch import torch_mlir import npcomp -from npcomp.compiler.pytorch.backend.refjit import * +from npcomp.compiler.pytorch.backend import refjit from npcomp.compiler.utils import logging +import test_utils + logging.enable() lhs = torch.ones((4, 6, 1)) @@ -18,19 +18,20 @@ rhs = torch.ones((1, 1, 3)) * 0.6 bias = torch.ones((1, 1, 3)) * 0.2 threshold = torch.tensor((0.75, 0.25, 0.10)) + +def mul_maximum(lhs, rhs, threshold, bias): + return torch.maximum(lhs * rhs, threshold) + bias + + mb = torch_mlir.ModuleBuilder() with mb.capture_function("mul_maximum", [lhs, rhs, threshold, bias]) as f: - result = torch.maximum(lhs * rhs, threshold) - result = result + bias + result = mul_maximum(lhs, rhs, threshold, bias) f.returns([result]) -backend = CompilerBackend() +backend = refjit.CompilerBackend() jit_module = backend.load(backend.compile(mb.module)) -jit_result = jit_module.mul_maximum(lhs.numpy(), rhs.numpy(), threshold.numpy(), - bias.numpy()) - -print(f"PyTorch Result = {result.numpy()}", file=sys.stderr) -print(f"JIT Result = {jit_result}", file=sys.stderr) - -np.testing.assert_allclose(result.numpy(), jit_result) +test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs, + threshold, bias) +test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs + 1, + rhs + 2, threshold, bias) diff --git a/frontends/pytorch/examples/test_utils.py b/frontends/pytorch/examples/test_utils.py new file mode 100644 index 000000000..25a1b8efb --- /dev/null +++ b/frontends/pytorch/examples/test_utils.py @@ -0,0 +1,29 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +import sys +import textwrap + +import numpy as np + +INDENT = " " + + +def _indent(value): + return textwrap.indent(str(value), INDENT) + + +def compare_outputs(torch_func, jit_func, *args): + print('—' * 80) + + print(f"Input args:\n{_indent(args)}", file=sys.stderr) + result = torch_func(*args) + jit_result = jit_func(*args) + + np.testing.assert_allclose(result.numpy(), jit_result) + + # Only print these if the test passes, as np.testing will print them if it + # fails. + print(f"PyTorch Result:\n{_indent(result.numpy())}", file=sys.stderr) + print(f"JIT Result:\n{_indent(jit_result)}", file=sys.stderr) diff --git a/python/npcomp/compiler/pytorch/backend/refjit.py b/python/npcomp/compiler/pytorch/backend/refjit.py index 23def3656..f3cb4ddd6 100644 --- a/python/npcomp/compiler/pytorch/backend/refjit.py +++ b/python/npcomp/compiler/pytorch/backend/refjit.py @@ -4,6 +4,8 @@ import os +import torch + from mlir.ir import * from mlir.passmanager import * from npcomp.compiler.generic.backend import refjit as refjit_backend @@ -25,6 +27,20 @@ TORCH_TO_TCF_PASSES = ( is_enabled = refjit_backend.is_enabled +class TorchJitModuleInvoker(refjit_backend.JitModuleInvoker): + """Allows torch.Tensor inputs to be passed to module invocations.""" + + def __getitem__(self, function_name: str): + numpy_invoke = super().__getitem__(function_name) + + def invoke(*args): + args = tuple( + arg.numpy() if isinstance(arg, torch.Tensor) else arg for arg in args) + return numpy_invoke(*args) + + return invoke + + class CompilerBackend: """Main entry-point for the backend.""" @@ -68,9 +84,6 @@ class CompilerBackend: imported_module, refjit_backend.get_runtime_libs()) return jit_module - def load(self, jit_module): - """Loads a compiled artifact into the runtime. - - Since this is a JIT instead of an AOT compiler, - """ - return refjit_backend.JitModuleInvoker(jit_module) + def load(self, jit_module) -> TorchJitModuleInvoker: + """Loads a compiled artifact into the runtime.""" + return TorchJitModuleInvoker(jit_module)