Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best Practice for Passing/Storing Training Progress for Curriculum Learning in Brax #537

Open
hukz18 opened this issue Oct 19, 2024 · 2 comments

Comments

@hukz18
Copy link

hukz18 commented Oct 19, 2024

Hi Brax team,

I’m working on a reinforcement learning project using Brax to train a PPO agent and I’m trying to implement curriculum learning by adjusting the environment's difficulty dynamically based on the training progress (e.g., current_steps or number of episodes). My goal is to pass this information to the environment during training so that I can change certain parameters (like gravity, object mass, etc.) as the agent progresses.

I’ve thought of a solution where I modify the training code to pass the current training progress into the environment’s reset function. Here’s a simplified example of what I have in mind:

reset_fn = jax.jit(jax.vmap(lambda x: env.reset(x, current_step)))

However, this requires modifying the reset_fn in the training loop (brax/training/agents/ppo/train.py) to pass the training progress manually. And I also need to modify all the reset functions of the wrappers to allow the current_step to be passed into the reset function.

I've also tried to simple store a scalar value in the environment like self.num_episodes = 0, and call self.num_episodes = self.num_episodes + 1 in the reset function, unfortunately, this value never actually changes despite the reset calls. So I wonder if there's a way to achieve this without changing the training code of Brax itself.

Question:
Is there a better practice for passing or storing training progress information (like current_steps) in Brax for curriculum learning? Specifically:

Is modifying the training code the best approach, or can this be handled more elegantly by the environment itself?
Can we store or retrieve the training progress (e.g., current_steps) in the environment without needing to modify the reset function directly?
I’d appreciate any advice or best practices you can suggest for implementing this kind of feature in Brax.

Thanks for your help!

@juhorsch
Copy link

Hi,

Any chance you figured it out?
Thanks!

@hukz18
Copy link
Author

hukz18 commented Nov 27, 2024

Hi, I ended up modifying the step function AutoResetWrapper and adding an episode_num item to keep track of the number of episodes for each environment, the modified wrapper looks like this:
inside the step function of AutoResetWrapper:

    state.info['episode_num'] = jp.where(
        state.done, state.info['episode_num'] + 1, state.info['episode_num']
    )

and add the 'episode_num' key to the state dict elsewhere. You can also track the total number of environment steps similarly.
However, this walkaround can't keep track of the progress inside the environment class, so I'd like to keep the issue open for now.

Hope it helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants