mirror of https://github.com/llvm/torch-mlir
[torch] Add support for f8 types for linalg conversion (#3436)
Linalg conversion requires mapping for f8 typespull/3437/head
parent
7f188eb824
commit
75af64fc12
|
@ -86,24 +86,33 @@ enum class TypeKind {
|
||||||
// at:: and c10:: parts of the macro are never used within the compiler -- we
|
// at:: and c10:: parts of the macro are never used within the compiler -- we
|
||||||
// only use this for the enum values.
|
// only use this for the enum values.
|
||||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
||||||
_(uint8_t, Byte) /* 0 */ \
|
_(uint8_t, Byte) /* 0 */ \
|
||||||
_(int8_t, Char) /* 1 */ \
|
_(int8_t, Char) /* 1 */ \
|
||||||
_(int16_t, Short) /* 2 */ \
|
_(int16_t, Short) /* 2 */ \
|
||||||
_(int, Int) /* 3 */ \
|
_(int, Int) /* 3 */ \
|
||||||
_(int64_t, Long) /* 4 */ \
|
_(int64_t, Long) /* 4 */ \
|
||||||
_(at::Half, Half) /* 5 */ \
|
_(at::Half, Half) /* 5 */ \
|
||||||
_(float, Float) /* 6 */ \
|
_(float, Float) /* 6 */ \
|
||||||
_(double, Double) /* 7 */ \
|
_(double, Double) /* 7 */ \
|
||||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||||
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
||||||
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
||||||
_(bool, Bool) /* 11 */ \
|
_(bool, Bool) /* 11 */ \
|
||||||
_(c10::qint8, QInt8) /* 12 */ \
|
_(c10::qint8, QInt8) /* 12 */ \
|
||||||
_(c10::quint8, QUInt8) /* 13 */ \
|
_(c10::quint8, QUInt8) /* 13 */ \
|
||||||
_(c10::qint32, QInt32) /* 14 */ \
|
_(c10::qint32, QInt32) /* 14 */ \
|
||||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||||
_(c10::quint2x4, QUInt2x4) /* 17 */
|
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||||
|
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||||
|
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||||
|
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||||
|
_(c10::bits8, Bits8) /* 21 */ \
|
||||||
|
_(c10::bits16, Bits16) /* 22 */ \
|
||||||
|
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||||
|
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||||
|
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||||
|
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
|
||||||
|
|
||||||
enum class ScalarType : int8_t {
|
enum class ScalarType : int8_t {
|
||||||
#define DEFINE_ENUM(_1, n) n,
|
#define DEFINE_ENUM(_1, n) n,
|
||||||
|
|
|
@ -80,6 +80,14 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
if (complexElemType.isF64())
|
if (complexElemType.isF64())
|
||||||
return torch_upstream::ScalarType::ComplexDouble;
|
return torch_upstream::ScalarType::ComplexDouble;
|
||||||
}
|
}
|
||||||
|
if (isa<Float8E5M2Type>(type))
|
||||||
|
return torch_upstream::ScalarType::Float8_e5m2;
|
||||||
|
if (isa<Float8E4M3FNType>(type))
|
||||||
|
return torch_upstream::ScalarType::Float8_e4m3fn;
|
||||||
|
if (isa<Float8E5M2FNUZType>(type))
|
||||||
|
return torch_upstream::ScalarType::Float8_e5m2fnuz;
|
||||||
|
if (isa<Float8E4M3FNUZType>(type))
|
||||||
|
return torch_upstream::ScalarType::Float8_e4m3fnuz;
|
||||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||||
}
|
}
|
||||||
Type Torch::getTypeForTorchType(
|
Type Torch::getTypeForTorchType(
|
||||||
|
@ -128,6 +136,14 @@ Torch::getTypeForScalarType(MLIRContext *context,
|
||||||
return mlir::ComplexType::get(Float32Type::get(context));
|
return mlir::ComplexType::get(Float32Type::get(context));
|
||||||
case torch_upstream::ScalarType::ComplexDouble:
|
case torch_upstream::ScalarType::ComplexDouble:
|
||||||
return mlir::ComplexType::get(Float64Type::get(context));
|
return mlir::ComplexType::get(Float64Type::get(context));
|
||||||
|
case torch_upstream::ScalarType::Float8_e5m2:
|
||||||
|
return Float8E5M2Type::get(context);
|
||||||
|
case torch_upstream::ScalarType::Float8_e4m3fn:
|
||||||
|
return Float8E4M3FNType::get(context);
|
||||||
|
case torch_upstream::ScalarType::Float8_e5m2fnuz:
|
||||||
|
return Float8E5M2FNUZType::get(context);
|
||||||
|
case torch_upstream::ScalarType::Float8_e4m3fnuz:
|
||||||
|
return Float8E4M3FNUZType::get(context);
|
||||||
case torch_upstream::ScalarType::Undefined:
|
case torch_upstream::ScalarType::Undefined:
|
||||||
return failure();
|
return failure();
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in New Issue