Skip to content

Commit

Permalink
Fixes monkey patching attribute error triggered by cl/629129623.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629179580
  • Loading branch information
Nolan Miller authored and The oryx Authors committed Apr 29, 2024
1 parent 23ab8f2 commit e7a0987
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions oryx/core/interpreters/inverse/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ def convert_element_type_ildj(incells, outcells, *, new_dtype, **params):
jax.scipy.special.logit = custom_inverse(jax.scipy.special.logit)
jax.nn.sigmoid = jax.scipy.special.expit
jax.nn.softplus = custom_inverse(jax.nn.softplus)
jax.scipy.special.expit.def_inverse_unary(f_inv=jax.scipy.special.logit,
f_ildj=expit_ildj)
jax.scipy.special.logit.def_inverse_unary(f_inv=jax.scipy.special.expit,
f_ildj=logit_ildj)
jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv,
f_ildj=softplus_ildj)

jax.scipy.special.expit.def_inverse_unary(
f_inv=jax.scipy.special.logit, f_ildj=expit_ildj
) # pytype: disable=attribute-error
jax.scipy.special.logit.def_inverse_unary(
f_inv=jax.scipy.special.expit, f_ildj=logit_ildj
) # pytype: disable=attribute-error
jax.nn.softplus.def_inverse_unary(f_inv=softplus_inv, f_ildj=softplus_ildj) # pytype: disable=attribute-error

0 comments on commit e7a0987

Please sign in to comment.