Skip to content

Commit

Permalink
Early stopping, learning rate and noise decay (#475)
Browse files Browse the repository at this point in the history
No early stopping by default
Linear schedule for learning rate and action noise decay
Basic scheduling options adapted from SB3

Co-authored by: kim-mskw
  • Loading branch information
mthede authored Dec 3, 2024
1 parent 16cbea2 commit 44fb0fc
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 40 deletions.
2 changes: 2 additions & 0 deletions assume/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ class LearningConfig(TypedDict):
algorithm: str
actor_architecture: str
learning_rate: float
learning_rate_schedule: str
training_episodes: int
episodes_collecting_initial_experience: int
train_freq: str
Expand All @@ -764,6 +765,7 @@ class LearningConfig(TypedDict):
noise_sigma: float
noise_scale: int
noise_dt: int
action_noise_schedule: str
trained_policies_save_path: str
early_stopping_steps: int
early_stopping_threshold: float
4 changes: 2 additions & 2 deletions assume/common/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def handle_output_message(self, content: dict, meta: MetaDict):
"market_meta",
"market_dispatch",
"unit_dispatch",
"rl_learning_params",
"rl_params",
]:
# these can be processed as a single dataframe
self.write_buffers[content_type].extend(content_data)
Expand Down Expand Up @@ -449,7 +449,7 @@ async def store_dfs(self):
df = self.convert_market_dispatch(data_list)
case "unit_dispatch":
df = self.convert_unit_dispatch(data_list)
case "rl_learning_params":
case "rl_params":
df = self.convert_rl_params(data_list)
case "grid_flows":
dfs = []
Expand Down
24 changes: 24 additions & 0 deletions assume/reinforcement_learning/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ def __init__(
self.device = self.learning_role.device
self.float_type = self.learning_role.float_type

def update_learning_rate(
self,
optimizers: list[th.optim.Optimizer] | th.optim.Optimizer,
learning_rate: float,
) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0).
Args:
optimizers (List[th.optim.Optimizer] | th.optim.Optimizer): An optimizer or a list of optimizers.
Note:
Adapted from SB3:
- https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/base_class.py#L286
- https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/utils.py#L68
"""

if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate

def update_policy(self):
logger.error(
"No policy update function of the used Rl algorithm was defined. Please define how the policies should be updated in the specific algorithm you use"
Expand Down
31 changes: 29 additions & 2 deletions assume/reinforcement_learning/algorithms/matd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def create_actors(self) -> None:
unit_strategy.actor_target.train(mode=False)

unit_strategy.actor.optimizer = Adam(
unit_strategy.actor.parameters(), lr=self.learning_rate
unit_strategy.actor.parameters(),
lr=self.learning_role.calc_lr_from_progress(
1
), # 1=100% of simulation remaining, uses learning_rate from config as starting point
)

obs_dim_list.append(unit_strategy.obs_dim)
Expand Down Expand Up @@ -314,7 +317,10 @@ def create_critics(self) -> None:
)

self.learning_role.critics[u_id].optimizer = Adam(
self.learning_role.critics[u_id].parameters(), lr=self.learning_rate
self.learning_role.critics[u_id].parameters(),
lr=self.learning_role.calc_lr_from_progress(
1
), # 1 = 100% of simulation remaining, uses learning_rate from config as starting point
)

self.learning_role.target_critics[u_id].load_state_dict(
Expand Down Expand Up @@ -392,6 +398,27 @@ def update_policy(self):

logger.debug("Updating Policy")
n_rl_agents = len(self.learning_role.rl_strats.keys())

# update noise decay and learning rate
updated_noise_decay = self.learning_role.calc_noise_from_progress(
self.learning_role.get_progress_remaining()
)

learning_rate = self.learning_role.calc_lr_from_progress(
self.learning_role.get_progress_remaining()
)

# loop again over all units to avoid update call for every gradient step, as it will be ambiguous
for u_id, unit_strategy in self.learning_role.rl_strats.items():
self.update_learning_rate(
[
self.learning_role.critics[u_id].optimizer,
self.learning_role.rl_strats[u_id].actor.optimizer,
],
learning_rate=learning_rate,
)
unit_strategy.action_noise.update_noise_decay(updated_noise_decay)

for _ in range(self.gradient_steps):
self.n_updates += 1
i = 0
Expand Down
83 changes: 62 additions & 21 deletions assume/reinforcement_learning/learning_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@

import logging
from collections import defaultdict
from datetime import datetime
from pathlib import Path

import torch as th
from mango import Role

from assume.common.base import LearningConfig, LearningStrategy
from assume.common.utils import datetime2timestamp
from assume.reinforcement_learning.algorithms.base_algorithm import RLAlgorithm
from assume.reinforcement_learning.algorithms.matd3 import TD3
from assume.reinforcement_learning.buffer import ReplayBuffer
from assume.reinforcement_learning.learning_utils import linear_schedule_func

logger = logging.getLogger(__name__)

Expand All @@ -31,13 +34,11 @@ class Learning(Role):
def __init__(
self,
learning_config: LearningConfig,
start: datetime = None,
end: datetime = None,
):
# how many learning roles do exist and how are they named
self.buffer: ReplayBuffer = None
self.early_stopping_steps = learning_config.get("early_stopping_steps", 10)
self.early_stopping_threshold = learning_config.get(
"early_stopping_threshold", 0.05
)
self.episodes_done = 0
self.rl_strats: dict[int, LearningStrategy] = {}
self.rl_algorithm = learning_config.get("algorithm", "matd3")
Expand All @@ -55,6 +56,19 @@ def __init__(
"trained_policies_load_path", self.trained_policies_save_path
)

# if early_stopping_steps are not provided then set default to no early stopping (early_stopping_steps need to be greater than validation_episodes)
self.early_stopping_steps = learning_config.get(
"early_stopping_steps",
int(
self.training_episodes
/ learning_config.get("validation_episodes_interval", 5)
+ 1
),
)
self.early_stopping_threshold = learning_config.get(
"early_stopping_threshold", 0.05
)

cuda_device = (
learning_config["device"]
if "cuda" in learning_config.get("device", "cpu")
Expand All @@ -69,7 +83,26 @@ def __init__(
th.backends.cuda.matmul.allow_tf32 = True
th.backends.cudnn.allow_tf32 = True

if start is not None:
self.start = datetime2timestamp(start)
if end is not None:
self.end = datetime2timestamp(end)

self.learning_rate = learning_config.get("learning_rate", 1e-4)
self.learning_rate_schedule = learning_config.get(
"learning_rate_schedule", None
)
if self.learning_rate_schedule == "linear":
self.calc_lr_from_progress = linear_schedule_func(self.learning_rate)
else:
self.calc_lr_from_progress = lambda x: self.learning_rate

noise_dt = learning_config.get("noise_dt", 1)
self.action_noise_schedule = learning_config.get("action_noise_schedule", None)
if self.action_noise_schedule == "linear":
self.calc_noise_from_progress = linear_schedule_func(noise_dt)
else:
self.calc_noise_from_progress = lambda x: noise_dt

# if we do not have initial experience collected we will get an error as no samples are available on the
# buffer from which we can draw experience to adapt the strategy, hence we set it to minimum one episode
Expand Down Expand Up @@ -118,8 +151,6 @@ def load_inter_episodic_data(self, inter_episodic_data):
if self.episodes_done > self.episodes_collecting_initial_experience:
self.turn_off_initial_exploration()

self.set_noise_scale(inter_episodic_data["noise_scale"])

self.initialize_policy(inter_episodic_data["actors_and_critics"])

def get_inter_episodic_data(self):
Expand All @@ -138,7 +169,6 @@ def get_inter_episodic_data(self):
"avg_all_eval": self.avg_rewards,
"buffer": self.buffer,
"actors_and_critics": self.rl_algorithm.extract_policy(),
"noise_scale": self.get_noise_scale(),
}

def setup(self) -> None:
Expand Down Expand Up @@ -189,26 +219,31 @@ def turn_off_initial_exploration(self) -> None:
for _, unit in self.rl_strats.items():
unit.collect_initial_experience_mode = False

def set_noise_scale(self, stored_scale) -> None:
"""
Set the noise scale for all learning strategies (units) in rl_strats.
def get_progress_remaining(self) -> float:
"""
for _, unit in self.rl_strats.items():
unit.action_noise.scale = stored_scale
Get the remaining learning progress from the simulation run.
def get_noise_scale(self) -> None:
"""
Get the noise scale from the first learning strategy (unit) in rl_strats.
total_duration = self.end - self.start
elapsed_duration = self.context.current_timestamp - self.start

Notes:
The noise scale is the same for all learning strategies (units) in rl_strats, so we only need to get it from one unit.
It is only depended on the number of updates done so far, which is determined by the number of episodes done and the update frequency.
learning_episodes = (
self.training_episodes - self.episodes_collecting_initial_experience
)

"""
stored_scale = list(self.rl_strats.values())[0].action_noise.scale
if self.episodes_done < self.episodes_collecting_initial_experience:
progress_remaining = 1
else:
progress_remaining = (
1
- (
(self.episodes_done - self.episodes_collecting_initial_experience)
/ learning_episodes
)
- ((1 / learning_episodes) * (elapsed_duration / total_duration))
)

return stored_scale
return progress_remaining

def create_learning_algorithm(self, algorithm: RLAlgorithm):
"""
Expand Down Expand Up @@ -329,6 +364,12 @@ def compare_and_save_policies(self, metrics: dict) -> bool:
logger.info(
f"Stopping training as no improvement above {self.early_stopping_threshold} in last {self.early_stopping_steps} evaluations for {metric}"
)
if (
self.learning_rate_schedule or self.action_noise_schedule
) is not None:
logger.info(
f"Learning rate schedule ({self.learning_rate_schedule}) or action noise schedule ({self.action_noise_schedule}) were scheduled to decay, further learning improvement can be possible. End value of schedule may not have been reached."
)

self.rl_algorithm.save_params(
directory=f"{self.trained_policies_save_path}/last_policies"
Expand Down
2 changes: 1 addition & 1 deletion assume/reinforcement_learning/learning_unit_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def write_learning_to_output(self, orderbook: Orderbook, market_id: str) -> None
receiver_addr=db_addr,
content={
"context": "write_results",
"type": "rl_learning_params",
"type": "rl_params",
"data": output_agent_list,
},
)
Expand Down
66 changes: 64 additions & 2 deletions assume/reinforcement_learning/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: AGPL-3.0-or-later

from collections.abc import Callable
from datetime import datetime
from typing import TypedDict

Expand All @@ -17,6 +18,10 @@ class ObsActRew(TypedDict):

observation_dict = dict[list[datetime], ObsActRew]

# A schedule takes the remaining progress as input
# and outputs a scalar (e.g. learning rate, action noise scale ...)
Schedule = Callable[[float], float]


# Ornstein-Uhlenbeck Noise
# from https://github.com/songrotek/DDPG/blob/master/ou_noise.py
Expand Down Expand Up @@ -64,10 +69,16 @@ def __init__(self, action_dimension, mu=0.0, sigma=0.1, scale=1.0, dt=0.9998):
self.dt = dt

def noise(self):
noise = self.scale * np.random.normal(self.mu, self.sigma, self.act_dimension)
self.scale = self.dt * self.scale # if self.scale >= 0.1 else self.scale
noise = (
self.dt
* self.scale
* np.random.normal(self.mu, self.sigma, self.act_dimension)
)
return noise

def update_noise_decay(self, updated_decay: float):
self.dt = updated_decay


def polyak_update(params, target_params, tau: float):
"""
Expand All @@ -91,3 +102,54 @@ def polyak_update(params, target_params, tau: float):
for param, target_param in zip(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)


def linear_schedule_func(
start: float, end: float = 0, end_fraction: float = 1
) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = 1 - ``end_fraction``.
Args:
start: value to start with if ``progress_remaining`` = 1
end: value to end with if ``progress_remaining`` = 0
end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
Returns:
Linear schedule function.
Note:
Adapted from SB3: https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/utils.py#L100
"""

def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction

return func


def constant_schedule(val: float) -> Schedule:
"""
Create a function that returns a constant. It is useful for learning rate schedule (to avoid code duplication)
Args:
val: constant value
Returns:
Constant schedule function.
Note:
From SB3: https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/utils.py#L124
"""

def func(_):
return val

return func
1 change: 0 additions & 1 deletion assume/scenario/loader_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,6 @@ def run_learning(
"avg_all_eval": [],
"episodes_done": 0,
"eval_episodes_done": 0,
"noise_scale": world.learning_config.get("noise_scale", 1.0),
}

# -----------------------------------------
Expand Down
5 changes: 4 additions & 1 deletion assume/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def setup_learning(self) -> None:
# if so, we initiate the rl learning role with parameters
from assume.reinforcement_learning.learning_role import Learning

self.learning_role = Learning(self.learning_config)
self.learning_role = Learning(
self.learning_config, start=self.start, end=self.end
)

# separate process does not support buffer and learning
self.learning_agent_addr = addr(self.addr, "learning_agent")
rl_agent = agent_composed_of(
Expand Down
Loading

0 comments on commit 44fb0fc

Please sign in to comment.