Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a Jax+RL example based on rejax.PPO Signed-off-by: Fabrice Normandin <[email protected]> * Remove some of the unused code Signed-off-by: Fabrice Normandin <[email protected]> * Move things around a bit Signed-off-by: Fabrice Normandin <[email protected]> * Update version requirements for jax/torch Signed-off-by: Fabrice Normandin <[email protected]> * Use xtills for cleaner Jit with annotations Signed-off-by: Fabrice Normandin <[email protected]> * Save gif every epoch Signed-off-by: Fabrice Normandin <[email protected]> * Fix rendering of classic-control gymnax envs Signed-off-by: Fabrice Normandin <[email protected]> * Add a "pure jax" training loop option Signed-off-by: Fabrice Normandin <[email protected]> * Fused training step in Lightning module Signed-off-by: Fabrice Normandin <[email protected]> * Works without hash warnings now! Signed-off-by: Fabrice Normandin <[email protected]> * Reorganize the code a bit Signed-off-by: Fabrice Normandin <[email protected]> * Use vmap to train multiple agents in parallel Signed-off-by: Fabrice Normandin <[email protected]> * Add a jax analogue to lightning.Trainer Signed-off-by: Fabrice Normandin <[email protected]> * Add the equivalent of lightning.Callback for jax Signed-off-by: Fabrice Normandin <[email protected]> * Log hyper-parameters Signed-off-by: Fabrice Normandin <[email protected]> * Progress bar almost works Signed-off-by: Fabrice Normandin <[email protected]> * Managed to get the progress bar to work! Signed-off-by: Fabrice Normandin <[email protected]> * Move the trainer + callback to a different file Signed-off-by: Fabrice Normandin <[email protected]> * Make stuff generic (not tied to PPOLearner) Signed-off-by: Fabrice Normandin <[email protected]> * Update gymnax to improve rendering performance Signed-off-by: Fabrice Normandin <[email protected]> * Add configs, tweak experiment/main Signed-off-by: Fabrice Normandin <[email protected]> * wip: fixing issues in experiment.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix config now that network is optional Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with progress bar callback! Signed-off-by: Fabrice Normandin <[email protected]> * Fix duplicated code in main.py Signed-off-by: Fabrice Normandin <[email protected]> * Move tests / Lightning wrapper to test file Signed-off-by: Fabrice Normandin <[email protected]> * Rename things, add docstring to JaxTrainer Signed-off-by: Fabrice Normandin <[email protected]> * Fix links in docstrings of JaxTrainer / JaxModule Signed-off-by: Fabrice Normandin <[email protected]> * Tweak the docs of JaxModule/JaxTrainer Signed-off-by: Fabrice Normandin <[email protected]> * Use regression fixtures in test Signed-off-by: Fabrice Normandin <[email protected]> * Fix the ref in the JaxTrainer docstring Signed-off-by: Fabrice Normandin <[email protected]> * Fix small errors that break CI Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in test_rejax Signed-off-by: Fabrice Normandin <[email protected]> * "fix" config schema generation errors Signed-off-by: Fabrice Normandin <[email protected]> * Fix test_rejax function Signed-off-by: Fabrice Normandin <[email protected]> * Test the `train` method to replicate rejax.PPO Signed-off-by: Fabrice Normandin <[email protected]> * Move Jax typing utils to a new module Signed-off-by: Fabrice Normandin <[email protected]> * Fix default param causing preallocation of GPU mem Signed-off-by: Fabrice Normandin <[email protected]> * Add comments in conftest.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix test for rejax, add more todos in conftest.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in lightning wrapper for rejax.PPO Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue in test_config from conftest change Signed-off-by: Fabrice Normandin <[email protected]> * (temp) make the tests run in unit test runs Signed-off-by: Fabrice Normandin <[email protected]> * Tweaks to the jax typing utils Signed-off-by: Fabrice Normandin <[email protected]> * Move the JaxTrainer to a new "trainers" dir Signed-off-by: Fabrice Normandin <[email protected]> * Simplify docs in `jax_trainer.py` Signed-off-by: Fabrice Normandin <[email protected]> * Move things around, add pytest.mark.slow marks Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug with config target type inference Signed-off-by: Fabrice Normandin <[email protected]> * Move things around in jax_rl_example_test.py Signed-off-by: Fabrice Normandin <[email protected]> * Add some docstrings Signed-off-by: Fabrice Normandin <[email protected]> * Re-organize tests, update regression files Signed-off-by: Fabrice Normandin <[email protected]> * Fix the missing indexing in test for equivalence Signed-off-by: Fabrice Normandin <[email protected]> * Don't use file_regression with gifs Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with jax_rl_example_test.test_lightning Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
- Loading branch information