diff --git a/python/test/eager_mode/annotate_args.py b/python/test/eager_mode/annotate_args.py new file mode 100644 index 000000000..c933e4960 --- /dev/null +++ b/python/test/eager_mode/annotate_args.py @@ -0,0 +1,99 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + + +import torch + +from framework import run_test +from torch_mlir.eager_mode.torch_mlir_dispatch import ( + annotate_args_kwargs, + normalize_args_kwargs, + build_script_function, +) + + +# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(1, 3, 32, 32), dtype=torch.float32) +# ----- +# CHECK: PASS - simple +@run_test +def simple(): + target = torch.ops.aten.addmm.default + A = torch.randn(1, 3, 32, 32) + B = torch.randn(1, 3, 32, 32) + C = torch.randn(1, 3, 32, 32) + args = (A, B, C) + kwargs = dict(beta=1, alpha=1) + + new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs) + script_fun = build_script_function(target._schema, new_args, new_kwargs) + annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs) + for annot in annotations: + print(annot) + + +# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(-1, 3, 32, 32), dtype=torch.float32) +# ----- +# CHECK: PASS - handle_zero_dim +@run_test +def handle_zero_dim(): + target = torch.ops.aten.addmm.default + A = torch.randn(0, 3, 32, 32) + B = torch.randn(0, 3, 32, 32) + C = torch.randn(0, 3, 32, 32) + args = (A, B, C) + kwargs = dict(beta=1, alpha=1) + + new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs) + script_fun = build_script_function(target._schema, new_args, new_kwargs) + annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs) + for annot in annotations: + print(annot) + + +# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(2, 5, 2, 3), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# CHECK: Torch Tensor (shape=(5,), dtype=torch.float32) +# ----- +# CHECK: PASS - correctly_order_kwargs +@run_test +def correctly_order_kwargs(): + target = torch.ops.aten.native_batch_norm.out + + input = torch.randn(2, 5, 2, 3) + weight = torch.randn(5) + bias = torch.randn(5) + running_mean = torch.randn(5) + running_var = torch.randn(5) + args = (input, weight, bias, running_mean, running_var) + + out = torch.empty_like(input) + save_mean = torch.empty_like(running_mean) + save_invstd = torch.empty_like(running_var) + + kwargs = dict( + training=False, + momentum=0.1, + eps=0.0001, + out=out, + save_mean=save_mean, + save_invstd=save_invstd, + ) + + new_args, new_kwargs = normalize_args_kwargs(target.overloadpacket, args, kwargs) + script_fun = build_script_function(target._schema, new_args, new_kwargs) + annotations, *_ = annotate_args_kwargs(script_fun, new_args, new_kwargs) + for annot in annotations: + print(annot) diff --git a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py index ad2b69f40..ca4ac0828 100644 --- a/python/torch_mlir/eager_mode/torch_mlir_dispatch.py +++ b/python/torch_mlir/eager_mode/torch_mlir_dispatch.py @@ -168,8 +168,7 @@ def annotate_args_kwargs( if isinstance(arg, np.ndarray): tensor_kwargs[arg_idxs[kw]] = (arg, normalized_kwargs[kw].dtype) - for i in range(len(tensor_kwargs)): - arg, arg_dtype = tensor_kwargs[i] + for _i, (arg, arg_dtype) in sorted(tensor_kwargs.items()): annotations.append(TorchTensorType(shape=tuple(arg.shape), dtype=arg_dtype)) tensor_kwargs_flat.append(arg)