From 18ef40acaf39b1f32dd1e292ce3883a1dad76ed1 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 11 Apr 2022 16:17:44 -0500 Subject: [PATCH] Fixes a bug in use of upstream `normalize_function` in our `normalize_args_kwargs` (in eager mode) and introduces unit tests. (#740) NB: `shouldnt_normalize2` and `shouldnt_normalize3` currently XPASS i.e., args *will* successfully normalize despite being incorrect due to an [upstream bug](https://github.com/pytorch/pytorch/issues/75342). --- .../test/eager_mode/normalize_args_kwargs.py | 84 +++++++++++++++++++ .../eager_mode/torch_mlir_dispatch.py | 6 +- 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 python/test/eager_mode/normalize_args_kwargs.py diff --git a/python/test/eager_mode/normalize_args_kwargs.py b/python/test/eager_mode/normalize_args_kwargs.py new file mode 100644 index 000000000..2949c8fa8 --- /dev/null +++ b/python/test/eager_mode/normalize_args_kwargs.py @@ -0,0 +1,84 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + + +import torch + +from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs + + +def run_test(test, XFAIL=False, XPASS=False): + try: + test() + print(("X" if XPASS else "") + f"PASS - {test.__name__}") + except Exception as e: + print(("X" if XFAIL else "") + f"FAIL - {test.__name__}") + print(e) + + +# CHECK: PASS - should_normalize +def should_normalize(): + target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket + args = (torch.randn((1, 3, 32, 32)),) + kwargs = {"kernel_size": [3, 3]} + golden = { + "kernel_size": [3, 3], + # This is due to the schema for max_pool2d_with_indices defining + # the stride arg as int[2] stride=[]. + "stride": [], + "padding": [0, 0], + "dilation": [1, 1], + "ceil_mode": False, + } + + new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs) + for arg, new_arg in zip(args, new_args): + assert torch.allclose(arg, new_arg) + for k, v in new_kwargs.items(): + assert v == golden[k] + + +# CHECK: FAIL - shouldnt_normalize1 +# CHECK: Couldn't normalize args and kwargs +def shouldnt_normalize1(): + target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket + args = (torch.randn((1, 3, 32, 32)),) + kwargs = {"stride": []} + normalize_args_kwargs(target, args, kwargs) + + +# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342 +# I.e., they should fail but in fact they pass because of the upstream bug. +# The reason for the bug is a fast path branch in operator_schemas.normalize_function +# that doesn't do rigorous type checking, and hence lets type mistmatches slip through. +# TODO(max): change these to FAIL when the upstream bug is fixed. + +# CHECK: XPASS - shouldnt_normalize2 +def shouldnt_normalize2(): + target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket + args = (torch.randn((1, 3, 32, 32)),) + kwargs = {"kernel_size": []} + normalize_args_kwargs(target, args, kwargs) + + +# CHECK: XPASS - shouldnt_normalize3 +def shouldnt_normalize3(): + target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket + args = (torch.randn((1, 3, 32, 32)),) + kwargs = {"kernel_size": [3, 3], "padding": None} + normalize_args_kwargs(target, args, kwargs) + + +def main(): + run_test(should_normalize) + run_test(shouldnt_normalize1) + run_test(shouldnt_normalize2, XPASS=True) + run_test(shouldnt_normalize3, XPASS=True) + + +if __name__ == "__main__": + main() diff --git a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py index a9e72b72a..7f57d6081 100644 --- a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py +++ b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py @@ -57,8 +57,10 @@ def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str, arg_types = map_aggregate(args, type) assert isinstance(arg_types, tuple) - arg_types = tuple([create_type_hint(i) for i in arg_types]) - kwarg_types = {k: type(v) for k, v in kwargs.items()} + arg_types = map_aggregate(map_aggregate(args, type), create_type_hint) + kwarg_types = { + k: create_type_hint(map_aggregate(v, type)) for k, v in kwargs.items() + } new_args_and_kwargs = normalize_function( target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False