Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialization of float8_e5m2 #148

Open
wonjeon opened this issue Apr 25, 2024 · 2 comments
Open

Deserialization of float8_e5m2 #148

wonjeon opened this issue Apr 25, 2024 · 2 comments
Assignees

Comments

@wonjeon
Copy link

wonjeon commented Apr 25, 2024

I tried the following code snippet, and it doesn't seem to work. Is this an already known issue?

>>> a = float8_e5m2(1.5)
>>> np.save("a.npy", a)
>>> b = np.load("a.npy")
Traceback (most recent call last):
  File "/home/wonjeo01/.local/lib/python3.10/site-packages/numpy/lib/format.py", line 640, in _read_array_header
    dtype = descr_to_dtype(d['descr'])
  File "/home/wonjeo01/.local/lib/python3.10/site-packages/numpy/lib/format.py", line 309, in descr_to_dtype
    return numpy.dtype(descr)
TypeError: data type '<f1' not understood

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/wonjeo01/.local/lib/python3.10/site-packages/numpy/lib/npyio.py", line 432, in load
    return format.read_array(fid, allow_pickle=allow_pickle,
  File "/home/wonjeo01/.local/lib/python3.10/site-packages/numpy/lib/format.py", line 765, in read_array
    shape, fortran_order, dtype = _read_array_header(
  File "/home/wonjeo01/.local/lib/python3.10/site-packages/numpy/lib/format.py", line 643, in _read_array_header
    raise ValueError(msg.format(d['descr'])) from e
ValueError: descr is not a valid dtype descriptor: '<f1'
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 25, 2024

Thanks - yeah this is a known issue (similar to what's reported in jax-ml/jax#8494).

Unfortunately, numpy's serialization only recognizes numpy's built-in dtypes, and the package currently offers no way to extend that. The best workaround for the time being would be something like this:

>>> np.save('a.npy', a.view('uint8'))
>>> np.load('a.npy').view(float8_e5m2)
array(1.5, dtype='float8_e5m2')

@wonjeon
Copy link
Author

wonjeon commented Apr 29, 2024

@jakevdp Thanks for your response and the information on the workaround. Confirmed that it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants