Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 635974880
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed May 22, 2024
1 parent e7a0987 commit b1a7685
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions oryx/examples/notebooks/probabilistic_programming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,10 @@
"source": [
"@jit\n",
"def run_chain(key, weights):\n",
" flat_state, sample_tree = jax.tree_flatten(weights)\n",
" flat_state, sample_tree = jax.tree.flatten(weights)\n",
"\n",
" def flat_log_prob(*states):\n",
" return target_log_prob(jax.tree_unflatten(sample_tree, states))\n",
" return target_log_prob(jax.tree.unflatten(sample_tree, states))\n",
"\n",
" def trace_fn(_, results):\n",
" return results.inner_results.accepted_results.target_log_prob\n",
Expand All @@ -922,7 +922,7 @@
" trace_fn=trace_fn,\n",
" current_state=flat_state,\n",
" seed=key)\n",
" samples = jax.tree_unflatten(sample_tree, flat_states)\n",
" samples = jax.tree.unflatten(sample_tree, flat_states)\n",
" return samples, log_probs\n",
"posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)"
]
Expand Down
2 changes: 1 addition & 1 deletion oryx/experimental/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _call(self, *args, **kwargs):

def init(self, init_key, *args, name=None, **kwargs):
"""Initializes a Template into a Layer."""
specs = jax.tree_map(state.make_array_spec, args)
specs = jax.tree.map(state.make_array_spec, args)
kwargs = dict(
cls=self.cls,
specs=specs,
Expand Down

0 comments on commit b1a7685

Please sign in to comment.