From 3983d34e3f7201961aeee6696399a284c99d5d64 Mon Sep 17 00:00:00 2001 From: Nick Harder Date: Thu, 30 Nov 2023 11:22:08 +0100 Subject: [PATCH] -adjust query of reward during training -before it got mean of all rewards -now it is per unit which is better --- assume/common/outputs.py | 14 ++++++++------ assume/common/scenario_loader.py | 3 ++- assume/common/units_operator.py | 3 +-- examples/examples.py | 4 ++-- examples/inputs/example_02a/config.yaml | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/assume/common/outputs.py b/assume/common/outputs.py index d17d84d4..079bac9b 100644 --- a/assume/common/outputs.py +++ b/assume/common/outputs.py @@ -8,6 +8,7 @@ from datetime import datetime from pathlib import Path +import numpy as np import pandas as pd from dateutil import rrule as rr from mango import Role @@ -415,13 +416,14 @@ async def on_stop(self): def get_sum_reward(self): query = text( - f"select reward FROM rl_params where simulation='{self.simulation_id}'" + f"select unit, SUM(reward) FROM rl_params where simulation='{self.simulation_id}' GROUP BY unit" ) - avg_reward = 0 with self.db.begin() as db: - reward = db.execute(query).fetchall() - if len(reward): - avg_reward = sum(r[0] for r in reward) / len(reward) + rewards_by_unit = db.execute(query).fetchall() - return avg_reward + # convert into a numpy array + rewards_by_unit = [r[1] for r in rewards_by_unit] + rewards_by_unit = np.array(rewards_by_unit) + + return rewards_by_unit diff --git a/assume/common/scenario_loader.py b/assume/common/scenario_loader.py index e149b10c..ec218b7a 100644 --- a/assume/common/scenario_loader.py +++ b/assume/common/scenario_loader.py @@ -647,7 +647,8 @@ def run_learning(world: World, inputs_path: str, scenario: str, study_case: str) world.run() - avg_reward = world.output_role.get_sum_reward() + total_rewards = world.output_role.get_sum_reward() + avg_reward = np.mean(total_rewards) # check reward improvement in validation run world.learning_config["trained_actors_path"] = old_path if avg_reward > best_reward: diff --git a/assume/common/units_operator.py b/assume/common/units_operator.py index ed166206..28c17036 100644 --- a/assume/common/units_operator.py +++ b/assume/common/units_operator.py @@ -20,7 +20,6 @@ MetaDict, OpeningMessage, Orderbook, - OrderBookMessage, RegistrationMessage, ) from assume.common.utils import aggregate_step_amount @@ -523,7 +522,7 @@ def write_to_learning( all_observations = all_observations.squeeze().cpu().numpy() all_actions = all_actions.squeeze().cpu().numpy() all_rewards = np.array(all_rewards) - rl_agent_data = (np.array(all_observations), all_actions, all_rewards) + rl_agent_data = (all_observations, all_actions, all_rewards) learning_role_id = self.context.data_dict.get("learning_agent_id") learning_role_addr = self.context.data_dict.get("learning_agent_addr") diff --git a/examples/examples.py b/examples/examples.py index 3a1b3903..412124e3 100644 --- a/examples/examples.py +++ b/examples/examples.py @@ -58,8 +58,8 @@ - local_db: without database and grafana - timescale: with database and grafana (note: you need docker installed) """ - data_format = "timescale" # "local_db" or "timescale" - example = "learning_small" + data_format = "local_db" # "local_db" or "timescale" + example = "small" if data_format == "local_db": db_uri = f"sqlite:///./examples/local_db/assume_db_{example}.db" diff --git a/examples/inputs/example_02a/config.yaml b/examples/inputs/example_02a/config.yaml index b168c335..5b03fc76 100644 --- a/examples/inputs/example_02a/config.yaml +++ b/examples/inputs/example_02a/config.yaml @@ -17,7 +17,7 @@ base: max_bid_price: 100 algorithm: matd3 learning_rate: 0.001 - training_episodes: 100 + training_episodes: 50 episodes_collecting_initial_experience: 5 train_freq: 24 gradient_steps: -1