Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Performance due to Jax version #746

Open
junpenglao opened this issue Oct 6, 2024 · 2 comments
Open

Potential Performance due to Jax version #746

junpenglao opened this issue Oct 6, 2024 · 2 comments

Comments

@junpenglao
Copy link
Member

Describe the issue as clearly as possible:

Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34

Reproduced locally:
On JAX 0.4.30

-------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------
Name (time in s)            Min               Max              Mean            StdDev            Median               IQR            Outliers     OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts     4.2739 (1.0)      5.0262 (1.0)      4.7685 (1.0)      0.3398 (1.0)      4.9671 (1.0)      0.5408 (1.0)           1;0  0.2097 (1.0)           5           1
test_regression_hmc      7.2055 (1.69)     8.1514 (1.62)     7.6479 (1.60)     0.4128 (1.22)     7.5257 (1.52)     0.7291 (1.35)          2;0  0.1308 (0.62)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

On JAX 0.4.34

---------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------
Name (time in s)             Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_regression_nuts      9.2754 (1.0)      10.2643 (1.0)       9.6681 (1.0)      0.3660 (1.0)       9.6078 (1.0)      0.3647 (1.0)           2;0  0.1034 (1.0)           5           1
test_regression_hmc      19.7752 (2.13)     21.4303 (2.09)     20.6382 (2.13)     0.7185 (1.96)     20.4793 (2.13)     1.2633 (3.46)          2;0  0.0485 (0.47)          5           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Steps/code to reproduce the bug:

Ping to a jax version and run 

pytest --benchmark-only


### Expected result:

```shell
n.a

Error message:

n.a

Blackjax/JAX/jaxlib/Python version information:

n.a

Context for the issue:

No response

@junpenglao
Copy link
Member Author

related: pyro-ppl/numpyro#1867
likely rootcause and workaround see: jax-ml/jax#23822

@ColCarroll
Copy link
Contributor

ColCarroll commented Oct 29, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants