Skip to content

Commit

Permalink
Fix binominal distribution (#1860)
Browse files Browse the repository at this point in the history
* Comments added why the binominal_dispatch function will run into an infinite loop

* Fix for the infinite loop problem of binominal_dispatch.

* Update of the fix to a more concise version.

* Linting fix

* Changed to the proposed solution, to correct log1_p value correctly
  • Loading branch information
InfinityMod authored Sep 10, 2024
1 parent bb7767e commit f5aca91
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def _binom_inv_cond_fn(val):
return cond_exclude_large_mu & (geom_acc <= n)

log1_p = jnp.log1p(-p)
# Make sure p=0 is never taken into account as a fix for possible zeros in p.
log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
return ret[0]

Expand Down

0 comments on commit f5aca91

Please sign in to comment.