7.5 KiB
Adding an E2E Test
Overview
Adding support for a Torch operator in Torch-MLIR should always be accompanied
by at least one end-to-end test to make sure the implementation of the op
matches the behavior of PyTorch. The tests live in the
torch-mlir/python/torch_mlir_e2e_test/test_suite/
directory. When adding a new
test, choose a file that best matches the op you're testing, and if there is no
file that best matches add a new file for your op.
An E2E Test Deconstructed
In order to understand how to create an end-to-end test for your op, let's break down an existing test to see what the different parts mean:
class IndexTensorModule3dInput(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, index):
return torch.ops.aten.index(x, (index,))
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
Class Name
class IndexTensorModule3dInput(torch.nn.Module):
The class name should always contain the name of the op that is being tested. This makes it easy to search for tests for a particular op. Often times an op will require multiple tests to make sure different paths in the compilation work as expected. In such cases, it is customary to add extra information to the class name about what is being tested. In this example, the op is being tested with a rank-3 tensor as an input.
__init__
Method
def __init__(self):
super().__init__()
In most tests, the __init__
method simply calls the __init__
method of the
torch.nn.Module
class. However, sometimes this method can be used to
initialize parameters needed in the forward
method. An example of such a case
is in the E2E test for Resnet18.
@export
and @annotate_args
Decorators
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.int64, True),
])
The @export
decorator
lets the importer know which methods in the class will be public after the
torch.nn.Module
gets imported into the torch
dialect. All E2E tests should
have this decorator on the forward
method.
The @annotate_args
decorator
is used to give the importer information about the arguments of the method being
decorated, which can then be propagated further into the IR of the body of the
method. The list of annotations must have one annotation for each argument
including the self
argument. The self
argument always gets the annotation of
None
, while the other inputs get an annotation with three fields in the
following order:
- Shape of input tensor. Use
-1
for dynamic dimensions - Dtype of the input tensor
- Boolean representing whether the input tensor has value semantics. This will always be true for E2E tests, since the Torch-MLIR backend contract requires all tensors in the IR to eventually have value semantics.
From the structure of the annotations for the arguments other than the self
argument it is clear that only tensor arguments are supported. This means that
if an op requires an input other than a tensor, you need to do one of the
following:
- Create the value in the method body
- Create the value as a class parameter in the
__init__
method - In the case of certain values such as
int
s andfloat
s, you can pass a zero-rank tensor as an input and useint(input)
orfloat(input)
in the method body to turn the tensor into a scalarint
orfloat
, respectively.
forward
Method
def forward(self, x, index):
return torch.ops.aten.index(x, (index,))
The forward method should be a simple test of your op. In other words, it will
almost always take the form of simply returning the result of calling your
op. The call to your op should always be made using
torch.ops.aten.{op_name}
to make it very clear which ATen op is being
tested. Some ATen ops have different variants under the same base name, such as
aten.mean
, which has also a variant aten.mean.dim
. At the Python level, such
ops are accessed by just their base name, and the right variant is chosen based
on the inputs given. For example, to test aten.mean.dim
the test should use
torch.ops.aten.mean(..., dim=...)
.
@register_test_case
Decorator
@register_test_case(module_factory=lambda: IndexTensorModule3dInput())
The @register_test_case
decorator is used to register the test case
function. The module_factory
argument should be a function that when called
produces an instance of the test class. This function will be used to create the
first argument passed to the test case function.
Test Case Function
def IndexTensorModule3dInput_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.randint(2, 3, high=3))
The convention adopted for the name of the test case function is to have the
same name as the test class postfixed by _basic
. The test function always
takes an instance of the test class as the first argument and a
TestUtils
object as the second argument. The TestUtils
has some methods, such as
tu.rand
and
tu.randint
,
that allow the creation of random tensors in a way that makes sure the compiled
module and the golden trace receive the same tensors as input. Therefore, all
random inputs should be generated through the TestUtils
object.
Things to Consider When Creating New Tests
- Do you need negative numbers? If so,
tu.rand
andtu.randint
both allow you to specify a lower and upper bound for random number generation - Make sure the annotation of the forward method matches the input types and shapes
- If an op takes optional flag arguments, there should be a test for each flag that is supported
- If there are tricky edge cases that your op needs to handle, have a test for each edge case
- Always follow the style and conventions of the file you're adding a test in. An attempt has been made to keep all E2E test files with consistent style, but file specific variations do exist