Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This source code is copied from PyTorch, and remains licensed under
|
|
|
|
// the PyTorch BSD-style license available at
|
|
|
|
// https://github.com/pytorch/pytorch/blob/master/LICENSE
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
|
2023-02-22 07:05:55 +08:00
|
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
|
|
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
namespace mlir {
|
|
|
|
namespace torch {
|
|
|
|
namespace torch_upstream {
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ScalarType enum related code are copied from c10/core/ScalarType.h.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static inline bool isQIntType(ScalarType t) {
|
|
|
|
// Don't forget to extend this when adding new QInt types
|
|
|
|
return t == ScalarType::QInt8 || t == ScalarType::QUInt8 ||
|
|
|
|
t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 ||
|
|
|
|
t == ScalarType::QUInt2x4;
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Type promotion related code are copied from
|
|
|
|
// aten/src/ATen/native/TypeProperties.*.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|
|
|
// This is generated according to NumPy's promote_types
|
|
|
|
constexpr auto u1 = ScalarType::Byte;
|
|
|
|
constexpr auto i1 = ScalarType::Char;
|
|
|
|
constexpr auto i2 = ScalarType::Short;
|
|
|
|
constexpr auto i4 = ScalarType::Int;
|
|
|
|
constexpr auto i8 = ScalarType::Long;
|
|
|
|
constexpr auto f2 = ScalarType::Half;
|
|
|
|
constexpr auto f4 = ScalarType::Float;
|
|
|
|
constexpr auto f8 = ScalarType::Double;
|
|
|
|
constexpr auto c2 = ScalarType::ComplexHalf;
|
|
|
|
constexpr auto c4 = ScalarType::ComplexFloat;
|
|
|
|
constexpr auto c8 = ScalarType::ComplexDouble;
|
|
|
|
constexpr auto b1 = ScalarType::Bool;
|
|
|
|
constexpr auto bf = ScalarType::BFloat16;
|
|
|
|
constexpr auto ud = ScalarType::Undefined;
|
|
|
|
if (a == ud || b == ud) {
|
|
|
|
return ScalarType::Undefined;
|
|
|
|
}
|
|
|
|
|
|
|
|
// For QInt types, we only allow exact match
|
|
|
|
if (isQIntType(a) && a == b) {
|
|
|
|
return a;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (isQIntType(a) || isQIntType(b)) {
|
|
|
|
assert(false && "promoteTypes with quantized numbers is not handled yet; "
|
|
|
|
"figure out what the correct rules should be");
|
|
|
|
}
|
|
|
|
|
|
|
|
// this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
|
|
|
|
// so that's why we have to add undefined as we are not sure what is the
|
|
|
|
// corrent values for the type promotions in complex type cases.
|
|
|
|
static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
|
|
|
|
ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
|
|
|
|
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
|
|
|
|
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, u1, ud, ud, ud, bf},
|
|
|
|
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, i1, ud, ud, ud, bf},
|
|
|
|
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, c4, c8, i2, ud, ud, ud, bf},
|
|
|
|
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, c4, c8, i4, ud, ud, ud, bf},
|
|
|
|
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, c4, c8, i8, ud, ud, ud, bf},
|
|
|
|
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, c4, c8, f2, ud, ud, ud, f4},
|
|
|
|
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, c4, c8, f4, ud, ud, ud, f4},
|
|
|
|
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, c8, c8, f8, ud, ud, ud, f8},
|
|
|
|
/* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, c2, c4, c8, ud, ud, ud, ud, ud},
|
|
|
|
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
|
|
|
|
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
|
|
|
|
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, c4, c8, b1, ud, ud, ud, bf},
|
|
|
|
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
|
|
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
|
|
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
|
|
|
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, ud, c4, c8, bf, ud, ud, ud, bf},
|
|
|
|
};
|
|
|
|
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
|
|
|
}
|
|
|
|
|
|
|
|
static inline bool isFloatingType(ScalarType t) {
|
|
|
|
return (t == ScalarType::Double || t == ScalarType::Float ||
|
|
|
|
t == ScalarType::Half || t == ScalarType::BFloat16);
|
|
|
|
}
|
|
|
|
|
|
|
|
static inline bool isComplexType(ScalarType t) {
|
|
|
|
return (t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
|
|
|
|
t == ScalarType::ComplexDouble);
|
|
|
|
}
|
|
|
|
|
|
|
|
static inline ScalarType combine_categories(ScalarType higher,
|
|
|
|
ScalarType lower) {
|
|
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
|
|
if (isComplexType(higher)) {
|
|
|
|
return higher;
|
|
|
|
} else if (!isComplexType(lower) && isFloatingType(higher)) {
|
|
|
|
return higher;
|
|
|
|
}
|
|
|
|
if (higher == ScalarType::Bool || isFloatingType(lower) ||
|
|
|
|
isComplexType(lower)) {
|
|
|
|
return promote_skip_undefined(higher, lower);
|
|
|
|
}
|
|
|
|
if (higher != ScalarType::Undefined) {
|
|
|
|
return higher;
|
|
|
|
}
|
|
|
|
return lower;
|
|
|
|
}
|
|
|
|
|
|
|
|
ScalarType promote_skip_undefined(ScalarType a, ScalarType b) {
|
|
|
|
if (a == ScalarType::Undefined) {
|
|
|
|
return b;
|
|
|
|
}
|
|
|
|
if (b == ScalarType::Undefined) {
|
|
|
|
return a;
|
|
|
|
}
|
|
|
|
return promoteTypes(a, b);
|
|
|
|
}
|
|
|
|
|
|
|
|
ScalarType result_type(const ResultTypeState &in_state) {
|
|
|
|
return combine_categories(
|
|
|
|
in_state.dimResult,
|
|
|
|
combine_categories(in_state.zeroResult, in_state.wrappedResult));
|
|
|
|
}
|
|
|
|
|
2023-02-22 07:05:55 +08:00
|
|
|
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
|
|
|
|
if (reduce == "max" || reduce == "amax") {
|
|
|
|
return torch_upstream::ReductionType::MAX;
|
|
|
|
} else if (reduce == "mean") {
|
|
|
|
return torch_upstream::ReductionType::MEAN;
|
|
|
|
} else if (reduce == "min" || reduce == "amin") {
|
|
|
|
return torch_upstream::ReductionType::MIN;
|
|
|
|
} else if (reduce == "sum") {
|
|
|
|
return torch_upstream::ReductionType::SUM;
|
|
|
|
} else if (reduce == "prod") {
|
|
|
|
return torch_upstream::ReductionType::PROD;
|
|
|
|
} else {
|
|
|
|
llvm_unreachable(
|
|
|
|
"'reduce' argument must be either sum, prod, mean, amax or amin");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
} // namespace torch_upstream
|
|
|
|
} // namespace torch
|
|
|
|
} // namespace mlir
|