2022-04-12 05:17:44 +08:00
|
|
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
|
|
|
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2022-04-12 05:53:52 +08:00
|
|
|
from framework import run_test
|
2022-04-12 05:17:44 +08:00
|
|
|
from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
# CHECK: PASS - should_normalize
|
2022-04-12 05:53:52 +08:00
|
|
|
@run_test
|
2022-04-12 05:17:44 +08:00
|
|
|
def should_normalize():
|
|
|
|
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
|
|
|
args = (torch.randn((1, 3, 32, 32)),)
|
|
|
|
kwargs = {"kernel_size": [3, 3]}
|
|
|
|
golden = {
|
|
|
|
"kernel_size": [3, 3],
|
|
|
|
# This is due to the schema for max_pool2d_with_indices defining
|
|
|
|
# the stride arg as int[2] stride=[].
|
|
|
|
"stride": [],
|
|
|
|
"padding": [0, 0],
|
|
|
|
"dilation": [1, 1],
|
|
|
|
"ceil_mode": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs)
|
|
|
|
for arg, new_arg in zip(args, new_args):
|
|
|
|
assert torch.allclose(arg, new_arg)
|
|
|
|
for k, v in new_kwargs.items():
|
|
|
|
assert v == golden[k]
|
|
|
|
|
|
|
|
|
|
|
|
# CHECK: FAIL - shouldnt_normalize1
|
|
|
|
# CHECK: Couldn't normalize args and kwargs
|
2022-04-12 05:53:52 +08:00
|
|
|
@run_test
|
2022-04-12 05:17:44 +08:00
|
|
|
def shouldnt_normalize1():
|
|
|
|
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
|
|
|
args = (torch.randn((1, 3, 32, 32)),)
|
|
|
|
kwargs = {"stride": []}
|
|
|
|
normalize_args_kwargs(target, args, kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
|
|
|
|
# I.e., they should fail but in fact they pass because of the upstream bug.
|
|
|
|
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
|
|
|
|
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
|
|
|
|
# TODO(max): change these to FAIL when the upstream bug is fixed.
|
|
|
|
|
|
|
|
# CHECK: XPASS - shouldnt_normalize2
|
2022-04-12 05:53:52 +08:00
|
|
|
@run_test(XPASS=True)
|
2022-04-12 05:17:44 +08:00
|
|
|
def shouldnt_normalize2():
|
|
|
|
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
|
|
|
args = (torch.randn((1, 3, 32, 32)),)
|
|
|
|
kwargs = {"kernel_size": []}
|
|
|
|
normalize_args_kwargs(target, args, kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
# CHECK: XPASS - shouldnt_normalize3
|
2022-04-12 05:53:52 +08:00
|
|
|
@run_test(XPASS=True)
|
2022-04-12 05:17:44 +08:00
|
|
|
def shouldnt_normalize3():
|
|
|
|
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
|
|
|
|
args = (torch.randn((1, 3, 32, 32)),)
|
|
|
|
kwargs = {"kernel_size": [3, 3], "padding": None}
|
|
|
|
normalize_args_kwargs(target, args, kwargs)
|