Skip to content

Commit

Permalink
- moved learning rate calculation out of unit loop to avoid unnecessa…
Browse files Browse the repository at this point in the history
…ry recalculation

- added linear sheduler to one example to showcase usage
  • Loading branch information
kim-mskw committed Dec 3, 2024
1 parent 56bcb6b commit 1545b57
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 1 addition & 4 deletions assume/reinforcement_learning/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self.float_type = self.learning_role.float_type

def update_learning_rate(
self, optimizers: list[th.optim.Optimizer] | th.optim.Optimizer
self, optimizers: list[th.optim.Optimizer] | th.optim.Optimizer, learning_rate: float
) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0).
Expand All @@ -88,9 +88,6 @@ def update_learning_rate(
- https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/utils.py#L68
"""
learning_rate = self.learning_role.calc_lr_from_progress(
self.learning_role.get_progress_remaining()
)

if not isinstance(optimizers, list):
optimizers = [optimizers]
Expand Down
6 changes: 5 additions & 1 deletion assume/reinforcement_learning/algorithms/matd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,17 @@ def update_policy(self):
self.learning_role.get_progress_remaining()
)

learning_rate = self.learning_role.calc_lr_from_progress(
self.learning_role.get_progress_remaining()
)

# loop again over all units to avoid update call for every gradient step, as it will be ambiguous
for u_id, unit_strategy in self.learning_role.rl_strats.items():
self.update_learning_rate(
[
self.learning_role.critics[u_id].optimizer,
self.learning_role.rl_strats[u_id].actor.optimizer,
]
], learning_rate=learning_rate
)
unit_strategy.action_noise.update_noise_decay(updated_noise_decay)

Expand Down
2 changes: 2 additions & 0 deletions examples/inputs/example_02a/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ base:
max_bid_price: 100
algorithm: matd3
learning_rate: 0.001
learning_rate_schedule: linear
training_episodes: 50
episodes_collecting_initial_experience: 5
train_freq: 24h
gradient_steps: -1
batch_size: 256
gamma: 0.99
device: cpu
action_noise_schedule: linear
noise_sigma: 0.1
noise_scale: 1
noise_dt: 1
Expand Down

0 comments on commit 1545b57

Please sign in to comment.