diff --git a/frontends/pytorch/e2e_testing/torchscript/basic.py b/frontends/pytorch/e2e_testing/torchscript/basic.py new file mode 100644 index 000000000..485e06853 --- /dev/null +++ b/frontends/pytorch/e2e_testing/torchscript/basic.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from torch_mlir.torchscript.e2e_test.framework import TestUtils +from torch_mlir.torchscript.e2e_test.registry import register_test_case +from torch_mlir.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32), + ([-1, -1], torch.float32), + ]) + def forward(self, lhs, rhs): + return torch.mm(lhs, rhs) + +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_chained(module, tu: TestUtils): + res = module.forward(tu.rand(4, 4), tu.rand(4, 4)) + module.forward(res, res) + +# ============================================================================== + +class TanhModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([2, 3, -1], torch.float32), + ]) + def forward(self, x): + return torch.tanh(x) + +@register_test_case(module_factory=lambda: TanhModule()) +def TanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 1)) + +# ============================================================================== + +class MmTanhModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32), + ([-1, -1], torch.float32), + ]) + def forward(self, lhs, rhs): + return torch.tanh(self.matmul(lhs, rhs)) + def matmul(self, lhs, rhs): + return torch.mm(lhs, rhs) + +@register_test_case(module_factory=lambda: MmTanhModule()) +def MmTanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 2), tu.rand(2, 4)) diff --git a/frontends/pytorch/e2e_testing/torchscript/main.py b/frontends/pytorch/e2e_testing/torchscript/main.py new file mode 100644 index 000000000..4b06cc774 --- /dev/null +++ b/frontends/pytorch/e2e_testing/torchscript/main.py @@ -0,0 +1,40 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse + +from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results +from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY + +# Available test configs. +from torch_mlir.torchscript.e2e_test.configs import ( + RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig +) + +# Import tests to register them in the global registry. +import basic + +def main(): + parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') + parser.add_argument('--config', + choices=['native_torch', 'torchscript', 'refbackend'], + default='refbackend', + help=''' +Meaning of options: +"refbackend": run through npcomp's RefBackend. +"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic). +"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). +''') + args = parser.parse_args() + if args.config == 'refbackend': + config = RefBackendTestConfig() + elif args.config == 'native_torch': + config = NativeTorchTestConfig() + elif args.config == 'torchscript': + config = TorchScriptTestConfig() + results = run_tests(GLOBAL_TEST_REGISTRY, config) + report_results(results) + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/python/torch_mlir/torchscript/__init__.py b/frontends/pytorch/python/torch_mlir/torchscript/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/frontends/pytorch/python/torch_mlir/torchscript/annotations.py b/frontends/pytorch/python/torch_mlir/torchscript/annotations.py new file mode 100644 index 000000000..57b9dd92b --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/annotations.py @@ -0,0 +1,95 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional, Tuple + +import torch + +import torch_mlir + +# Decorators + +# Currently, these decorators are very low-level and map 1:1 with +# methods on `torch_mlir.ClassAnnotator`. Eventually, we expect there to +# be a more elaborate Python layer which allows all the different annotations +# to be expressed conveniently and gives clearer error reports when +# the annotations aren't acceptable. + + +def export(fn): + """Decorator that tells the npcomp compiler that a method is exported. + + By default, no methods are exported, which is very important for + the compiler, because otherwise most Torch programs consist of a sea + of tiny exported functions with no rank or dtype information + (see `annotate_args`), which the compiler cannot do much with. + + Note that this is different from `torch.jit.export`, which controls + which methods are scripted in the first place. For non-`forward` methods, + using this decorator usually means you also need `torch.jit.export`. + Conceptually, this decorator is annotating the scripted module, but is + applied to the original `torch.nn.Module` for convenience. + """ + fn._npcomp_export = True + return fn + + +ArgAnnotation = Tuple[List[int], torch.dtype] + + +# TODO: Replace with py3 extended argument annotations when available. +# See https://www.python.org/dev/peps/pep-0593/ +def annotate_args(annotations: List[Optional[ArgAnnotation]]): + """Decorator that tells the npcomp compiler information about arguments. + + The `annotations` should be a list of the same length as the number of + argument to the method (including `self`). Each list entry is either: + - None, corresponding to providing the compiler with no information. + - A 2-tuple consisting of a shape and a dtype, such as + `([2, 3, 4], torch.float32)`. A dimension with an unknown size can be + indicated by using `-1` as the size. This provides the compiler a + guarantee that the argument will always dynamically have the described + shape and dtype. + """ + + # TODO: Check the number of arguments matches the number of arg annotations. + def decorator(fn): + fn._npcomp_arg_annotations = annotations + return fn + + return decorator + + +# Utilities for extracting decorated information into torch_mlir.ClassAnnotator. + + +def _recursively_extract_annotations( + module: torch.nn.Module, scripted: torch.jit.ScriptModule, + class_annotator: torch_mlir.ClassAnnotator): + assert module.__class__.__name__ == scripted.original_name, "script module does not come from specified module" + + # Extract information on methods. + for method_name, scripted_method in scripted.__dict__.items(): + if not isinstance(scripted_method, torch.ScriptMethod): + continue + method = getattr(module, method_name) + if hasattr(method, '_npcomp_export'): + class_annotator.exportPath(scripted._c._type(), [method_name]) + if hasattr(method, '_npcomp_arg_annotations'): + class_annotator.annotateShapesAndDtypes( + scripted._c._type(), [method_name], + method._npcomp_arg_annotations) + # Recurse. + for name, child in module.named_children(): + scripted_child = getattr(scripted, name) + _recursively_extract_annotations(child, scripted_child, + class_annotator) + + +def extract_annotations(program: torch.nn.Module, + scripted: torch.jit.ScriptModule, + class_annotator: torch_mlir.ClassAnnotator): + """Populate the ClassAnnotator with annotations extracted from `program`.""" + class_annotator.exportNone(scripted._c._type()) + _recursively_extract_annotations(program, scripted, class_annotator) diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/__init__.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py new file mode 100644 index 000000000..935aca84a --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py @@ -0,0 +1,7 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .ref_backend import RefBackendTestConfig +from .native_torch import NativeTorchTestConfig +from .torchscript import TorchScriptTestConfig diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/native_torch.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/native_torch.py new file mode 100644 index 000000000..6afb97a0b --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/native_torch.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +from typing import Any + +import torch + +from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem + + +class NativeTorchTestConfig(TestConfig): + """TestConfig that just runs the torch.nn.Module without compiling""" + def __init__(self): + super().__init__() + + def compile(self, program: torch.nn.Module) -> torch.nn.Module: + return program + + def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + # TODO: Deepcopy the torch.nn.Module, so that if the program is + # stateful then it does not mutate the original compiled program. + result: Trace = [] + for item in trace: + outputs = getattr(artifact, item.symbol)(*item.inputs) + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + outputs=outputs)) + return result diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py new file mode 100644 index 000000000..486a71417 --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py @@ -0,0 +1,47 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any + +import numpy as np +import torch + +import torch_mlir +from npcomp.compiler.pytorch.backend import refjit, frontend_lowering +from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir.torchscript.annotations import extract_annotations + + +class RefBackendTestConfig(TestConfig): + """TestConfig that just runs the torch.nn.Module through RefBackend.""" + def __init__(self): + super().__init__() + self.backend = refjit.CompilerBackend() + + def compile(self, program: torch.nn.Module) -> Any: + mb = torch_mlir.ModuleBuilder() + scripted = torch.jit.script(program) + class_annotator = torch_mlir.ClassAnnotator() + + extract_annotations(program, scripted, class_annotator) + + mb.import_module(scripted._c, class_annotator) + # Lower module in place. + frontend_lowering.lower_object_graph(mb.module) + return self.backend.compile(mb.module) + + def run(self, artifact: Any, trace: Trace) -> Trace: + jit_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = [t.numpy() for t in item.inputs] + outputs = getattr(jit_module, item.symbol)(*numpy_inputs) + if isinstance(outputs, np.ndarray): + outputs = [outputs] + torch_outputs = [torch.tensor(ndarray) for ndarray in outputs] + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + outputs=torch_outputs)) + return result diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/torchscript.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/torchscript.py new file mode 100644 index 000000000..378ecd81a --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/torchscript.py @@ -0,0 +1,34 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +from typing import Any + +import torch + +from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem + + +class TorchScriptTestConfig(TestConfig): + """TestConfig that runs the torch.nn.Module through TorchScript""" + def __init__(self): + super().__init__() + + def compile(self, program: torch.nn.Module) -> torch.jit.ScriptModule: + return torch.jit.script(program) + + def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace: + # TODO: Deepcopy the torch.jit.ScriptModule, so that if the program is + # stateful then it does not mutate the original compiled program. + + result: Trace = [] + for item in trace: + outputs = getattr(artifact, item.symbol)(*item.inputs) + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + outputs=outputs)) + return result diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py new file mode 100644 index 000000000..b869ea6b0 --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/framework.py @@ -0,0 +1,260 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +""" +# End-to-end testing framework for TorchScript. + +For the purposes of this framework, "end to end" means the first "end" is +a `torch.nn.Module`, and the second "end" is execution. + +## Architecture + +A program for this testing framework is considered to be a `torch.nn.Module`, +which has a public interface consisting of its methods and instance attributes. + +A test in the framework consists conceputally of a list of calls into +the methods of a module (TODO: extend to instance attributes). It is expected +that the outputs match between the program run on a backend (controlled by +a TestConfig) and a golden trace obtained by running on native Torch (without +compiling or TorchScript'ing). +""" + +import abc +from typing import Any, Callable, List, NamedTuple, TypeVar + +import torch + + +class TraceItem(NamedTuple): + # The externally visible symbol name that is called. + # For example `"forward"` or `"submodule.forward"`. + symbol: str + # The list of inputs to the call. + inputs: List[torch.Tensor] # TODO: Support more types. + # The outputs from the call. + # Sometimes this field is treated as golden outputs from a test. + # Sometimes this field is treated as ignored, such as the input trace + # provided to `TestConfig.run`. + outputs: List[torch.Tensor] # TODO: Support more types + + +# A trace of invocations to the program. +# This is an ordered sequence of external invocations to a program's +# public boundary. +Trace = List[TraceItem] + +# A type shared between the result of `TestConfig.compile` and the input +# to `TestConfig.run`. Each backend will likely have a different definition of +# this type. +CompiledArtifact = TypeVar('CompiledArtifact') + + +class TestConfig(abc.ABC): + """The interface implemented by backends to run tests. + + The testing framework expects to be able to call `compile` to compile + a torch.nn.Module, and then pass the compiled artifact to `run` to run it. + + Note that the definition of "compiled artifact" here is quite loose, and + this interface allows for many different use cases besides simple testing. + + For example, this interface can be overridden to be a "data collector" + to gather information across all the test cases. For example, + a compiler backend could override "compile" to just return some IR at a + useful intermediate abstraction level (rather than the final compiled + artifact), and then have "run" save this intermediate IR + the trace as + input to some lower-level software stack's testing format. + + The set of TestConfig's is expected to be pluggable and provided by + users to suit their own needs. We provide a few configs out of the box + in the `configs` submodule of this package, but those are intended + to be for basic inspiration and enough for our own testing. + Backends to npcomp will likely have more elaborate TestConfig's, such + as `compile` being "compile for such-and-such DSP with these vectorization + cost model flags" and `run` being "connect to Android phone with + device ID 1234 and upload a program to run on it's DSP core, and also set + power throttling settings to 'performance'". + + That is also why this class is not called "backend", as it + encapsulates potentially many specific details of the test configuration + process as well. There isn't a general way to disentangle test configuration + from the compile/run process specific to a logical backend, since each + backend (compiler backend and runtime target) will have an arbitrarily + wild and wonderful set of possible configurations that we cannot predict. + """ + # This is not a frontend-lowered module, to allow various testing at the PyTorch level. + # We can have a helper class NpcompBackendTestConfig which does that. + @abc.abstractmethod + def compile(self, program: torch.nn.Module) -> CompiledArtifact: + """Compile the provided torch.nn.Module into a compiled artifact""" + pass + + # Any should match result of `compile`. + + @abc.abstractmethod + def run(self, artifact: CompiledArtifact, trace: Trace) -> Trace: + """Run the compiled artifact produced by `compile`. + + The backend should load the compiled artifact and call the + symbol names listed in `trace` with their respective inputs (the outputs + of `trace` should be ignored). A new identical trace with outputs + populated should be returned. + + This method should assume that `artifact` is being shared with + multiple parallel invocations of `run`, and so it should not be mutated. + This property is typicaly trivially satisfied for a true + "compiled artifact", but some backends don't directly involve a + compiled artifact per se (like a backend for which `CompiledArtifact` is + `torch.nn.Module` and `run` just invokes the torch.nn.Module itself) + + Args: + artifact: a compiled artifact produced by `compile`. + trace: The external invocations to stimulate the module. + Returns: + A trace with outputs recorded according to the results of running + on this backend. + """ + pass + + +# Utilities for common testing trace generation. +# Also, resets the random seed for reproducibility. +# TODO: If generating in parallel, how to have manual_seed be local? +class TestUtils: + """Utilities for executing a test. + + Test cases are provided an instance of this class to make test cases + more succinct. + + For reproducibility, this class also resets the random seed. + TODO: Figure out how to seed reset properly scoped to just a test case + (such as when running tests in parallel) + """ + def __init__(self): + torch.manual_seed(0) + + # TODO: Add zeros/ones/etc. as convenient. + def rand(self, *sizes): + if len(sizes) == 0: + return torch.rand([]) + return torch.rand(*sizes) + + +class Test(NamedTuple): + """A description of a test as produced by the test frontend. + """ + # Stable name for error reporting. + # + # This name's stability is also useful for backend, which want to + # generate their own lower-level test suites based on this framework. + # + # It is expected that those backends will need additional + # metadata to describe their test configurations, so having a unique + # key to keep that information associated is important. + unique_name: str + # A callable which produces the module under test. + # This is a callable to allow lazily creating the module. + program_factory: Callable[[], torch.nn.Module] + # A callable which provides external stimuli to the module. + # The first parameter is a torch.nn.Module (or a `_Tracer` wrapping that + # module, actually). + # The secon parameter is a `TestUtils` instance for convenience. + program_invoker: Callable[[Any, TestUtils], None] + + +class TestResult(NamedTuple): + # Stable unique name for error reporting and test suite configuration. + # + # Tests frequently need some additional data (such as expected pass/fail + # status, desired test configurations, etc.), and this gives a key to + # associate to. This avoids extending this class arbitrarily for every + # possible requirement from the test framework. + # + # This name is also useful for backends that are generating their own + # lower-level test suites from this framework for the same reasons, though + # those reasons are stronger because we cannot simply extend this + # class. + unique_name: str # Should match Test.unique_name for corresponding test. + # The trace produced by the backend. + trace: Trace + # The golden trace which `trace` is expected to match. + golden_trace: Trace + + +class _Tracer: + """Wrapper around a `torch.nn.Module` that records calls into it. + + The inputs and outputs of each call are recorded in a Trace. + """ + module: torch.nn.Module + trace: Trace + + def __init__(self, module: torch.nn.Module): + self.module = module + self.trace = [] + + def __getattr__(self, name): + # TODO: Handle `module.foo.bar.baz` nesting. + # For now, we are limited to attributes of the top-level module. + def invoke(*args): + raw_outputs = getattr(self.module, name)(*args) + if isinstance(raw_outputs, torch.Tensor): + outputs = [raw_outputs] + self.trace.append( + TraceItem(symbol=name, inputs=args, outputs=outputs)) + return raw_outputs + + return invoke + + def get_trace(self): + return self.trace + + +def _generate_golden_trace(test: Test) -> Trace: + tracer = _Tracer(test.program_factory()) + test.program_invoker(tracer, TestUtils()) + return tracer.get_trace() + + +def run_tests(tests: List[Test], config: TestConfig) -> List[TestResult]: + """Invoke the given `Test`'s with the provided `TestConfig`.""" + results = [] + for test in tests: + golden_trace = _generate_golden_trace(test) + # TODO: Precompile everything in parallel. + compiled = config.compile(test.program_factory()) + # TODO: Run in parallel. + trace = config.run(compiled, golden_trace) + results.append( + TestResult(unique_name=test.unique_name, + trace=trace, + golden_trace=golden_trace)) + return results + + +def report_results(results: List[TestResult]): + """Provide a basic error report summarizing various TestResult's.""" + for result in results: + failed = False + for item_num, (item, golden_item) in enumerate( + zip(result.trace, result.golden_trace)): + assert item.symbol == golden_item.symbol + assert len(item.inputs) == len(golden_item.inputs) + assert len(item.outputs) == len(golden_item.outputs) + for input, golden_input in zip(item.inputs, golden_item.inputs): + assert torch.allclose(input, golden_input) + for output_num, (output, golden_output) in enumerate( + zip(item.outputs, golden_item.outputs)): + # TODO: Refine error message. Things to consider: + # - Very large tensors -- don't spew, but give useful info + # - Smaller tensors / primitives -- want to show exact values + # - Machine parseable format? + if not torch.allclose(output, golden_output): + print( + f'Error: in call #{item_num} into the module: result #{output_num} not close in call to "{item.symbol}"' + ) + failed = True + if failed: + print('FAILURE "{}"'.format(result.unique_name)) + else: + print('SUCCESS "{}"'.format(result.unique_name)) diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py new file mode 100644 index 000000000..34e616604 --- /dev/null +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/registry.py @@ -0,0 +1,30 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable + +import torch + +from .framework import Test + +# The global registry of tests. +GLOBAL_TEST_REGISTRY = [] + + +def register_test_case(module_factory: Callable[[], torch.nn.Module]): + """Convenient decorator-based test registration. + + Adds a `framework.Test` to the global test registry based on the decorated + function. The test's `unique_name` is taken from the function name, the + test's `program_factory` is taken from `module_factory`, and the + `program_invoker` is the decorated function. + """ + def decorator(f): + GLOBAL_TEST_REGISTRY.append( + Test(unique_name=f.__name__, + program_factory=module_factory, + program_invoker=f)) + return f + + return decorator diff --git a/frontends/pytorch/test/ivalue_import/annotations/sugar.py b/frontends/pytorch/test/ivalue_import/annotations/sugar.py new file mode 100644 index 000000000..ac9af702c --- /dev/null +++ b/frontends/pytorch/test/ivalue_import/annotations/sugar.py @@ -0,0 +1,50 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +# RUN: %PYTHON %s | FileCheck %s + +import torch + +import torch_mlir +from torch_mlir.torchscript.annotations import ( + annotate_args, export, extract_annotations +) + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([3, 4], torch.float32), + ([4, 5], torch.float32), + ]) + def forward(self, lhs, rhs): + return torch.mm(lhs, rhs) + +module = MmModule() +annotator = torch_mlir.ClassAnnotator() +extract_annotations(module, torch.jit.script(module), annotator) +print(annotator) + +# CHECK: ClassAnnotator { +# CHECK: ClassAnnotation('__torch__.MmModule') { +# CHECK: MethodAnnotation('forward') { +# CHECK: isExported = true +# CHECK: argAnnotations = +# CHECK: ArgAnnotation(0) { +# CHECK: dtype = +# CHECK: shape = +# CHECK: } +# CHECK: ArgAnnotation(1) { +# CHECK: dtype = Float +# CHECK: shape = [3, 4] +# CHECK: } +# CHECK: ArgAnnotation(2) { +# CHECK: dtype = Float +# CHECK: shape = [4, 5] +# CHECK: } +# CHECK: } +# CHECK: } +# CHECK: } diff --git a/frontends/pytorch/test/torchscript_e2e_test/README.md b/frontends/pytorch/test/torchscript_e2e_test/README.md new file mode 100644 index 000000000..2055e3b25 --- /dev/null +++ b/frontends/pytorch/test/torchscript_e2e_test/README.md @@ -0,0 +1,2 @@ +This directory is for testing the e2e_test framework itself. +It is not for holding e2e tests themselves!!! diff --git a/frontends/pytorch/test/torchscript_e2e_test/basic.py b/frontends/pytorch/test/torchscript_e2e_test/basic.py new file mode 100644 index 000000000..502e805c3 --- /dev/null +++ b/frontends/pytorch/test/torchscript_e2e_test/basic.py @@ -0,0 +1,36 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils +from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig + + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, lhs, rhs): + return torch.mm(lhs, rhs) + + +# TODO: Refine messages. +# CHECK: SUCCESS "MmModule_basic" +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + + +def main(): + config = TorchScriptTestConfig() + results = run_tests(GLOBAL_TEST_REGISTRY, config) + report_results(results) + + +if __name__ == '__main__': + main() diff --git a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py new file mode 100644 index 000000000..e3bbc8f22 --- /dev/null +++ b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py @@ -0,0 +1,42 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See frontends/pytorch/LICENSE for license information. + +# RUN: %PYTHON %s | FileCheck %s + +import torch + +from torch_mlir.torchscript.e2e_test.framework import run_tests, report_results, TestUtils +from torch_mlir.torchscript.e2e_test.registry import register_test_case, GLOBAL_TEST_REGISTRY +from torch_mlir.torchscript.e2e_test.configs import TorchScriptTestConfig + + +class MmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, lhs, rhs): + # Use torch.jit.is_scripting() to fake a miscompile. + # The non-scripted code will take one path, and the scripted code + # will take another path. + if torch.jit.is_scripting(): + return torch.mm(rhs, lhs) + return torch.mm(lhs, rhs) + + +# TODO: Refine error messages. +# CHECK: Error: in call #0 into the module: result #0 not close in call to "forward" +# CHECK: FAILURE "MmModule_basic" +@register_test_case(module_factory=lambda: MmModule()) +def MmModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4), tu.rand(4, 4)) + + +def main(): + config = TorchScriptTestConfig() + results = run_tests(GLOBAL_TEST_REGISTRY, config) + report_results(results) + + +if __name__ == '__main__': + main()