-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add types for contrib.stochastic_support and improve the HSGP module #1907
Add types for contrib.stochastic_support and improve the HSGP module #1907
Conversation
@@ -3,6 +3,9 @@ | |||
|
|||
from abc import ABC, abstractmethod | |||
from collections import OrderedDict, namedtuple | |||
from typing import Any, Callable, OrderedDict as OrderedDictType | |||
|
|||
from jaxlib.xla_extension import ArrayImpl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what the implications of using ArrayImpl
vs jax.numpy.ndarray
are?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great question! Thanks for the feedback! So we started trying with type-hints with ArrayImpl
in #1819 (comment)
On the other hand in https://jax.readthedocs.io/en/latest/jax.typing.html its suggested to use jax.typing.ArrayLike
. However when I use this for x: ArrayLike
and have inside the function x.ndim
, mypy
complains with
error: Item "builtins.bool" of "Array | ndarray[Any, Any] | numpy.bool | number[Any] | builtins.bool | int | float | complex" has no attribute "ndim" [union-attr]
Hence, this is why I am continuing with ArrayImpl
and waiting for feedback :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to avoid importing stuff from jaxlib. I guess it is better to use jax.Array
(or jnp.ndarray)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re ndim: you can use jnp.ndim(x), which will accept ArrayLike instances, including python scalars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great input! I'll try that then 🙂!
ok! After changing the types (and also updating #1906). I think this is the way to go as suggested ( I even caught minor issues) :)
As always, feedback is welcome! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I like ArrayLike -> Array, looks consistent with jax api. If not much work, let's also replace jnp.ndarray by jax.Array (optional).
Related to #299