Add dtype functions for two tensor promotion ops (#1831)

This commit adds dtype functions for ops in RefineTypes under the
category of "Promote the two dtypes". The only ops not added here are
convolution ops, since they take an optional tensor argument, and the
dtype pipeline currently does not correctly handle that case. I will
add a follow up patch fixing this.

This commit also adds two helper functions that perform a very
thorough testing of dtype functions. The helper function
`_check_two_tensor_op` is able to independently test invalid input
dtypes and invalid output dtypes.

Lastly, this commit also XFAILs "MobilenetV3Module_basic".
pull/1865/head
Ramiro Leal-Cavazos 2023-02-01 22:30:27 +00:00 committed by GitHub
parent 83d4e89d25
commit 981ac88758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1289 additions and 271 deletions

View File

@ -85,6 +85,11 @@ TORCHDYNAMO_XFAIL_SET = {
"ElementwisePreluModule_basic",
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
"StdCorrectionKeepDimModule_basic",
# Dtype function transition failures
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
}
MHLO_PASS_SET = {

File diff suppressed because it is too large Load Diff

View File

@ -693,10 +693,9 @@ void TypeAnalysis::visitOperation(Operation *op,
}
// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp,
AtenMseLossOp>(op)) {
if (isa<AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
@ -705,20 +704,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
// Promote the two dtypes assuming possibly-zero rank.
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenThresholdBackwardOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
incorporateKnowledge(op->getResult(0), knowledge);
return;
}
// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {

View File

@ -190,6 +190,8 @@ class SimplifyDtypeCalculationsPass
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);
PrimIfOp::getCanonicalizationPatterns(patterns, context);
// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;

View File

@ -12,7 +12,7 @@ from torch import device
import torch.jit._shape_functions as upstream_shape_functions
from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function
from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar
from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype
# ==============================================================================
# Shape Functions
@ -1017,180 +1017,295 @@ def atenupsample_nearest2d〡shape(self: List[int], output_size: List[int], s
# Dtype Functions
# ==============================================================================
def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
"""Generate invocations that thoroughly test the first tensor arg of the op.
# All the torch types sorted in decreasing order of priority during type promotion.
_SORTED_TORCH_TYPES = [
torch.complex128, torch.complex64,
torch.float64, torch.float32, torch.float16, torch.bfloat16,
torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool
]
This is meant to be used by ops where the entire dtype computation involves
at most the first tensor argument of the op. If an dtype function uses other
arguments, custom invocations should be created to test the logic of the
dtype function instead of using this helper function.
def _check_tensors_with_the_same_dtype(
num_of_tensors: Optional[int] = None,
tensor_shapes: Optional[list[tuple[int]]] = None,
error_types: Optional[set[int]] = None, *args, **kwargs):
"""Create invocations where all tensors have the same dtype.
This function generates invocations with `num_of_tensors` tensors
that all have the same dtype. It creates an invocation for every
possible dtype. For dtypes in `error_types`, the invocations are
error invocations.
One can also specify the shapes of the tensors. Either `num_of_tensors`
or `tensor_shapes` must be specified whenever this function is called.
The extra *args and **kwargs arguments are passed to the invocations.
"""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]
invocations = []
for type_ in _SORTED_TORCH_TYPES:
tensors = []
if tensor_shapes is None and num_of_tensors is not None:
tensors = [NonZeroDTensorWithDtype(type_)] * num_of_tensors
elif tensor_shapes is not None and num_of_tensors is None:
for tensor_shape in tensor_shapes:
tensors.append(TensorOfShape(*tensor_shape, dtype=type_))
else:
assert False, \
"Either `num_of_tensors` or `tensor_shapes` must be specified"
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
"""Generate invocations for floating point only op."""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
]
if error_types is not None and type_ in error_types:
invocations.append(ErrorInvocation(*tensors, *args, **kwargs))
else:
invocations.append(Invocation(*tensors, *args, **kwargs))
return invocations
def _check_two_tensor_op(
input_error_types: Optional[set[int]] = None,
output_error_types: Optional[set[int]] = None, **kwargs):
"""Generate invocations for basic two-tensor dtype functions.
This helper function is meant to be used to check dtype functions that
take two tensor operands and either return the promoted result or
return a constant dtype based on the tensor dtypes.
The testing performed is thorough enough to be able to detect if dtypes
are invalid as inputs or as outputs to the PyTorch op. Invalid dtypes
must be specified in `input_error_types` and `output_error_types` to
ensure the invocations are error invocations.
"""
if input_error_types is not None and output_error_types is not None:
assert len(input_error_types.intersection(output_error_types)) == 0, \
"An invalid input type implies an invalid output type, " \
"so there is no need to repeat the type in the `output_error_types` set"
all_error_types = set()
all_error_types |= set() if input_error_types is None else input_error_types
all_error_types |= set() if output_error_types is None else output_error_types
def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs):
"""Create invocations where one tensor varies its dtype.
This helper function creates invocations with two tensors where one
tensor varies its dtype while the other one stays constant. The varying
is done for both tensors and the varying is performed over every possible
dtype.
This function helps identify when a dtype is an invalid input dtype
for dtype functions that do promotion.
"""
# We will only create invocations for dtypes with priorities less than
# or equal to the highest priority valid type. By setting the non-varying
# tensor dtype to be the highest priority valid type, we ensure that
# every promotion results in a valid dtype. This allows the invocations
# to test in isolation assertions on input types.
constant_type = None
constant_type_index = None
for e, type_ in enumerate(_SORTED_TORCH_TYPES):
if type_ not in all_error_types:
constant_type = type_
constant_type_index = e
break
assert constant_type is not None, \
"Unable to find a constant type. Make sure the union of " \
"`input_error_types` and `output_error_types` is not all possible types."
invocations = []
for type_ in _SORTED_TORCH_TYPES[constant_type_index:]:
tensor_1 = NonZeroDTensorWithDtype(type_)
tensor_2 = NonZeroDTensorWithDtype(constant_type)
if input_error_types is not None and type_ in input_error_types:
invocations += [ErrorInvocation(tensor_1, tensor_2, **kwargs),
ErrorInvocation(tensor_2, tensor_1, **kwargs)]
else:
invocations += [Invocation(tensor_1, tensor_2, **kwargs),
Invocation(tensor_2, tensor_1, **kwargs)]
return invocations
same_dtype_invocations = _check_tensors_with_the_same_dtype(
num_of_tensors=2, error_types=all_error_types, **kwargs)
varying_dtype_invocations = \
check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs)
return same_dtype_invocations + varying_dtype_invocations
def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
if (is_float_dtype(input_dtype) and input_dtype != torch.float32) \
or is_complex_dtype(input_dtype):
return input_dtype
return torch.float32
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atenexp〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128}))
def atenexpm1〡dtype(self_rank: int, self_dtype: int) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert self_dtype != torch.float16, "`self` cannot have float16 dtype"
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atensin〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atencos〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atensigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenreciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atensqrt〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atenlog〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atenlog2〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atenlog1p〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def atenrsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16, torch.complex64, torch.complex128}))
def atenerf〡dtype(self_rank: int, self_dtype: int) -> int:
assert not is_complex_dtype(self_dtype) and self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32,
torch.int64, torch.float16, torch.complex64, torch.complex128}))
def atensoftplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
assert not is_complex_dtype(self_dtype) and not is_integer_dtype(self_dtype) and self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0]))
def atenfrobenius_normdim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
assert not is_integer_dtype(self_dtype)
if self_dtype == torch.complex128:
return torch.float64
elif self_dtype == torch.complex64:
return torch.float32
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.float16}))
def primssqrt〡dtype(self_rank: int, self_dtype: int) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenall〡dtype(self_rank: int, self_dtype: int) -> int:
return torch.bool
return torch.uint8 if self_dtype == torch.uint8 else torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenany〡dtype(self_rank: int, self_dtype: int) -> int:
return torch.bool
return torch.uint8 if self_dtype == torch.uint8 else torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def ateneqScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
@check_dtype_function(_check_two_tensor_op())
def ateneqTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0))
def atengeScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0))
def atengtScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}))
def atengtTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
def atenleScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(other_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0))
def atenleScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(_check_two_tensor_op())
def atenlogical_and〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenlogical_not〡dtype(self_rank: int, self_dtype: int) -> int:
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
@check_dtype_function(_check_two_tensor_op())
def atenlogical_or〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
@check_dtype_function(_check_two_tensor_op())
def atenlogical_xor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0.0) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=0))
def atenltScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(
_get_invocations_for_op_with_tensor_arg_followed_by(NonZeroDTensorWithDtype(torch.float)))
_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}))
def atenltTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(other_dtype), "`self` cannot be complex"
return torch.bool
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by(0.0))
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atenneScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float]) -> int:
return torch.bool
@ -1205,55 +1320,231 @@ def atenadd〡dtype(a: Union[int, float], b: Union[int, float]) -> int:
dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)]
return promote_dtypes(ranks, dtypes)
@check_dtype_function([
Invocation(NonZeroDTensorWithDtype(torch.complex64)),
Invocation(NonZeroDTensorWithDtype(torch.complex128)),
Invocation(NonZeroDTensorWithDtype(torch.float)),
Invocation(NonZeroDTensorWithDtype(torch.double)),
Invocation(NonZeroDTensorWithDtype(torch.bool)),
Invocation(NonZeroDTensorWithDtype(torch.uint8)),
Invocation(NonZeroDTensorWithDtype(torch.int8)),
Invocation(NonZeroDTensorWithDtype(torch.int16)),
Invocation(NonZeroDTensorWithDtype(torch.int32)),
Invocation(NonZeroDTensorWithDtype(torch.int64)),
ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)),
])
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.float16, torch.bfloat16}))
def atenfft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int:
if self_dtype == torch.complex64 or self_dtype == torch.complex128:
if is_complex_dtype(self_dtype):
return self_dtype
elif self_dtype == torch.float:
return torch.complex64
elif self_dtype == torch.double:
return torch.complex128
elif self_dtype == torch.bool or self_dtype == torch.uint8 or \
self_dtype == torch.int8 or self_dtype == torch.int16 or \
self_dtype == torch.int32 or self_dtype == torch.int64:
elif is_integer_dtype(self_dtype):
return torch.complex64
else:
assert False, "Unsupported dtype"
@check_dtype_function([
Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float32)),
Invocation(ZeroDTensorWithDtype(torch.float64), NonZeroDTensorWithDtype(torch.float32)),
Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)),
Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)),
])
def atenfloor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.bool}, other=0.0) +
_check_tensors_with_the_same_dtype(
num_of_tensors=1, error_types={torch.bool}, other=0))
def atenrsubScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
assert self_dtype != torch.bool, "`self` cannot have bool dtype"
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
@check_dtype_function(
_check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128}))
def aten__and__Tensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype"
assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function([
Invocation(NonZeroDTensorWithDtype(torch.float32), other=0),
Invocation(NonZeroDTensorWithDtype(torch.int64), other=0.0),
Invocation(NonZeroDTensorWithDtype(torch.float16), other=0.0),
Invocation(ZeroDTensorWithDtype(torch.float32), other=0),
Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0),
Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0)
])
def atenrsubScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
@check_dtype_function(_check_two_tensor_op())
def atenaddTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int:
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128}))
def atenbitwise_andTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype"
assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128}))
def atenbitwise_orTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype"
assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_two_tensor_op(input_error_types={torch.float16, torch.bfloat16, torch.float32,
torch.float64, torch.complex64, torch.complex128}))
def atenbitwise_xorTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert is_integer_dtype(self_dtype), "Expected `self` to have integer dtype"
assert is_integer_dtype(other_dtype), "Expected `other` to have integer dtype"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) +
# Different width
[ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64),
TensorOfShape(2, 4, 3, dtype=torch.float32)),
# Different type
ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32),
TensorOfShape(2, 4, 3, dtype=torch.int32))])
def atenbmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int:
assert self_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `self` to not be float16 or bool"
assert mat2_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `mat2` to not be float16 or bool"
assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype"
return self_dtype
@check_dtype_function(_check_two_tensor_op())
def atendivTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
if is_complex_dtype(promoted_dtype) or \
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32):
return promoted_dtype
else:
return torch.float32
@check_dtype_function(_check_two_tensor_op(rounding_mode=None))
def atendivTensor_mode〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, rounding_mode: Optional[str]) -> int:
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
if is_complex_dtype(promoted_dtype) or \
(is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32):
return promoted_dtype
else:
return torch.float32
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}))
def atenfloor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(other_dtype), "`other` cannot be complex"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool"
return promoted_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(2, 3, 4), (2, 4, 3)], error_types={torch.float16, torch.bool}) +
# Different width
[ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float64),
TensorOfShape(2, 4, 3, dtype=torch.float32)),
# Different type
ErrorInvocation(TensorOfShape(2, 3, 4, dtype=torch.float32),
TensorOfShape(2, 4, 3, dtype=torch.int32))])
def atenmatmul〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert self_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `self` to not be float16 or bool"
assert other_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `other` to not be float16 or bool"
assert self_dtype == other_dtype, "`self` and `other` must have the same dtype"
return self_dtype
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}))
def atenmaximum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(other_dtype), "`other` cannot be complex"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}))
def atenminimum〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(other_dtype), "`other` cannot be complex"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3, 4), (4, 3)], error_types={torch.float16, torch.bool}) +
# Different width
[ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64),
TensorOfShape(4, 3, dtype=torch.float32)),
# Different type
ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.int32))])
def atenmm〡dtype(self_rank: int, self_dtype: int, mat2_rank: int, mat2_dtype: int) -> int:
assert self_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `self` to not be float16 or bool"
assert mat2_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `mat2` to not be float16 or bool"
assert self_dtype == mat2_dtype, "`self` and `mat2` must have the same dtype"
return self_dtype
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128},
output_error_types={torch.bool, torch.uint8, torch.int8, torch.int16,
torch.int32, torch.int64, torch.bfloat16}))
def atenmse_loss〡dtype(self_rank: int, self_dtype: int, target_rank: int, target_dtype: int, reduction: int = 1) -> int:
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
assert not is_complex_dtype(target_dtype), "`target` cannot be complex"
ranks: List[Optional[int]] = [self_rank, target_rank]
dtypes = [self_dtype, target_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
assert is_float_dtype(promoted_dtype) and promoted_dtype != torch.bfloat16, \
"Expected promoted dtype to be float but not `bfloat16`"
return promoted_dtype
@check_dtype_function(_check_two_tensor_op())
def atenmulTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int) -> int:
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3, 4), (4,)], error_types={torch.float16, torch.bool}) +
# Different width
[ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float64),
TensorOfShape(4, dtype=torch.float32)),
# Different type
ErrorInvocation(TensorOfShape(3, 4, dtype=torch.float32),
TensorOfShape(4, dtype=torch.int32))])
def atenmv〡dtype(self_rank: int, self_dtype: int, vec_rank: int, vec_dtype: int) -> int:
assert self_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `self` to not be float16 or bool"
assert vec_dtype not in [torch.float16, torch.bool], \
"Expected dtype of `vec` to not be float16 or bool"
assert self_dtype == vec_dtype, "`self` and `vec` must have the same dtype"
ranks: List[Optional[int]] = [self_rank, vec_rank]
dtypes = [self_dtype, vec_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.bool}))
def atensubTensor〡dtype(self_rank: int, self_dtype: int, other_rank: int, other_dtype: int, alpha: Union[int, float] = 1) -> int:
assert self_dtype != torch.bool, "`self` cannot be of bool dtype"
assert other_dtype != torch.bool, "`other` cannot be of bool dtype"
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0))
def atenthreshold_backward〡dtype(grad_output_rank: int, grad_output_dtype: int, self_rank: int, self_dtype: int, threshold: Union[int, float]) -> int:
assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex"
assert not is_complex_dtype(self_dtype), "`self` cannot be complex"
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
dtypes = [grad_output_dtype, self_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
assert promoted_dtype not in [torch.bool, torch.float16], \
"Result dtype for aten.threshold_backward cannot be bool or float16"
return promoted_dtype
# ==============================================================================
# Main

View File

@ -14,6 +14,16 @@ from torch_mlir.passmanager import PassManager
from .registry import Registry
def is_integer_dtype(dtype: int) -> bool:
return dtype in [torch.bool, torch.uint8, torch.int8,
torch.int16, torch.int32, torch.int64]
def is_complex_dtype(dtype: int) -> bool:
return dtype in [torch.complex64, torch.complex128]
def is_float_dtype(dtype: int) -> bool:
return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
# that when `jit.script`ed converts a float scalar to a tensor

View File

@ -9,6 +9,7 @@
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"NormalizeModule_basic",
"MobilenetV3Module_basic",
}
def register_all_tests():

View File

@ -235,30 +235,6 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>,
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor