Add support for bfloat16 in fximporter (#2896)

this introduces an additional soft dependency on the python ml_dtypes
python packages in order to support bfloat16

Addresses #2843
pull/2839/merge
Daniel Garvey 2024-02-14 16:24:25 -06:00 committed by GitHub
parent e7a09440d3
commit 77b7550997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 3 deletions

View File

@ -43,6 +43,13 @@ from torch.fx import (
Graph,
GraphModule,
)
try:
import ml_dtypes
except ModuleNotFoundError:
# The third-party ml_dtypes package provides some optional
# low precision data-types. If used in this file, it is
# conditional.
ml_dtypes = None
from torch.fx.node import (
Argument as NodeArgument,
@ -137,7 +144,6 @@ TORCH_DTYPE_TO_NPY_TYPE = {
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
# torch.bf16: None, there's no equivalent np datatype so this isn't supported right now
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
@ -146,6 +152,8 @@ TORCH_DTYPE_TO_NPY_TYPE = {
torch.complex64: np.complex64,
torch.complex128: np.complex128,
}
if ml_dtypes is not None:
TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16
TORCH_DTYPE_TO_INT = {
torch.uint8: 0,
@ -1090,6 +1098,10 @@ def _make_vtensor_literal_op(
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
# check support for bfloat16
assert (
not (tensor.dtype == torch.bfloat16 and ml_dtypes is None)
), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
# Resolve the attribute.
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
assert (
@ -1115,7 +1127,7 @@ def _make_vtensor_literal_op(
type=element_type, array=np_tensor, shape=np_tensor.shape
)
else:
bytes_view = memoryview(np_tensor)
bytes_view = np_tensor.view(npy_dtype)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"

View File

@ -1,4 +1,4 @@
pillow
dill
multiprocess
onnx==1.15.0
onnx==1.15.0