[torch] Add support for f8 types for linalg conversion (#3436)

Linalg conversion requires mapping for f8 types
pull/3437/head
Rob Suderman 2024-06-07 13:59:38 -07:00 committed by GitHub
parent 7f188eb824
commit 75af64fc12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 18 deletions

View File

@ -103,7 +103,16 @@ enum class TypeKind {
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,

View File

@ -80,6 +80,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
if (complexElemType.isF64())
return torch_upstream::ScalarType::ComplexDouble;
}
if (isa<Float8E5M2Type>(type))
return torch_upstream::ScalarType::Float8_e5m2;
if (isa<Float8E4M3FNType>(type))
return torch_upstream::ScalarType::Float8_e4m3fn;
if (isa<Float8E5M2FNUZType>(type))
return torch_upstream::ScalarType::Float8_e5m2fnuz;
if (isa<Float8E4M3FNUZType>(type))
return torch_upstream::ScalarType::Float8_e4m3fnuz;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
Type Torch::getTypeForTorchType(
@ -128,6 +136,14 @@ Torch::getTypeForScalarType(MLIRContext *context,
return mlir::ComplexType::get(Float32Type::get(context));
case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::Float8_e5m2:
return Float8E5M2Type::get(context);
case torch_upstream::ScalarType::Float8_e4m3fn:
return Float8E4M3FNType::get(context);
case torch_upstream::ScalarType::Float8_e5m2fnuz:
return Float8E5M2FNUZType::get(context);
case torch_upstream::ScalarType::Float8_e4m3fnuz:
return Float8E4M3FNUZType::get(context);
case torch_upstream::ScalarType::Undefined:
return failure();
default: