mirror of https://github.com/llvm/torch-mlir
Fixes a bug in use of upstream `normalize_function` in our `normalize_args_kwargs` (in eager mode) and introduces unit tests. (#740)
NB: `shouldnt_normalize2` and `shouldnt_normalize3` currently XPASS i.e., args *will* successfully normalize despite being incorrect due to an [upstream bug](https://github.com/pytorch/pytorch/issues/75342).pull/745/head
parent
9ec0683e92
commit
18ef40acaf
|
@ -0,0 +1,84 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
||||
|
||||
|
||||
def run_test(test, XFAIL=False, XPASS=False):
|
||||
try:
|
||||
test()
|
||||
print(("X" if XPASS else "") + f"PASS - {test.__name__}")
|
||||
except Exception as e:
|
||||
print(("X" if XFAIL else "") + f"FAIL - {test.__name__}")
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK: PASS - should_normalize
|
||||
def should_normalize():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": [3, 3]}
|
||||
golden = {
|
||||
"kernel_size": [3, 3],
|
||||
# This is due to the schema for max_pool2d_with_indices defining
|
||||
# the stride arg as int[2] stride=[].
|
||||
"stride": [],
|
||||
"padding": [0, 0],
|
||||
"dilation": [1, 1],
|
||||
"ceil_mode": False,
|
||||
}
|
||||
|
||||
new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs)
|
||||
for arg, new_arg in zip(args, new_args):
|
||||
assert torch.allclose(arg, new_arg)
|
||||
for k, v in new_kwargs.items():
|
||||
assert v == golden[k]
|
||||
|
||||
|
||||
# CHECK: FAIL - shouldnt_normalize1
|
||||
# CHECK: Couldn't normalize args and kwargs
|
||||
def shouldnt_normalize1():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"stride": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
|
||||
# I.e., they should fail but in fact they pass because of the upstream bug.
|
||||
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
|
||||
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
|
||||
# TODO(max): change these to FAIL when the upstream bug is fixed.
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize2
|
||||
def shouldnt_normalize2():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": []}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
# CHECK: XPASS - shouldnt_normalize3
|
||||
def shouldnt_normalize3():
|
||||
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
||||
args = (torch.randn((1, 3, 32, 32)),)
|
||||
kwargs = {"kernel_size": [3, 3], "padding": None}
|
||||
normalize_args_kwargs(target, args, kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
run_test(should_normalize)
|
||||
run_test(shouldnt_normalize1)
|
||||
run_test(shouldnt_normalize2, XPASS=True)
|
||||
run_test(shouldnt_normalize3, XPASS=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -57,8 +57,10 @@ def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str,
|
|||
|
||||
arg_types = map_aggregate(args, type)
|
||||
assert isinstance(arg_types, tuple)
|
||||
arg_types = tuple([create_type_hint(i) for i in arg_types])
|
||||
kwarg_types = {k: type(v) for k, v in kwargs.items()}
|
||||
arg_types = map_aggregate(map_aggregate(args, type), create_type_hint)
|
||||
kwarg_types = {
|
||||
k: create_type_hint(map_aggregate(v, type)) for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
new_args_and_kwargs = normalize_function(
|
||||
target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
|
||||
|
|
Loading…
Reference in New Issue