diff --git a/oryx/examples/notebooks/probabilistic_programming.ipynb b/oryx/examples/notebooks/probabilistic_programming.ipynb index e339b87..7da562a 100644 --- a/oryx/examples/notebooks/probabilistic_programming.ipynb +++ b/oryx/examples/notebooks/probabilistic_programming.ipynb @@ -905,10 +905,10 @@ "source": [ "@jit\n", "def run_chain(key, weights):\n", - " flat_state, sample_tree = jax.tree_flatten(weights)\n", + " flat_state, sample_tree = jax.tree.flatten(weights)\n", "\n", " def flat_log_prob(*states):\n", - " return target_log_prob(jax.tree_unflatten(sample_tree, states))\n", + " return target_log_prob(jax.tree.unflatten(sample_tree, states))\n", "\n", " def trace_fn(_, results):\n", " return results.inner_results.accepted_results.target_log_prob\n", @@ -922,7 +922,7 @@ " trace_fn=trace_fn,\n", " current_state=flat_state,\n", " seed=key)\n", - " samples = jax.tree_unflatten(sample_tree, flat_states)\n", + " samples = jax.tree.unflatten(sample_tree, flat_states)\n", " return samples, log_probs\n", "posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)" ] diff --git a/oryx/experimental/nn/base.py b/oryx/experimental/nn/base.py index 1dbbddb..836a5f9 100644 --- a/oryx/experimental/nn/base.py +++ b/oryx/experimental/nn/base.py @@ -336,7 +336,7 @@ def _call(self, *args, **kwargs): def init(self, init_key, *args, name=None, **kwargs): """Initializes a Template into a Layer.""" - specs = jax.tree_map(state.make_array_spec, args) + specs = jax.tree.map(state.make_array_spec, args) kwargs = dict( cls=self.cls, specs=specs,