Fix single-element tuple construction in abstract interp library (#2258)

Single element tuples in Python need a comma after the
element. However, the `registry.py` file, which generates the expected
abstract interpretation function signatures, was not inserting the
comma. This commit changes the expected signature generator to add a
comma after the last element in any non-empty default tuple argument.
pull/2260/head
Ramiro Leal-Cavazos 2023-06-22 11:27:40 -07:00 committed by GitHub
parent 96b14e952e
commit 6f2bf31291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 12 deletions

View File

@ -465,10 +465,10 @@ def aten_unsafe_view〡shape(self: List[int], size: List[int]) -> List[int]:
def atenresize_〡shape(self: List[int], size: List[int], memory_format: Optional[int] = None) -> List[int]:
return size
def atenmax_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> List[int]:
def atenmax_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]:
return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
def atenmax_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
def atenmax_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
maxpool2d = indices = upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode)
return maxpool2d, indices
@ -522,7 +522,7 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd
else:
return [nbatch, nInputPlane, outputHeight, outputWidth]
def atenavg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
def atenavg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]:
return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
def atenadaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]:
@ -859,10 +859,10 @@ def atenview_as_complex〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
else:
assert False, "Unsupported dtype"
def atenconv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
def atenconv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]:
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)
def atenconv_transpose2dinput〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> List[int]:
def atenconv_transpose2dinput〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]:
return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation)
def atenconvolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]:
@ -1350,7 +1350,7 @@ def atenadaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def atenavg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
def atenavg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1631,12 +1631,12 @@ def atenmasked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dty
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def atenmax_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int:
def atenmax_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def atenmax_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]:
def atenmax_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64
@ -2326,7 +2326,7 @@ def aten_convolutiondeprecated〡dtype(input_rank_dtype: Tuple[int, int],
Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)),
Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16))
])
def atenconv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int:
def atenconv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype
@ -2337,7 +2337,7 @@ def atenconv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype:
Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)),
Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16))
])
def atenconv_transpose2dinput〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int:
def atenconv_transpose2dinput〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype

View File

@ -32,8 +32,14 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
# testing against the real ops, and tuples work fine in all
# the places this kicks in (e.g. conv dilations -- we aren't
# mutating those lists).
default_debug = arg["default_debug"].replace(
'[', '(').replace(']', ')')
default_list = arg["default_debug"]
# (,) is not a valid empty tuple contruction in Python, so
# we must handle the emtpy case separately.
if default_list == "[]":
default_debug = "()"
else:
default_debug = default_list.replace(
"[", "(").replace("]", ",)")
elif arg["pytype"] == "str":
default_debug = repr(arg["default_debug"]).replace("'", '"')
else: