mirror of https://github.com/llvm/torch-mlir
Add support for bfloat16 in fximporter (#2896)
this introduces an additional soft dependency on the python ml_dtypes python packages in order to support bfloat16 Addresses #2843pull/2839/merge
parent
e7a09440d3
commit
77b7550997
|
@ -43,6 +43,13 @@ from torch.fx import (
|
||||||
Graph,
|
Graph,
|
||||||
GraphModule,
|
GraphModule,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
import ml_dtypes
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# The third-party ml_dtypes package provides some optional
|
||||||
|
# low precision data-types. If used in this file, it is
|
||||||
|
# conditional.
|
||||||
|
ml_dtypes = None
|
||||||
|
|
||||||
from torch.fx.node import (
|
from torch.fx.node import (
|
||||||
Argument as NodeArgument,
|
Argument as NodeArgument,
|
||||||
|
@ -137,7 +144,6 @@ TORCH_DTYPE_TO_NPY_TYPE = {
|
||||||
torch.int16: np.int16,
|
torch.int16: np.int16,
|
||||||
torch.int32: np.int32,
|
torch.int32: np.int32,
|
||||||
torch.int64: np.int64,
|
torch.int64: np.int64,
|
||||||
# torch.bf16: None, there's no equivalent np datatype so this isn't supported right now
|
|
||||||
torch.float16: np.float16,
|
torch.float16: np.float16,
|
||||||
torch.float32: np.float32,
|
torch.float32: np.float32,
|
||||||
torch.float64: np.float64,
|
torch.float64: np.float64,
|
||||||
|
@ -146,6 +152,8 @@ TORCH_DTYPE_TO_NPY_TYPE = {
|
||||||
torch.complex64: np.complex64,
|
torch.complex64: np.complex64,
|
||||||
torch.complex128: np.complex128,
|
torch.complex128: np.complex128,
|
||||||
}
|
}
|
||||||
|
if ml_dtypes is not None:
|
||||||
|
TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16
|
||||||
|
|
||||||
TORCH_DTYPE_TO_INT = {
|
TORCH_DTYPE_TO_INT = {
|
||||||
torch.uint8: 0,
|
torch.uint8: 0,
|
||||||
|
@ -1090,6 +1098,10 @@ def _make_vtensor_literal_op(
|
||||||
) -> Operation:
|
) -> Operation:
|
||||||
mapping = py_attr_tracker.track(tensor)
|
mapping = py_attr_tracker.track(tensor)
|
||||||
if mapping.is_empty:
|
if mapping.is_empty:
|
||||||
|
# check support for bfloat16
|
||||||
|
assert (
|
||||||
|
not (tensor.dtype == torch.bfloat16 and ml_dtypes is None)
|
||||||
|
), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
||||||
# Resolve the attribute.
|
# Resolve the attribute.
|
||||||
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
|
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
|
||||||
assert (
|
assert (
|
||||||
|
@ -1115,7 +1127,7 @@ def _make_vtensor_literal_op(
|
||||||
type=element_type, array=np_tensor, shape=np_tensor.shape
|
type=element_type, array=np_tensor, shape=np_tensor.shape
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
bytes_view = memoryview(np_tensor)
|
bytes_view = np_tensor.view(npy_dtype)
|
||||||
tensor_type = create_mlir_tensor_type(tensor)
|
tensor_type = create_mlir_tensor_type(tensor)
|
||||||
shape_desc = "_".join([str(d) for d in tensor.shape])
|
shape_desc = "_".join([str(d) for d in tensor.shape])
|
||||||
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
||||||
|
|
Loading…
Reference in New Issue