torch-mlir/frontends/pytorch/e2e_testing/torchscript/main.py

85 lines
3.0 KiB
Python
Raw Normal View History

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import re
import sys
from torch_mlir.torchscript.e2e_test.framework import run_tests
from torch_mlir.torchscript.e2e_test.reporting import report_results
from torch_mlir.torchscript.e2e_test.registry import GLOBAL_TEST_REGISTRY
# Available test configs.
from torch_mlir.torchscript.e2e_test.configs import (
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
)
# Import tests to register them in the global registry.
# TODO: Use a relative import.
# That requires invoking this file as a "package" though, which makes it
# not possible to just do `python main.py`. Instead, it requires something
# like `python -m tochscript.main` which is annoying because it can only
# be run from a specific directory.
# TODO: Find out best practices for python "main" files.
import basic
import vision_models
import mlp
import batchnorm
Get simple quantized model importing. This is enough to import the program and get it through the compilation pipeline. It of course fails at the VerifyBackendContract pass since there is a lot missing, but the final IR for a simple quantized MLP is looking pretty decent already: [IR](https://gist.github.com/silvasean/f76bccd76e9b193d396cfb2f9a11f54d) Main changes: - Add support for importing torch quantized tensors, including `torch.per_tensor_affine.create` op and `!torch.qint8` element type. - Add support for importing `LinearPackedParamsBase` (basically a weight + optional bias, but requires `torch.linear_params.create` op + `!torch.LinearParams` type to model it). This was less painful than I expected, as it has the necessary methods to opaquely unpack itself. I factored things so it should be easy to extend to other custom classes like `ConvPackedParamsBase`. - Add minimal boilerplate for importing `quantized::*` ops, with `quantized::linear` being a motivating example. - Add e2e test with simple quantized MLP (courtesy of @phoenix-meadowlark). This is somewhat of an abuse of `!numpy.ndarray` / `tensor`, as really the proper semantics of `!torch.qint8` dtype on a Torch tensor is "check the quantizer object of the tensor for side data (scale/offset, possibly per-channel) that defines the full semantics of the tensor". We don't have any such notion of "side data" for `!numpy.ndarray` / `tensor`, let alone anything that would have the associated behavior of keying off the dtype to determine if the side data is present. This will be fixed by a proper `!torch.tensor` type.
2021-05-20 02:40:48 +08:00
import quantized_models
Generalize support for elementwise ops. We plumb through e2e a fair number of interesting cases: - unary, binary, ternary elementwise ops - ops like `torch.aten.add.Tensor` that also take a scalar parameter - static size-1 broadcasting We allow the static size-1 broadcasting case, but emit a runtime error in the case of dynamic size-1 broadcasting. This seems like a sweet spot subset of things that can be lowered directly to linalg, while not being overly constraining to users. This is consistent with what IREE is doing for CHLO->Linalg lowering as well ([code](https://github.com/google/iree/blob/50bf7a87e465d2048c527bc27724edde40519b7e/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp#L1)). To test the static size-1 case, we added support for the `torch.aten.unsqueeze` op and lowering for it through `linalg.tensor_expand_shape`. This involved a generalization of `MaximizeValueSemantics` able to handle it (the solution there also works for `torch.aten.flatten.using_ints` which we need for ResNet anyway) Also, a few minor additional changes: - Add `VerifyInvariantsBeforeBackendLowering` pass, which catches a large class of errors before we get to backend lowering (now that we are doing dialect conversion, the errors are way nicer if we just emit them up front rather than in the guts of a random pattern). - Minor change to RefBackend to allow `linalg.tensor_expand_shape`. Recommended review order: - e2e tests in elementwise.py - `ConvertElementwiseOp` in TorchToLinalg.cpp + elementwise.mlir test - `ConvertAtenUnsqueezeOp` in TorchToLinalg.cpp + unsqueeze.mlir test - RefineTypes.cpp + tests - MaximizeValueSemantics changes + test - VerifyInvariantsBeforeBackendLowering pass + test
2021-06-26 08:25:09 +08:00
import elementwise
def _get_argparse():
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'],
default='refbackend',
help='''
Meaning of options:
"refbackend": run through npcomp's RefBackend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
parser.add_argument('--filter', default='.*', help='''
Regular expression specifying which tests to include in this run.
''')
parser.add_argument('-v', '--verbose',
default=False,
action='store_true',
help='report test results with additional detail')
return parser
def main():
args = _get_argparse().parse_args()
# Find the selected config.
if args.config == 'refbackend':
config = RefBackendTestConfig()
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
elif args.config == 'torchscript':
config = TorchScriptTestConfig()
# Find the selected tests, and emit a diagnostic if none are found.
tests = [
test for test in GLOBAL_TEST_REGISTRY
if re.match(args.filter, test.unique_name)
]
if len(tests) == 0:
print(
f'ERROR: the provided filter {args.filter!r} does not match any tests'
)
print('The available tests are:')
for test in GLOBAL_TEST_REGISTRY:
print(test.unique_name)
sys.exit(1)
# Run the tests.
results = run_tests(tests, config)
# Report the test results.
report_results(results, args.verbose)
if __name__ == '__main__':
main()