Skip to content

Commit

Permalink
Merge pull request #256 from assume-framework/fix_avg_reward_query
Browse files Browse the repository at this point in the history
adjust query of reward during training
  • Loading branch information
kim-mskw authored Dec 7, 2023
2 parents 357f0ba + a1314f7 commit dd9a744
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
14 changes: 8 additions & 6 deletions assume/common/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
from dateutil import rrule as rr
from mango import Role
Expand Down Expand Up @@ -415,13 +416,14 @@ async def on_stop(self):

def get_sum_reward(self):
query = text(
f"select reward FROM rl_params where simulation='{self.simulation_id}'"
f"select unit, SUM(reward) FROM rl_params where simulation='{self.simulation_id}' GROUP BY unit"
)

avg_reward = 0
with self.db.begin() as db:
reward = db.execute(query).fetchall()
if len(reward):
avg_reward = sum(r[0] for r in reward) / len(reward)
rewards_by_unit = db.execute(query).fetchall()

return avg_reward
# convert into a numpy array
rewards_by_unit = [r[1] for r in rewards_by_unit]
rewards_by_unit = np.array(rewards_by_unit)

return rewards_by_unit
3 changes: 2 additions & 1 deletion assume/common/scenario_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ def run_learning(world: World, inputs_path: str, scenario: str, study_case: str)

world.run()

avg_reward = world.output_role.get_sum_reward()
total_rewards = world.output_role.get_sum_reward()
avg_reward = np.mean(total_rewards)
# check reward improvement in validation run
world.learning_config["trained_actors_path"] = old_path
if avg_reward > best_reward:
Expand Down
3 changes: 1 addition & 2 deletions assume/common/units_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
MetaDict,
OpeningMessage,
Orderbook,
OrderBookMessage,
RegistrationMessage,
)
from assume.common.utils import aggregate_step_amount
Expand Down Expand Up @@ -523,7 +522,7 @@ def write_to_learning(
all_observations = all_observations.squeeze().cpu().numpy()
all_actions = all_actions.squeeze().cpu().numpy()
all_rewards = np.array(all_rewards)
rl_agent_data = (np.array(all_observations), all_actions, all_rewards)
rl_agent_data = (all_observations, all_actions, all_rewards)

learning_role_id = self.context.data_dict.get("learning_agent_id")
learning_role_addr = self.context.data_dict.get("learning_agent_addr")
Expand Down
2 changes: 1 addition & 1 deletion examples/inputs/example_02a/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ base:
max_bid_price: 100
algorithm: matd3
learning_rate: 0.001
training_episodes: 100
training_episodes: 50
episodes_collecting_initial_experience: 5
train_freq: 24
gradient_steps: -1
Expand Down

0 comments on commit dd9a744

Please sign in to comment.