You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importnumpyasnpimportjax.numpyasjnpa=jnp.array([jnp.nan], dtype=jnp.float32)
np.testing.assert_array_equal(a, a) # No errora=jnp.array([jnp.nan], dtype=jnp.bfloat16)
np.testing.assert_array_equal(a, a) # AssertionError
This is one of the many examples you can probably find of where NumPy hard-codes logic about its built-in set of dtypes. There's nothing that downstream dtype implementations like ml_dtypes can do to change this. We might think about raising this issue upstream in the NumPy package.
The cause is that
np.testing.assert_array_equal()
does not recognise bfloat16 as a "number" type: https://github.com/numpy/numpy/blob/b3ddf2fd33232b8939f48c7c68a61c10257cd0c5/numpy/testing/_private/utils.py#L773The text was updated successfully, but these errors were encountered: