We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34
Reproduced locally: On JAX 0.4.30
-------------------------------------------------------------------------------- benchmark: 2 tests ------------------------------------------------------------------------------- Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_regression_nuts 4.2739 (1.0) 5.0262 (1.0) 4.7685 (1.0) 0.3398 (1.0) 4.9671 (1.0) 0.5408 (1.0) 1;0 0.2097 (1.0) 5 1 test_regression_hmc 7.2055 (1.69) 8.1514 (1.62) 7.6479 (1.60) 0.4128 (1.22) 7.5257 (1.52) 0.7291 (1.35) 2;0 0.1308 (0.62) 5 1 -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
On JAX 0.4.34
---------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------- Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_regression_nuts 9.2754 (1.0) 10.2643 (1.0) 9.6681 (1.0) 0.3660 (1.0) 9.6078 (1.0) 0.3647 (1.0) 2;0 0.1034 (1.0) 5 1 test_regression_hmc 19.7752 (2.13) 21.4303 (2.09) 20.6382 (2.13) 0.7185 (1.96) 20.4793 (2.13) 1.2633 (3.46) 2;0 0.0485 (0.47) 5 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Ping to a jax version and run pytest --benchmark-only
### Expected result: ```shell n.a
n.a
No response
The text was updated successfully, but these errors were encountered:
related: pyro-ppl/numpyro#1867 likely rootcause and workaround see: jax-ml/jax#23822
Sorry, something went wrong.
No branches or pull requests
Describe the issue as clearly as possible:
Our benchmark runtime increased more than 2x after JAX version upgrade to 0.4.34
Reproduced locally:
On JAX 0.4.30
On JAX 0.4.34
Steps/code to reproduce the bug:
Error message:
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: