[torch] Add support for f8 types for linalg conversion (#3436)

Linalg conversion requires mapping for f8 types
pull/3437/head
Rob Suderman 2024-06-07 13:59:38 -07:00 committed by GitHub
parent 7f188eb824
commit 75af64fc12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 18 deletions

View File

@ -86,24 +86,33 @@ enum class TypeKind {
// at:: and c10:: parts of the macro are never used within the compiler -- we // at:: and c10:: parts of the macro are never used within the compiler -- we
// only use this for the enum values. // only use this for the enum values.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \ _(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \ _(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \ _(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \ _(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \ _(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \ _(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \ _(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \ _(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \ _(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \ _(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \ _(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \ _(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \ _(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \ _(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \ _(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \ _(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \ _(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ _(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
enum class ScalarType : int8_t { enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n, #define DEFINE_ENUM(_1, n) n,

View File

@ -80,6 +80,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
if (complexElemType.isF64()) if (complexElemType.isF64())
return torch_upstream::ScalarType::ComplexDouble; return torch_upstream::ScalarType::ComplexDouble;
} }
if (isa<Float8E5M2Type>(type))
return torch_upstream::ScalarType::Float8_e5m2;
if (isa<Float8E4M3FNType>(type))
return torch_upstream::ScalarType::Float8_e4m3fn;
if (isa<Float8E5M2FNUZType>(type))
return torch_upstream::ScalarType::Float8_e5m2fnuz;
if (isa<Float8E4M3FNUZType>(type))
return torch_upstream::ScalarType::Float8_e4m3fnuz;
llvm::report_fatal_error("unhandled type for getScalarTypeForType"); llvm::report_fatal_error("unhandled type for getScalarTypeForType");
} }
Type Torch::getTypeForTorchType( Type Torch::getTypeForTorchType(
@ -128,6 +136,14 @@ Torch::getTypeForScalarType(MLIRContext *context,
return mlir::ComplexType::get(Float32Type::get(context)); return mlir::ComplexType::get(Float32Type::get(context));
case torch_upstream::ScalarType::ComplexDouble: case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float64Type::get(context)); return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::Float8_e5m2:
return Float8E5M2Type::get(context);
case torch_upstream::ScalarType::Float8_e4m3fn:
return Float8E4M3FNType::get(context);
case torch_upstream::ScalarType::Float8_e5m2fnuz:
return Float8E5M2FNUZType::get(context);
case torch_upstream::ScalarType::Float8_e4m3fnuz:
return Float8E4M3FNUZType::get(context);
case torch_upstream::ScalarType::Undefined: case torch_upstream::ScalarType::Undefined:
return failure(); return failure();
default: default: