[fx importer] support fx importer with lower version torch (#3486)

pull/3494/head
Yuanqiang Liu 2024-06-24 15:39:19 +08:00 committed by GitHub
parent fc19709daa
commit 61f37ae8a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 30 additions and 12 deletions

View File

@ -151,11 +151,17 @@ TORCH_DTYPE_TO_MLIR_TYPE_ASM = {
torch.complex32: "complex<f16>",
torch.complex64: "complex<f32>",
torch.complex128: "complex<f64>",
torch.float8_e5m2: "f8E5M2",
torch.float8_e4m3fn: "f8E4M3FN",
torch.float8_e5m2fnuz: "f8E5M2FNUZ",
torch.float8_e4m3fnuz: "f8E4M3FNUZ",
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM = {
"float8_e5m2": "f8E5M2",
"float8_e4m3fn": "f8E4M3FN",
"float8_e5m2fnuz": "f8E5M2FNUZ",
"float8_e4m3fnuz": "f8E4M3FNUZ",
}
for dtype_str, dtype_asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, dtype_str)] = dtype_asm
TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
torch.float16: lambda: F16Type.get(),
@ -173,11 +179,17 @@ TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
torch.complex32: lambda: ComplexType.get(F16Type.get()),
torch.complex64: lambda: ComplexType.get(F32Type.get()),
torch.complex128: lambda: ComplexType.get(F64Type.get()),
torch.float8_e5m2: lambda: Float8E5M2Type.get(),
torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(),
torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(),
torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(),
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = {
"float8_e5m2": lambda: Float8E5M2Type.get(),
"float8_e4m3fn": lambda: Float8E4M3FNType.get(),
"float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(),
"float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(),
}
for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type
TORCH_DTYPE_TO_NPY_TYPE = {
# torch.qint8: None, # no equivalent np datatype
@ -215,11 +227,17 @@ TORCH_DTYPE_TO_INT = {
# torch.quint8: 13,
# torch.qint32 14
torch.bfloat16: 15,
torch.float8_e5m2: 23,
torch.float8_e4m3fn: 24,
torch.float8_e5m2fnuz: 25,
torch.float8_e4m3fnuz: 26,
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_INT = {
"float8_e5m2": 23,
"float8_e4m3fn": 24,
"float8_e5m2fnuz": 25,
"float8_e4m3fnuz": 26,
}
for dtype_str, dtype_int in OPTIONAL_TORCH_DTYPE_TO_INT.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_INT[getattr(torch, dtype_str)] = dtype_int
TORCH_MEMORY_FORMAT_TO_INT = {
torch.contiguous_format: 0,