@ -12,7 +12,7 @@ from torch import device
import torch . jit . _shape_functions as upstream_shape_functions
from . testing_framework import Invocation , ErrorInvocation , TensorOfShape , LongTensorOfShape , NonZeroDTensorWithDtype , ZeroDTensorWithDtype , check_shape_function , check_dtype_function
from . library_generator import generate_library , not_present_in_registry , promote_dtypes , get_dtype_of_scalar
from . library_generator import generate_library , not_present_in_registry , promote_dtypes , get_dtype_of_scalar , is_integer_dtype , is_float_dtype , is_complex_dtype
# ==============================================================================
# Shape Functions
@ -1017,180 +1017,295 @@ def aten〇 upsample_nearest2d〡shape(self: List[int], output_size: List[int], s
# Dtype Functions
# ==============================================================================
def _get_invocations_for_op_with_tensor_arg_followed_by ( * args ) :
""" Generate invocations that thoroughly test the first tensor arg of the op.
# All the torch types sorted in decreasing order of priority during type promotion.
_SORTED_TORCH_TYPES = [
torch . complex128 , torch . complex64 ,
torch . float64 , torch . float32 , torch . float16 , torch . bfloat16 ,
torch . int64 , torch . int32 , torch . int16 , torch . int8 , torch . uint8 , torch . bool
]
This is meant to be used by ops where the entire dtype computation involves
at most the first tensor argument of the op . If an dtype function uses other
arguments , custom invocations should be created to test the logic of the
dtype function instead of using this helper function .
def _check_tensors_with_the_same_dtype (
num_of_tensors : Optional [ int ] = None ,
tensor_shapes : Optional [ list [ tuple [ int ] ] ] = None ,
error_types : Optional [ set [ int ] ] = None , * args , * * kwargs ) :
""" Create invocations where all tensors have the same dtype.
This function generates invocations with ` num_of_tensors ` tensors
that all have the same dtype . It creates an invocation for every
possible dtype . For dtypes in ` error_types ` , the invocations are
error invocations .
One can also specify the shapes of the tensors . Either ` num_of_tensors `
or ` tensor_shapes ` must be specified whenever this function is called .
The extra * args and * * kwargs arguments are passed to the invocations .
"""
return [
Invocation ( NonZeroDTensorWithDtype ( torch . float32 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . float64 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . bfloat16 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int64 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int32 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . bool ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float32 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float64 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . bfloat16 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . int64 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . int32 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . bool ) , * args ) ,
]
invocations = [ ]
for type_ in _SORTED_TORCH_TYPES :
tensors = [ ]
if tensor_shapes is None and num_of_tensors is not None :
tensors = [ NonZeroDTensorWithDtype ( type_ ) ] * num_of_tensors
elif tensor_shapes is not None and num_of_tensors is None :
for tensor_shape in tensor_shapes :
tensors . append ( TensorOfShape ( * tensor_shape , dtype = type_ ) )
else :
assert False , \
" Either `num_of_tensors` or `tensor_shapes` must be specified "
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by ( * args ) :
""" Generate invocations for floating point only op. """
return [
Invocation ( NonZeroDTensorWithDtype ( torch . float32 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . float64 ) , * args ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . bfloat16 ) , * args ) ,
ErrorInvocation ( NonZeroDTensorWithDtype ( torch . int64 ) , * args ) ,
ErrorInvocation ( NonZeroDTensorWithDtype ( torch . int32 ) , * args ) ,
ErrorInvocation ( NonZeroDTensorWithDtype ( torch . bool ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float32 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float64 ) , * args ) ,
Invocation ( ZeroDTensorWithDtype ( torch . bfloat16 ) , * args ) ,
ErrorInvocation ( ZeroDTensorWithDtype ( torch . int64 ) , * args ) ,
ErrorInvocation ( ZeroDTensorWithDtype ( torch . int32 ) , * args ) ,
ErrorInvocation ( ZeroDTensorWithDtype ( torch . bool ) , * args ) ,
]
if error_types is not None and type_ in error_types :
invocations . append ( ErrorInvocation ( * tensors , * args , * * kwargs ) )
else :
invocations . append ( Invocation ( * tensors , * args , * * kwargs ) )
return invocations
def _check_two_tensor_op (
input_error_types : Optional [ set [ int ] ] = None ,
output_error_types : Optional [ set [ int ] ] = None , * * kwargs ) :
""" Generate invocations for basic two-tensor dtype functions.
This helper function is meant to be used to check dtype functions that
take two tensor operands and either return the promoted result or
return a constant dtype based on the tensor dtypes .
The testing performed is thorough enough to be able to detect if dtypes
are invalid as inputs or as outputs to the PyTorch op . Invalid dtypes
must be specified in ` input_error_types ` and ` output_error_types ` to
ensure the invocations are error invocations .
"""
if input_error_types is not None and output_error_types is not None :
assert len ( input_error_types . intersection ( output_error_types ) ) == 0 , \
" An invalid input type implies an invalid output type, " \
" so there is no need to repeat the type in the `output_error_types` set "
all_error_types = set ( )
all_error_types | = set ( ) if input_error_types is None else input_error_types
all_error_types | = set ( ) if output_error_types is None else output_error_types
def check_two_tensors_with_one_varying_dtype_at_a_time ( * * kwargs ) :
""" Create invocations where one tensor varies its dtype.
This helper function creates invocations with two tensors where one
tensor varies its dtype while the other one stays constant . The varying
is done for both tensors and the varying is performed over every possible
dtype .
This function helps identify when a dtype is an invalid input dtype
for dtype functions that do promotion .
"""
# We will only create invocations for dtypes with priorities less than
# or equal to the highest priority valid type. By setting the non-varying
# tensor dtype to be the highest priority valid type, we ensure that
# every promotion results in a valid dtype. This allows the invocations
# to test in isolation assertions on input types.
constant_type = None
constant_type_index = None
for e , type_ in enumerate ( _SORTED_TORCH_TYPES ) :
if type_ not in all_error_types :
constant_type = type_
constant_type_index = e
break
assert constant_type is not None , \
" Unable to find a constant type. Make sure the union of " \
" `input_error_types` and `output_error_types` is not all possible types. "
invocations = [ ]
for type_ in _SORTED_TORCH_TYPES [ constant_type_index : ] :
tensor_1 = NonZeroDTensorWithDtype ( type_ )
tensor_2 = NonZeroDTensorWithDtype ( constant_type )
if input_error_types is not None and type_ in input_error_types :
invocations + = [ ErrorInvocation ( tensor_1 , tensor_2 , * * kwargs ) ,
ErrorInvocation ( tensor_2 , tensor_1 , * * kwargs ) ]
else :
invocations + = [ Invocation ( tensor_1 , tensor_2 , * * kwargs ) ,
Invocation ( tensor_2 , tensor_1 , * * kwargs ) ]
return invocations
same_dtype_invocations = _check_tensors_with_the_same_dtype (
num_of_tensors = 2 , error_types = all_error_types , * * kwargs )
varying_dtype_invocations = \
check_two_tensors_with_one_varying_dtype_at_a_time ( * * kwargs )
return same_dtype_invocations + varying_dtype_invocations
def _get_dtype_of_floating_point_op ( input_dtype : int ) - > int :
if input_dtype in ( torch . float64 , torch . bfloat16 , torch . float16 ) :
if ( is_float_dtype ( input_dtype ) and input_dtype != torch . float32 ) \
or is_complex_dtype ( input_dtype ) :
return input_dtype
return torch . float32
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 tanh〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 exp〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . float16 , torch . complex64 , torch . complex128 } ) )
def aten〇 expm1〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert self_dtype != torch . float16 , " `self` cannot have float16 dtype "
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 sin〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 cos〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 sigmoid〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 ) )
def aten〇 reciprocal〡dtype ( self_rank : int , self_dtype : int ) - > int :
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 sqrt〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 log〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 log2〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 log1p〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def aten〇 rsqrt〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 , torch . complex64 , torch . complex128 } ) )
def aten〇 erf〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) and self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_fp_only_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . bool , torch . uint8 , torch . int8 , torch . int16 , torch . int32 ,
torch . int64 , torch . float16 , torch . complex64 , torch . complex128 } ) )
def aten〇 softplus〡dtype ( self_rank : int , self_dtype : int , beta : Union [ int , float ] = 1 , threshold : Union [ int , float ] = 20 ) - > int :
assert self_dtype not in ( torch . int64 , torch . int32 , torch . bool )
assert not is_complex_dtype ( self_dtype ) and not is_integer_dtype ( self_dtype ) and self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_fp_only_op_with_tensor_arg_followed_by ( [ 0 ] ) )
@check_dtype_function ( _check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . bool , torch . uint8 , torch . int8 , torch . int16 , torch . int32 , torch . int64 } , dim = [ 0 ] ) )
def aten〇 frobenius_norm〇 dim〡dtype ( self_rank : int , self_dtype : int , dim : List [ int ] , keepdim : bool = False ) - > int :
assert self_dtype not in ( torch . int64 , torch . int32 , torch . bool )
assert not is_integer_dtype ( self_dtype )
if self_dtype == torch . complex128 :
return torch . float64
elif self_dtype == torch . complex64 :
return torch . float32
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 , error_types = { torch . float16 } ) )
def prims〇 sqrt〡dtype ( self_rank : int , self_dtype : int ) - > int :
assert self_dtype != torch . float16
return _get_dtype_of_floating_point_op ( self_dtype )
@check_dtype_function ( _ get_invocations_for_op_with_tensor_arg_followed_by( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 ) )
def aten〇 all〡dtype ( self_rank : int , self_dtype : int ) - > int :
return torch . bool
return torch . uint8 if self_dtype == torch . uint8 else torch . bool
@check_dtype_function ( _ get_invocations_for_op_with_tensor_arg_followed_by( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 ) )
def aten〇 any〡dtype ( self_rank : int , self_dtype : int ) - > int :
return torch . bool
return torch . uint8 if self_dtype == torch . uint8 else torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
@check_dtype_function ( _check_tensors_with_the_same_dtype ( num_of_tensors = 1 , other = 0.0 ) +
_check_tensors_with_the_same_dtype ( num_of_tensors = 1 , other = 0 ) )
def aten〇 eq〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
return torch . bool
@check_dtype_function (
_get_invocations_for_op_with_tensor_arg_followed_by ( NonZeroDTensorWithDtype ( torch . float ) ) )
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 eq〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0.0 ) +
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0 ) )
def aten〇 ge〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0.0 ) +
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0 ) )
def aten〇 gt〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function (
_ get_invocations_for_op_with_tensor_arg_followed_by( NonZeroDTensorWithDtype ( torch . float ) ) )
_ check_two_tensor_op( input_error_types = { torch . complex64 , torch . complex128 } ) )
def aten〇 gt〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
def aten〇 le〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( other_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function (
_get_invocations_for_op_with_tensor_arg_followed_by ( NonZeroDTensorWithDtype ( torch . float ) ) )
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0.0 ) +
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0 ) )
def aten〇 le〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 logical_and〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( ) )
@check_dtype_function ( _ check_tensors_with_the_same_dtype( num_of_tensors = 1 ) )
def aten〇 logical_not〡dtype ( self_rank : int , self_dtype : int ) - > int :
return torch . bool
@check_dtype_function (
_get_invocations_for_op_with_tensor_arg_followed_by ( NonZeroDTensorWithDtype ( torch . float ) ) )
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 logical_or〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
return torch . bool
@check_dtype_function (
_get_invocations_for_op_with_tensor_arg_followed_by ( NonZeroDTensorWithDtype ( torch . float ) ) )
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 logical_xor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0.0 ) +
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . complex64 , torch . complex128 } , other = 0 ) )
def aten〇 lt〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function (
_ get_invocations_for_op_with_tensor_arg_followed_by( NonZeroDTensorWithDtype ( torch . float ) ) )
_ check_two_tensor_op( input_error_types = { torch . complex64 , torch . complex128 } ) )
def aten〇 lt〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( other_dtype ) , " `self` cannot be complex "
return torch . bool
@check_dtype_function ( _get_invocations_for_op_with_tensor_arg_followed_by ( 0.0 ) )
@check_dtype_function (
_check_tensors_with_the_same_dtype ( num_of_tensors = 1 , other = 0.0 ) +
_check_tensors_with_the_same_dtype ( num_of_tensors = 1 , other = 0 ) )
def aten〇 ne〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] ) - > int :
return torch . bool
@ -1205,55 +1320,231 @@ def aten〇 add〡dtype(a: Union[int, float], b: Union[int, float]) -> int:
dtypes = [ get_dtype_of_scalar ( a ) , get_dtype_of_scalar ( b ) ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function ( [
Invocation ( NonZeroDTensorWithDtype ( torch . complex64 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . complex128 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . float ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . double ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . bool ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . uint8 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int8 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int16 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int32 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int64 ) ) ,
ErrorInvocation ( NonZeroDTensorWithDtype ( torch . float16 ) ) ,
ErrorInvocation ( NonZeroDTensorWithDtype ( torch . bfloat16 ) ) ,
] )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . float16 , torch . bfloat16 } ) )
def aten〇 fft_fft〡dtype ( self_rank : int , self_dtype : int , n : Optional [ int ] = None , dim : int = - 1 , norm : Optional [ str ] = None ) - > int :
if self_dtype == torch . complex64 or self_dtype == torch . complex128 :
if is_complex_dtype ( self_dtype ) :
return self_dtype
elif self_dtype == torch . float :
return torch . complex64
elif self_dtype == torch . double :
return torch . complex128
elif self_dtype == torch . bool or self_dtype == torch . uint8 or \
self_dtype == torch . int8 or self_dtype == torch . int16 or \
self_dtype == torch . int32 or self_dtype == torch . int64 :
elif is_integer_dtype ( self_dtype ) :
return torch . complex64
else :
assert False , " Unsupported dtype "
@check_dtype_function ( [
Invocation ( NonZeroDTensorWithDtype ( torch . float32 ) , NonZeroDTensorWithDtype ( torch . float32 ) ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float64 ) , NonZeroDTensorWithDtype ( torch . float32 ) ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float32 ) , NonZeroDTensorWithDtype ( torch . float64 ) ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . float32 ) , NonZeroDTensorWithDtype ( torch . int32 ) ) ,
] )
def aten〇 floor_divide〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
@check_dtype_function (
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . bool } , other = 0.0 ) +
_check_tensors_with_the_same_dtype (
num_of_tensors = 1 , error_types = { torch . bool } , other = 0 ) )
def aten〇 rsub〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] , alpha : Union [ int , float ] = 1 ) - > int :
assert self_dtype != torch . bool , " `self` cannot have bool dtype "
return promote_dtypes ( [ self_rank , None ] , [ self_dtype , get_dtype_of_scalar ( other ) ] )
@check_dtype_function (
_check_two_tensor_op ( input_error_types = { torch . float16 , torch . bfloat16 , torch . float32 ,
torch . float64 , torch . complex64 , torch . complex128 } ) )
def aten〇 __and__〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert is_integer_dtype ( self_dtype ) , " Expected `self` to have integer dtype "
assert is_integer_dtype ( other_dtype ) , " Expected `other` to have integer dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function ( [
Invocation ( NonZeroDTensorWithDtype ( torch . float32 ) , other = 0 ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . int64 ) , other = 0.0 ) ,
Invocation ( NonZeroDTensorWithDtype ( torch . float16 ) , other = 0.0 ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float32 ) , other = 0 ) ,
Invocation ( ZeroDTensorWithDtype ( torch . int64 ) , other = 0.0 ) ,
Invocation ( ZeroDTensorWithDtype ( torch . float16 ) , other = 0.0 )
] )
def aten〇 rsub〇 Scalar〡dtype ( self_rank : int , self_dtype : int , other : Union [ int , float ] , alpha : Union [ int , float ] = 1 ) - > int :
return promote_dtypes ( [ self_rank , None ] , [ self_dtype , get_dtype_of_scalar ( other ) ] )
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 add〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int , alpha : Union [ int , float ] = 1 ) - > int :
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_two_tensor_op ( input_error_types = { torch . float16 , torch . bfloat16 , torch . float32 ,
torch . float64 , torch . complex64 , torch . complex128 } ) )
def aten〇 bitwise_and〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert is_integer_dtype ( self_dtype ) , " Expected `self` to have integer dtype "
assert is_integer_dtype ( other_dtype ) , " Expected `other` to have integer dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_two_tensor_op ( input_error_types = { torch . float16 , torch . bfloat16 , torch . float32 ,
torch . float64 , torch . complex64 , torch . complex128 } ) )
def aten〇 bitwise_or〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert is_integer_dtype ( self_dtype ) , " Expected `self` to have integer dtype "
assert is_integer_dtype ( other_dtype ) , " Expected `other` to have integer dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_two_tensor_op ( input_error_types = { torch . float16 , torch . bfloat16 , torch . float32 ,
torch . float64 , torch . complex64 , torch . complex128 } ) )
def aten〇 bitwise_xor〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert is_integer_dtype ( self_dtype ) , " Expected `self` to have integer dtype "
assert is_integer_dtype ( other_dtype ) , " Expected `other` to have integer dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
tensor_shapes = [ ( 2 , 3 , 4 ) , ( 2 , 4 , 3 ) ] , error_types = { torch . float16 , torch . bool } ) +
# Different width
[ ErrorInvocation ( TensorOfShape ( 2 , 3 , 4 , dtype = torch . float64 ) ,
TensorOfShape ( 2 , 4 , 3 , dtype = torch . float32 ) ) ,
# Different type
ErrorInvocation ( TensorOfShape ( 2 , 3 , 4 , dtype = torch . float32 ) ,
TensorOfShape ( 2 , 4 , 3 , dtype = torch . int32 ) ) ] )
def aten〇 bmm〡dtype ( self_rank : int , self_dtype : int , mat2_rank : int , mat2_dtype : int ) - > int :
assert self_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `self` to not be float16 or bool "
assert mat2_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `mat2` to not be float16 or bool "
assert self_dtype == mat2_dtype , " `self` and `mat2` must have the same dtype "
return self_dtype
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 div〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
promoted_dtype = promote_dtypes ( ranks , dtypes )
if is_complex_dtype ( promoted_dtype ) or \
( is_float_dtype ( promoted_dtype ) and promoted_dtype != torch . float32 ) :
return promoted_dtype
else :
return torch . float32
@check_dtype_function ( _check_two_tensor_op ( rounding_mode = None ) )
def aten〇 div〇 Tensor_mode〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int , rounding_mode : Optional [ str ] ) - > int :
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
promoted_dtype = promote_dtypes ( ranks , dtypes )
if is_complex_dtype ( promoted_dtype ) or \
( is_float_dtype ( promoted_dtype ) and promoted_dtype != torch . float32 ) :
return promoted_dtype
else :
return torch . float32
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . complex64 , torch . complex128 } , output_error_types = { torch . bool } ) )
def aten〇 floor_divide〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( other_dtype ) , " `other` cannot be complex "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
promoted_dtype = promote_dtypes ( ranks , dtypes )
assert promoted_dtype != torch . bool , " Result dtype for aten.floor_divide bool "
return promoted_dtype
@check_dtype_function (
_check_tensors_with_the_same_dtype (
tensor_shapes = [ ( 2 , 3 , 4 ) , ( 2 , 4 , 3 ) ] , error_types = { torch . float16 , torch . bool } ) +
# Different width
[ ErrorInvocation ( TensorOfShape ( 2 , 3 , 4 , dtype = torch . float64 ) ,
TensorOfShape ( 2 , 4 , 3 , dtype = torch . float32 ) ) ,
# Different type
ErrorInvocation ( TensorOfShape ( 2 , 3 , 4 , dtype = torch . float32 ) ,
TensorOfShape ( 2 , 4 , 3 , dtype = torch . int32 ) ) ] )
def aten〇 matmul〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert self_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `self` to not be float16 or bool "
assert other_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `other` to not be float16 or bool "
assert self_dtype == other_dtype , " `self` and `other` must have the same dtype "
return self_dtype
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . complex64 , torch . complex128 } ) )
def aten〇 maximum〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( other_dtype ) , " `other` cannot be complex "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . complex64 , torch . complex128 } ) )
def aten〇 minimum〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( other_dtype ) , " `other` cannot be complex "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
tensor_shapes = [ ( 3 , 4 ) , ( 4 , 3 ) ] , error_types = { torch . float16 , torch . bool } ) +
# Different width
[ ErrorInvocation ( TensorOfShape ( 3 , 4 , dtype = torch . float64 ) ,
TensorOfShape ( 4 , 3 , dtype = torch . float32 ) ) ,
# Different type
ErrorInvocation ( TensorOfShape ( 3 , 4 , dtype = torch . float32 ) ,
TensorOfShape ( 4 , 3 , dtype = torch . int32 ) ) ] )
def aten〇 mm〡dtype ( self_rank : int , self_dtype : int , mat2_rank : int , mat2_dtype : int ) - > int :
assert self_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `self` to not be float16 or bool "
assert mat2_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `mat2` to not be float16 or bool "
assert self_dtype == mat2_dtype , " `self` and `mat2` must have the same dtype "
return self_dtype
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . complex64 , torch . complex128 } ,
output_error_types = { torch . bool , torch . uint8 , torch . int8 , torch . int16 ,
torch . int32 , torch . int64 , torch . bfloat16 } ) )
def aten〇 mse_loss〡dtype ( self_rank : int , self_dtype : int , target_rank : int , target_dtype : int , reduction : int = 1 ) - > int :
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
assert not is_complex_dtype ( target_dtype ) , " `target` cannot be complex "
ranks : List [ Optional [ int ] ] = [ self_rank , target_rank ]
dtypes = [ self_dtype , target_dtype ]
promoted_dtype = promote_dtypes ( ranks , dtypes )
assert is_float_dtype ( promoted_dtype ) and promoted_dtype != torch . bfloat16 , \
" Expected promoted dtype to be float but not `bfloat16` "
return promoted_dtype
@check_dtype_function ( _check_two_tensor_op ( ) )
def aten〇 mul〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int ) - > int :
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function (
_check_tensors_with_the_same_dtype (
tensor_shapes = [ ( 3 , 4 ) , ( 4 , ) ] , error_types = { torch . float16 , torch . bool } ) +
# Different width
[ ErrorInvocation ( TensorOfShape ( 3 , 4 , dtype = torch . float64 ) ,
TensorOfShape ( 4 , dtype = torch . float32 ) ) ,
# Different type
ErrorInvocation ( TensorOfShape ( 3 , 4 , dtype = torch . float32 ) ,
TensorOfShape ( 4 , dtype = torch . int32 ) ) ] )
def aten〇 mv〡dtype ( self_rank : int , self_dtype : int , vec_rank : int , vec_dtype : int ) - > int :
assert self_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `self` to not be float16 or bool "
assert vec_dtype not in [ torch . float16 , torch . bool ] , \
" Expected dtype of `vec` to not be float16 or bool "
assert self_dtype == vec_dtype , " `self` and `vec` must have the same dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , vec_rank ]
dtypes = [ self_dtype , vec_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . bool } ) )
def aten〇 sub〇 Tensor〡dtype ( self_rank : int , self_dtype : int , other_rank : int , other_dtype : int , alpha : Union [ int , float ] = 1 ) - > int :
assert self_dtype != torch . bool , " `self` cannot be of bool dtype "
assert other_dtype != torch . bool , " `other` cannot be of bool dtype "
ranks : List [ Optional [ int ] ] = [ self_rank , other_rank ]
dtypes = [ self_dtype , other_dtype ]
return promote_dtypes ( ranks , dtypes )
@check_dtype_function ( _check_two_tensor_op ( input_error_types = { torch . complex64 , torch . complex128 } , output_error_types = { torch . bool , torch . float16 } , threshold = 0 ) )
def aten〇 threshold_backward〡dtype ( grad_output_rank : int , grad_output_dtype : int , self_rank : int , self_dtype : int , threshold : Union [ int , float ] ) - > int :
assert not is_complex_dtype ( grad_output_dtype ) , " `grad_output` cannot be complex "
assert not is_complex_dtype ( self_dtype ) , " `self` cannot be complex "
ranks : List [ Optional [ int ] ] = [ grad_output_rank , self_rank ]
dtypes = [ grad_output_dtype , self_dtype ]
promoted_dtype = promote_dtypes ( ranks , dtypes )
assert promoted_dtype not in [ torch . bool , torch . float16 ] , \
" Result dtype for aten.threshold_backward cannot be bool or float16 "
return promoted_dtype
# ==============================================================================
# Main