mirror of https://github.com/llvm/torch-mlir
Add support for bfloat16 in fximporter (#2896)
this introduces an additional soft dependency on the python ml_dtypes python packages in order to support bfloat16 Addresses #2843pull/2839/merge
parent
e7a09440d3
commit
77b7550997
|
@ -43,6 +43,13 @@ from torch.fx import (
|
|||
Graph,
|
||||
GraphModule,
|
||||
)
|
||||
try:
|
||||
import ml_dtypes
|
||||
except ModuleNotFoundError:
|
||||
# The third-party ml_dtypes package provides some optional
|
||||
# low precision data-types. If used in this file, it is
|
||||
# conditional.
|
||||
ml_dtypes = None
|
||||
|
||||
from torch.fx.node import (
|
||||
Argument as NodeArgument,
|
||||
|
@ -137,7 +144,6 @@ TORCH_DTYPE_TO_NPY_TYPE = {
|
|||
torch.int16: np.int16,
|
||||
torch.int32: np.int32,
|
||||
torch.int64: np.int64,
|
||||
# torch.bf16: None, there's no equivalent np datatype so this isn't supported right now
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.float64: np.float64,
|
||||
|
@ -146,6 +152,8 @@ TORCH_DTYPE_TO_NPY_TYPE = {
|
|||
torch.complex64: np.complex64,
|
||||
torch.complex128: np.complex128,
|
||||
}
|
||||
if ml_dtypes is not None:
|
||||
TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16
|
||||
|
||||
TORCH_DTYPE_TO_INT = {
|
||||
torch.uint8: 0,
|
||||
|
@ -1090,6 +1098,10 @@ def _make_vtensor_literal_op(
|
|||
) -> Operation:
|
||||
mapping = py_attr_tracker.track(tensor)
|
||||
if mapping.is_empty:
|
||||
# check support for bfloat16
|
||||
assert (
|
||||
not (tensor.dtype == torch.bfloat16 and ml_dtypes is None)
|
||||
), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n"
|
||||
# Resolve the attribute.
|
||||
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
|
||||
assert (
|
||||
|
@ -1115,7 +1127,7 @@ def _make_vtensor_literal_op(
|
|||
type=element_type, array=np_tensor, shape=np_tensor.shape
|
||||
)
|
||||
else:
|
||||
bytes_view = memoryview(np_tensor)
|
||||
bytes_view = np_tensor.view(npy_dtype)
|
||||
tensor_type = create_mlir_tensor_type(tensor)
|
||||
shape_desc = "_".join([str(d) for d in tensor.shape])
|
||||
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
pillow
|
||||
dill
|
||||
multiprocess
|
||||
onnx==1.15.0
|
||||
onnx==1.15.0
|
||||
|
|
Loading…
Reference in New Issue