mirror of https://github.com/llvm/torch-mlir
170 lines
7.5 KiB
Markdown
170 lines
7.5 KiB
Markdown
# Adding an E2E Test
|
|
|
|
## Overview
|
|
|
|
Adding support for a Torch operator in Torch-MLIR should always be accompanied
|
|
by at least one end-to-end test to make sure the implementation of the op
|
|
matches the behavior of PyTorch. The tests live in the
|
|
`torch-mlir/python/torch_mlir_e2e_test/test_suite/` directory. When adding a new
|
|
test, choose a file that best matches the op you're testing, and if there is no
|
|
file that best matches add a new file for your op.
|
|
|
|
## An E2E Test Deconstructed
|
|
|
|
In order to understand how to create an end-to-end test for your op, let's break
|
|
down an existing test to see what the different parts mean:
|
|
|
|
```python
|
|
class IndexTensorModule3dInput(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
([-1, -1], torch.int64, True),
|
|
])
|
|
def forward(self, x, index):
|
|
return torch.ops.aten.index(x, (index,))
|
|
|
|
|
|
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
|
|
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
|
|
```
|
|
|
|
### Class Name
|
|
|
|
|
|
```python
|
|
class IndexTensorModule3dInput(torch.nn.Module):
|
|
```
|
|
|
|
The class name should always contain the name of the op that is being
|
|
tested. This makes it easy to search for tests for a particular op. Often times
|
|
an op will require multiple tests to make sure different paths in the
|
|
compilation work as expected. In such cases, it is customary to add extra
|
|
information to the class name about what is being tested. In this example, the
|
|
op is being tested with a rank-3 tensor as an input.
|
|
|
|
### `__init__` Method
|
|
|
|
```python
|
|
def __init__(self):
|
|
super().__init__()
|
|
```
|
|
|
|
In most tests, the `__init__` method simply calls the `__init__` method of the
|
|
`torch.nn.Module` class. However, sometimes this method can be used to
|
|
initialize parameters needed in the `forward` method. An example of such a case
|
|
is in the [E2E test for Resnet18](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/test_suite/vision_models.py#L17-L22).
|
|
|
|
|
|
### `@export` and `@annotate_args` Decorators
|
|
|
|
```python
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1, -1, -1], torch.float32, True),
|
|
([-1, -1], torch.int64, True),
|
|
])
|
|
```
|
|
|
|
The [`@export` decorator](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/torchscript/annotations.py#L30)
|
|
lets the importer know which methods in the class will be public after the
|
|
`torch.nn.Module` gets imported into the `torch` dialect. All E2E tests should
|
|
have this decorator on the `forward` method.
|
|
|
|
The [`@annotate_args` decorator](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir_e2e_test/torchscript/annotations.py#L53)
|
|
is used to give the importer information about the arguments of the method being
|
|
decorated, which can then be propagated further into the IR of the body of the
|
|
method. The list of annotations **must** have one annotation for each argument
|
|
including the `self` argument. The `self` argument always gets the annotation of
|
|
`None`, while the other inputs get an annotation with three fields in the
|
|
following order:
|
|
|
|
1. Shape of input tensor. Use `-1` for dynamic dimensions
|
|
2. Dtype of the input tensor
|
|
3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/class_annotator.h#L54-L67). This
|
|
will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the
|
|
IR to eventually have value semantics.
|
|
|
|
From the structure of the annotations for the arguments other than the `self`
|
|
argument it is clear that only tensor arguments are supported. This means that
|
|
if an op requires an input other than a tensor, you need to do one of the
|
|
following:
|
|
|
|
- Create the value in the method body
|
|
- Create the value as a class parameter in the `__init__` method
|
|
- In the case of certain values such as `int`s and `float`s, you can pass a
|
|
zero-rank tensor as an input and use `int(input)` or `float(input)`in the
|
|
method body to turn the tensor into a scalar `int` or `float`, respectively.
|
|
|
|
### `forward` Method
|
|
|
|
```python
|
|
def forward(self, x, index):
|
|
return torch.ops.aten.index(x, (index,))
|
|
```
|
|
|
|
The forward method should be a simple test of your op. In other words, it will
|
|
almost always take the form of simply returning the result of calling your
|
|
op. The call to your op should **always** be made using
|
|
`torch.ops.aten.{op_name}` to make it very clear which ATen op is being
|
|
tested. Some ATen ops have different variants under the same base name, such as
|
|
`aten.mean`, which has also a variant `aten.mean.dim`. At the Python level, such
|
|
ops are accessed by just their base name, and the right variant is chosen based
|
|
on the inputs given. For example, to test `aten.mean.dim` the test should use
|
|
`torch.ops.aten.mean(..., dim=...)`.
|
|
|
|
### `@register_test_case` Decorator
|
|
|
|
```python
|
|
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
|
|
```
|
|
|
|
The `@register_test_case` decorator is used to register the test case
|
|
function. The `module_factory` argument should be a function that when called
|
|
produces an instance of the test class. This function will be used to create the
|
|
first argument passed to the test case function.
|
|
|
|
### Test Case Function
|
|
|
|
```python
|
|
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
|
|
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
|
|
```
|
|
|
|
The convention adopted for the name of the test case function is to have the
|
|
same name as the test class postfixed by `_basic`. The test function always
|
|
takes an instance of the test class as the first argument and a
|
|
[`TestUtils`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L167)
|
|
object as the second argument. The `TestUtils` has some methods, such as
|
|
[`tu.rand`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L182)
|
|
and
|
|
[`tu.randint`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L185),
|
|
that allow the creation of random tensors in a way that makes sure the compiled
|
|
module and the golden trace receive the same tensors as input. Therefore, all
|
|
random inputs should be generated through the `TestUtils` object.
|
|
|
|
|
|
## Things to Consider When Creating New Tests
|
|
|
|
- Do you need negative numbers? If so,
|
|
[`tu.rand`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L182)
|
|
and
|
|
[`tu.randint`](https://github.com/llvm/torch-mlir/blob/8e880a2d009b67d45fb07434ab62ec2066a11185/python/torch_mlir_e2e_test/torchscript/framework.py#L185)
|
|
both allow you to specify a lower and upper bound for random number generation
|
|
- Make sure the annotation of the forward method matches the input types and
|
|
shapes
|
|
- If an op takes optional flag arguments, there should be a test for each flag
|
|
that is supported
|
|
- If there are tricky edge cases that your op needs to handle, have a test for
|
|
each edge case
|
|
- Always follow the style and conventions of the file you're adding a test
|
|
in. An attempt has been made to keep all E2E test files with consistent style,
|
|
but file specific variations do exist
|
|
|