We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi! Thanks for this implementation!
I am trying to use this implementation but I am running into the error in the title of this issue. Here is what I am working with:
import jax import jax.numpy as jnp import jax.scipy.stats as stats from typing import Callable, Tuple import haiku as hk from jax.random import PRNGKey, split import optax import matplotlib.pyplot as plt from optax_swag import swag def nll(apply_fn: Callable): def _nll(params, batch: Tuple[jax.Array, jax.Array]) -> float: x, y = batch out = apply_fn(params, x) ll = stats.norm.logpdf(out, y) return - ll.sum() return _nll def generate_data(): x = jnp.linspace(0, 10, 25).reshape(-1, 1) y = jnp.sin(0.4 * x) + 3 return x, y def make_small_mlp(): relu = jax.nn.relu def small_mlp(x): mlp = hk.Sequential([ hk.Linear(50), relu, hk.Linear(50), relu, hk.Linear(50), relu, hk.Linear(1)]) return mlp(x) return hk.transform(small_mlp) def train_model(params, model_apply, data, opt_init, opt_update, epochs, loss_fn): loss_fn = nll(model_apply) x, y = data opt_state = opt_init(params) print(opt_state) @jax.jit def train_one_epoch(params, opt_state): nll_val, grad = jax.value_and_grad(loss_fn)(params, (x,y)) updates, opt_state = opt_update(grad, opt_state) params = optax.apply_updates(params, updates) return params, opt_state, nll_val for i in range(epochs): params, opt_state, nll_val = train_one_epoch(params, opt_state) print(f"STEP {i} | NLL: {nll_val}") preds = model_apply(params, x) plt.plot(x, y) plt.plot(x, preds) plt.show() return params model_init_key, _, _, _, _ = split(PRNGKey(123), 5) x,y = generate_data() mlp = make_small_mlp() params = mlp.init(model_init_key, x[0]) model_apply = lambda params, x: mlp.apply(params, None, x) opt_init, opt_update = optax.chain(optax.adam(1e-3), swag(5, 5)) params = train_model(params, model_apply, (x,y), opt_init, opt_update, 500, nll) # 'ValueError: Expected dict, got None.'
Any ideas of what I may doing wrong? Thanks so much!
The text was updated successfully, but these errors were encountered:
Can you also add a full stack trace?
Sorry, something went wrong.
@activatedgeek sure thing!
Traceback (most recent call last): File "/Users/pscemama/jax/blah.py", line 86, in <module> params = train_model(params, model_apply, (x, y), opt_init, opt_update, 500, nll) File "/Users/pscemama/jax/blah.py", line 66, in train_model params, opt_state, nll_val = train_one_epoch(params, opt_state) File "/Users/pscemama/jax/blah.py", line 61, in train_one_epoch updates, opt_state = opt_update(grad, opt_state) File "/Users/pscemama/jax/.jax_env/lib/python3.10/site-packages/optax/_src/combine.py", line 59, in update_fn updates, new_s = fn(updates, s, params, **extra_args) File "/Users/pscemama/jax/.jax_env/lib/python3.10/site-packages/optax/_src/base.py", line 311, in update return tx.update(updates, state, params) File "/Users/pscemama/jax/.jax_env/lib/python3.10/site-packages/optax_swag/transform.py", line 81, in update_fn next_mean = jax.tree_util.tree_map(lambda mu, np: jnp.where(update_mask, (n * mu + np) / (n + 1), mu), ValueError: Expected dict, got None.
No branches or pull requests
Hi! Thanks for this implementation!
I am trying to use this implementation but I am running into the error in the title of this issue. Here is what I am working with:
Any ideas of what I may doing wrong? Thanks so much!
The text was updated successfully, but these errors were encountered: