-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Q] How to properly save and load fp8 NumPy arrays? #207
Comments
Unfortunately NumPy's array serialization only works with NumPy's built-in dtypes. Probably the easiest way to serialize arrays with custom dtypes is to view them as unsigned int: import ml_dtypes
import numpy as np
import json
# Create the array
x = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)
np.save('x.npy', x.view('uint8'))
x2 = np.load('x.npy').view(ml_dtypes.float8_e5m2)
print(np.all(x == x2))
# True Your approach of serializing the raw bytes also works, though I'd recommend not naming the file with a |
Hi Jake, Thank you for your reply! I have one additional question I tired to use pickle. import ml_dtypes
import numpy as np
import pickle
# Create the array
a = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e5m2)
b = np.array([.2, .4, .6], dtype=ml_dtypes.float8_e4m3)
# Save
with open('a.npy.pkl', "wb") as f:
pickle.dump(a, f)
with open('b.npy.pkl', "wb") as f:
pickle.dump(b, f)
# Load back
a2 = np.load('a.npy.pkl', allow_pickle=True)
b2 = np.load('b.npy.pkl', allow_pickle=True)
print(np.all(a == a2))
print(np.all(b == b2)) Seems that it works out of the box and saves ml_dtypes dtype info into the the file. What are the disadvantages of using pickle? Cons which I found:
|
Yes, pickle works, but has downsides. The two you mention are the main issues: unpickling allows for arbitrary code execution, and will often break when used in an environment with different package versions. |
I would like to save and load an f8m5e2 array. I initially tried using the standard
numpy.save()
andnumpy.load()
functions, but loading fails.I found that I can save and load float8 arrays using a lower-level API (
np.tobytes
/np.frombuffer
), as shown below:Is the solution above (np.tobytes / np.frombuffer) considered best practice for this case?
@jakevdp Jake, can you comment on it?
Related Issues
The text was updated successfully, but these errors were encountered: