Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types for contrib.stochastic_support and improve the HSGP module #1907

Merged
merged 8 commits into from
Nov 15, 2024

Conversation

juanitorduz
Copy link
Contributor

Related to #299

@juanitorduz juanitorduz marked this pull request as draft November 13, 2024 22:22
@juanitorduz juanitorduz marked this pull request as ready for review November 14, 2024 11:05
@@ -3,6 +3,9 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType

from jaxlib.xla_extension import ArrayImpl
Copy link
Contributor

@tillahoffmann tillahoffmann Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what the implications of using ArrayImpl vs jax.numpy.ndarray are?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great question! Thanks for the feedback! So we started trying with type-hints with ArrayImpl in #1819 (comment)

On the other hand in https://jax.readthedocs.io/en/latest/jax.typing.html its suggested to use jax.typing.ArrayLike. However when I use this for x: ArrayLike and have inside the function x.ndim, mypy complains with

error: Item "builtins.bool" of "Array | ndarray[Any, Any] | numpy.bool | number[Any] | builtins.bool | int | float | complex" has no attribute "ndim"  [union-attr]

Hence, this is why I am continuing with ArrayImpl and waiting for feedback :)

Copy link
Member

@fehiepsi fehiepsi Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid importing stuff from jaxlib. I guess it is better to use jax.Array (or jnp.ndarray)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re ndim: you can use jnp.ndim(x), which will accept ArrayLike instances, including python scalars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great input! I'll try that then 🙂!

@juanitorduz juanitorduz changed the title Add types for contrib.stochastic_support Add types for contrib.stochastic_support and improve the HSGP module Nov 15, 2024
@juanitorduz juanitorduz marked this pull request as draft November 15, 2024 08:41
@juanitorduz juanitorduz marked this pull request as ready for review November 15, 2024 12:26
@juanitorduz
Copy link
Contributor Author

juanitorduz commented Nov 15, 2024

ok! After changing the types (and also updating #1906). I think this is the way to go as suggested ( I even caught minor issues) :)

JAX Typing Best Practices

When annotating JAX arrays in public API functions, we recommend using ArrayLike for array inputs, and Array for array outputs.

As always, feedback is welcome!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I like ArrayLike -> Array, looks consistent with jax api. If not much work, let's also replace jnp.ndarray by jax.Array (optional).

@fehiepsi fehiepsi merged commit 66921bd into pyro-ppl:master Nov 15, 2024
4 checks passed
@juanitorduz juanitorduz deleted the types-stochastic-support branch November 15, 2024 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants