Skip to content

Commit

Permalink
Add an RL example in Jax (#55)
Browse files Browse the repository at this point in the history
* Add a Jax+RL example based on rejax.PPO

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove some of the unused code

Signed-off-by: Fabrice Normandin <[email protected]>

* Move things around a bit

Signed-off-by: Fabrice Normandin <[email protected]>

* Update version requirements for jax/torch

Signed-off-by: Fabrice Normandin <[email protected]>

* Use xtills for cleaner Jit with annotations

Signed-off-by: Fabrice Normandin <[email protected]>

* Save gif every epoch

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix rendering of classic-control gymnax envs

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a "pure jax" training loop option

Signed-off-by: Fabrice Normandin <[email protected]>

* Fused training step in Lightning module

Signed-off-by: Fabrice Normandin <[email protected]>

* Works without hash warnings now!

Signed-off-by: Fabrice Normandin <[email protected]>

* Reorganize the code a bit

Signed-off-by: Fabrice Normandin <[email protected]>

* Use vmap to train multiple agents in parallel

Signed-off-by: Fabrice Normandin <[email protected]>

* Add a jax analogue to lightning.Trainer

Signed-off-by: Fabrice Normandin <[email protected]>

* Add the equivalent of lightning.Callback for jax

Signed-off-by: Fabrice Normandin <[email protected]>

* Log hyper-parameters

Signed-off-by: Fabrice Normandin <[email protected]>

* Progress bar almost works

Signed-off-by: Fabrice Normandin <[email protected]>

* Managed to get the progress bar to work!

Signed-off-by: Fabrice Normandin <[email protected]>

* Move the trainer + callback to a different file

Signed-off-by: Fabrice Normandin <[email protected]>

* Make stuff generic (not tied to PPOLearner)

Signed-off-by: Fabrice Normandin <[email protected]>

* Update gymnax to improve rendering performance

Signed-off-by: Fabrice Normandin <[email protected]>

* Add configs, tweak experiment/main

Signed-off-by: Fabrice Normandin <[email protected]>

* wip: fixing issues in experiment.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix config now that network is optional

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue with progress bar callback!

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix duplicated code in main.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Move tests / Lightning wrapper to test file

Signed-off-by: Fabrice Normandin <[email protected]>

* Rename things, add docstring to JaxTrainer

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix links in docstrings of JaxTrainer / JaxModule

Signed-off-by: Fabrice Normandin <[email protected]>

* Tweak the docs of JaxModule/JaxTrainer

Signed-off-by: Fabrice Normandin <[email protected]>

* Use regression fixtures in test

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the ref in the JaxTrainer docstring

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix small errors that break CI

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug in test_rejax

Signed-off-by: Fabrice Normandin <[email protected]>

* "fix" config schema generation errors

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix test_rejax function

Signed-off-by: Fabrice Normandin <[email protected]>

* Test the `train` method to replicate rejax.PPO

Signed-off-by: Fabrice Normandin <[email protected]>

* Move Jax typing utils to a new module

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix default param causing preallocation of GPU mem

Signed-off-by: Fabrice Normandin <[email protected]>

* Add comments in conftest.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix test for rejax, add more todos in conftest.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug in lightning wrapper for rejax.PPO

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue in test_config from conftest change

Signed-off-by: Fabrice Normandin <[email protected]>

* (temp) make the tests run in unit test runs

Signed-off-by: Fabrice Normandin <[email protected]>

* Tweaks to the jax typing utils

Signed-off-by: Fabrice Normandin <[email protected]>

* Move the JaxTrainer to a new "trainers" dir

Signed-off-by: Fabrice Normandin <[email protected]>

* Simplify docs in `jax_trainer.py`

Signed-off-by: Fabrice Normandin <[email protected]>

* Move things around, add pytest.mark.slow marks

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix bug with config target type inference

Signed-off-by: Fabrice Normandin <[email protected]>

* Move things around in jax_rl_example_test.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Add some docstrings

Signed-off-by: Fabrice Normandin <[email protected]>

* Re-organize tests, update regression files

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the missing indexing in test for equivalence

Signed-off-by: Fabrice Normandin <[email protected]>

* Don't use file_regression with gifs

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix issue with jax_rl_example_test.test_lightning

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Oct 11, 2024
1 parent a5acd0b commit 682cce6
Show file tree
Hide file tree
Showing 27 changed files with 2,843 additions and 155 deletions.
3 changes: 3 additions & 0 deletions .regression_files/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.gif
# Ignore tensor regression files.
*.npz
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
val/episode_lengths:
max: '2.e+02'
mean: '2.e+02'
min: '2.e+02'
shape: []
sum: '2.e+02'
val/rewards:
max: '-1.222e+03'
mean: '-1.222e+03'
min: '-1.222e+03'
shape: []
sum: '-1.222e+03'
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-6.495e+02'
mean: '-1.229e+03'
min: '-1.878e+03'
shape:
- 76
- 128
sum: '-1.196e+07'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-6.495e+02'
mean: '-1.229e+03'
min: '-1.878e+03'
shape:
- 76
- 128
sum: '-1.196e+07'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-4.319e-01'
mean: '-5.755e+02'
min: '-1.872e+03'
shape:
- 76
- 128
sum: '-5.599e+06'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
11 changes: 11 additions & 0 deletions docs/examples/jax_rl_example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
additional_python_references:
- project.algorithms.jax_rl_example
- project.trainers.jax_trainer
---

# Reinforcement Learning (Jax)

## JaxTrainer

The `JaxTrainer` is
2 changes: 2 additions & 0 deletions docs/generate_reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# based on https://github.com/mkdocstrings/mkdocstrings/blob/5802b1ef5ad9bf6077974f777bd55f32ce2bc219/docs/gen_doc_stubs.py#L25


import os
from logging import getLogger as get_logger
from pathlib import Path

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
logger = get_logger(__name__)


Expand Down
2 changes: 2 additions & 0 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .example import ExampleAlgorithm
from .hf_example import HFExample
from .jax_example import JaxExample
from .jax_rl_example import JaxRLExample
from .no_op import NoOp

__all__ = [
"ExampleAlgorithm",
"JaxExample",
"NoOp",
"HFExample",
"JaxRLExample",
]
55 changes: 41 additions & 14 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Literal
from typing import Any, Literal

from lightning import LightningModule, Trainer
from torch import Tensor
Expand All @@ -11,11 +11,11 @@


class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]):
def __init__(self):
def __init__(self, num_optimizers: int | None = None):
super().__init__()
self.last_step_times: dict[Literal["train", "val", "test"], float] = {}
self.last_update_time: dict[int, float | None] = {}
self.num_optimizers: int | None = None
self.num_optimizers: int | None = num_optimizers

@override
def on_shared_epoch_start(
Expand Down Expand Up @@ -56,19 +56,44 @@ def on_shared_batch_end(
now = time.perf_counter()
if phase in self.last_step_times:
elapsed = now - self.last_step_times[phase]
if is_sequence_of(batch, Tensor):
batch_size = batch[0].shape[0]
pl_module.log(
f"{phase}/samples_per_second",
batch_size / elapsed,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
batch_size = self.get_num_samples(batch)
self.log(
f"{phase}/samples_per_second",
batch_size / elapsed,
module=pl_module,
trainer=trainer,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
batch_size=batch_size,
)
# todo: support other kinds of batches
self.last_step_times[phase] = now

def log(
self,
name: str,
value: Any,
module: LightningModule | Any,
trainer: Trainer | Any,
**kwargs,
):
# Used to possibly customize how the values are logged (e.g. for non-LightningModules).
# By default, uses the LightningModule.log method.
return module.log(
name,
value,
**kwargs,
)

def get_num_samples(self, batch: BatchType) -> int:
if is_sequence_of(batch, Tensor):
return batch[0].shape[0]
raise NotImplementedError(
f"Don't know how many 'samples' there are in batch of type {type(batch)}"
)

@override
def on_before_optimizer_step(
self,
Expand All @@ -89,9 +114,11 @@ def on_before_optimizer_step(
key = "ups"
else:
key = f"optimizer_{opt_idx}/ups"
pl_module.log(
self.log(
key,
updates_per_second,
module=pl_module,
trainer=trainer,
prog_bar=False,
on_step=True,
)
36 changes: 6 additions & 30 deletions project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
import logging
import os
from collections.abc import Callable
from typing import Concatenate, Literal, ParamSpec, TypeVar
from typing import Literal

import chex
import flax.linen
import jax
import rich
Expand All @@ -21,8 +21,6 @@
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.typing_utils.protocols import ClassificationDataModule

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


def flatten(x: jax.Array) -> jax.Array:
return x.reshape((x.shape[0], -1))
Expand Down Expand Up @@ -58,8 +56,8 @@ class JaxFcNet(flax.linen.Module):
num_features: int = 256

@flax.linen.compact
def __call__(self, x: jax.Array):
x = flatten(x)
def __call__(self, x: jax.Array, forward_rng: chex.PRNGKey | None = None):
# x = flatten(x)
x = flax.linen.Dense(features=self.num_features)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
Expand Down Expand Up @@ -89,6 +87,8 @@ def __init__(
hp: HParams = HParams(),
):
super().__init__()
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

self.datamodule = datamodule
self.hp = hp or self.HParams()

Expand Down Expand Up @@ -193,30 +193,6 @@ def to_channels_last(x: jax.Array) -> jax.Array:
return x.transpose(0, 2, 3, 1)


P = ParamSpec("P")
Out = TypeVar("Out")


def jit(
fn: Callable[P, Out],
) -> Callable[P, Out]:
"""Small type hint fix for jax's `jit` (preserves the signature of the callable)."""
return jax.jit(fn) # type: ignore


In = TypeVar("In")
Aux = TypeVar("Aux")


def value_and_grad(
fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
argnums: Literal[0] = 0,
has_aux: Literal[True] = True,
) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]:
"""Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable)."""
return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore


def main():
logging.basicConfig(
level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()]
Expand Down
Loading

0 comments on commit 682cce6

Please sign in to comment.