From f8e75f659df159c39f49a25825f0bd51bbd52b1c Mon Sep 17 00:00:00 2001 From: Matthias Gehre <93204396+mgehre-amd@users.noreply.github.com> Date: Thu, 13 Jul 2023 15:07:54 +0200 Subject: [PATCH] Add make_fx_tosa variant to end2end tests (#2240) * Add make_fx_tosa variant to end2end tests * e2e_testing/xfail_sets.py: Add make_fx_tosa xfail for stable --- .../python_deploy/build_linux_packages.sh | 5 ++- e2e_testing/main.py | 7 +++- e2e_testing/xfail_sets.py | 36 +++++++++++++++++++ python/torch_mlir/__init__.py | 19 ++++++++-- .../configs/tosa_backend.py | 5 +-- 5 files changed, 65 insertions(+), 7 deletions(-) diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index bd9c0cee6..f00423ffc 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -300,7 +300,10 @@ function test_in_tree() { exit 1 ;; esac - + + echo ":::: Run make_fx + TOSA e2e integration tests" + python -m e2e_testing.main --config=make_fx_tosa -v + echo ":::: Run TorchDynamo e2e integration tests" python -m e2e_testing.main --config=torchdynamo -v diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 13e7ba7c8..3893edee4 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -29,6 +29,7 @@ from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsT from .xfail_sets import ( LINALG_XFAIL_SET, + MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, @@ -42,7 +43,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -94,6 +95,10 @@ def main(): config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET crashing_set = set() + elif args.config == "make_fx_tosa": + config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True) + xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET + crashing_set = set() elif args.config == "stablehlo": config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) xfail_set = all_test_unique_names - STABLEHLO_PASS_SET diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 3f21ad6ff..6a9a981c7 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -11,6 +11,7 @@ # might be used to keep more elaborate sets of testing configurations). from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS +from torch_mlir._version import torch_version_for_comparison, version LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS @@ -1113,6 +1114,41 @@ TOSA_PASS_SET = { "ChunkListUnpackUneven_Module_basic", } +MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { +### Tests additionally passing in make_fx_tosa + "NativeGroupNormBackwardModule_basic", + "TensorFloatModule_basic", + "TensorIntModule_basic", +}) - { +### Test failing in make_fx_tosa but not in tosa + + # failed to lower torch.aten.empty.memory_format + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BatchNorm1DStaticShapeModule_basic", + + # Dynamic shape, has extra unsupported broadcast ops + "Matmul_3d", + + # failed to legalize operation 'torch.aten.max_pool2d_with_indices + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticModule_basic", + "ResNet18StaticModule_basic", + + # Unimplemented operator 'aten._index_put_impl_.hacked_twin' + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", +} + +if torch_version_for_comparison() < version.parse("2.1.0.dev"): + MAKE_FX_TOSA_PASS_SET -= { + # 'tensor.expand_shape' op expected rank expansion, but found source rank 1 >= result rank 1 + "ReshapeCollapseModule_basic", + } + LTC_CRASHING_SET = { # https://github.com/llvm/torch-mlir/issues/2186 "Add_Module_basic" diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 836d3fdfc..8de6cc1a1 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -13,6 +13,8 @@ import tempfile from torch._functorch.compile_utils import strip_overloads import torch import torch.fx +from torch_mlir.dynamo import _get_decomposition_table +from torch.fx.experimental.proxy_tensor import make_fx from .compiler_utils import run_pipeline_with_repro_report from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder @@ -225,8 +227,11 @@ class ExampleArgs: # they know what they are doing and that their trace is # correct for any specific concrete size. shape = [s if s != -1 else 7 for s in arg.shape] - example_args_for_trace.append( - torch.ones(*shape, dtype=arg.dtype)) + if len(shape) == 0: + example_args_for_trace.append(torch.tensor(1)) + else: + example_args_for_trace.append( + torch.ones(*shape, dtype=arg.dtype)) else: assert isinstance(arg, torch.Tensor) example_args_for_trace.append(arg) @@ -313,7 +318,8 @@ def compile(model: torch.nn.Module, ignore_traced_shapes=False, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], - verbose: bool = False): + verbose: bool = False, + use_make_fx: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -367,6 +373,13 @@ def compile(model: torch.nn.Module, else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + if use_make_fx: + args = example_args._get_for_tracing(use_tracing=True, ignore_traced_shapes=True)["forward"] + model = make_fx( + model, + decomposition_table=_get_decomposition_table())(*args) + + # For FX-based models, automatically strip overloads. if isinstance(model, torch.fx.GraphModule): strip_overloads(model) diff --git a/python/torch_mlir_e2e_test/configs/tosa_backend.py b/python/torch_mlir_e2e_test/configs/tosa_backend.py index 8b41cfeda..89b90567b 100644 --- a/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -23,14 +23,15 @@ class TosaBackendTestConfig(TestConfig): This class handles all the common lowering that torch-mlir does before reaching the linalg-on-tensors abstraction level. """ - def __init__(self, backend: TosaBackend): + def __init__(self, backend: TosaBackend, use_make_fx: bool = False): super().__init__() self.backend = backend + self.use_make_fx = use_make_fx def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torch_mlir.compile( - program, example_args, output_type="tosa") + program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) return self.backend.compile(module)