diff --git a/assume/common/base.py b/assume/common/base.py index 8d6145249..4b8f6a8b6 100644 --- a/assume/common/base.py +++ b/assume/common/base.py @@ -68,17 +68,19 @@ def __init__( for strategy in self.bidding_strategies.values() ): self.outputs["actions"] = TensorFastSeries(value=0.0, index=self.index) - self.outputs["exploration_noise"] = TensorFastSeries( - value=0.0, - index=self.index, - ) self.outputs["reward"] = FastSeries(value=0.0, index=self.index) self.outputs["regret"] = FastSeries(value=0.0, index=self.index) - # RL data stored as lists to simplify storing to the buffer - self.outputs["rl_observations"] = [] - self.outputs["rl_actions"] = [] - self.outputs["rl_rewards"] = [] + self.avg_op_time = 0 + self.total_op_time = 0 + + + # RL data stored as lists to simplify storing to the buffer + self.outputs["rl_observations"] = [] + self.outputs["rl_actions"] = [] + self.outputs["rl_rewards"] = [] + self.outputs["rl_log_probs"] = [] + def calculate_bids( self, @@ -742,6 +744,14 @@ def __init__( # them into suitable format for recurrent neural networks self.num_timeseries_obs_dim = num_timeseries_obs_dim + self.rl_algorithm_name = kwargs.get("algorithm", "matd3") + if self.rl_algorithm_name == "matd3": + from assume.reinforcement_learning.algorithms.matd3 import get_actions + self.get_actions = get_actions + elif self.rl_algorithm_name == "ppo": + from assume.reinforcement_learning.algorithms.ppo import get_actions + self.get_actions = get_actions + class LearningConfig(TypedDict): """ diff --git a/assume/reinforcement_learning/__init__.py b/assume/reinforcement_learning/__init__.py index a10131609..152fcbbdb 100644 --- a/assume/reinforcement_learning/__init__.py +++ b/assume/reinforcement_learning/__init__.py @@ -3,4 +3,5 @@ # SPDX-License-Identifier: AGPL-3.0-or-later from assume.reinforcement_learning.buffer import ReplayBuffer +from assume.reinforcement_learning.buffer import RolloutBuffer from assume.reinforcement_learning.learning_role import Learning diff --git a/assume/reinforcement_learning/algorithms/__init__.py b/assume/reinforcement_learning/algorithms/__init__.py index 645e5c991..cb23e79b6 100644 --- a/assume/reinforcement_learning/algorithms/__init__.py +++ b/assume/reinforcement_learning/algorithms/__init__.py @@ -7,9 +7,11 @@ from assume.reinforcement_learning.neural_network_architecture import ( MLPActor, LSTMActor, + DistActor, ) actor_architecture_aliases: dict[str, type[nn.Module]] = { "mlp": MLPActor, "lstm": LSTMActor, + "dist": DistActor, } diff --git a/assume/reinforcement_learning/algorithms/base_algorithm.py b/assume/reinforcement_learning/algorithms/base_algorithm.py index 11aa71f2b..56fd0acea 100644 --- a/assume/reinforcement_learning/algorithms/base_algorithm.py +++ b/assume/reinforcement_learning/algorithms/base_algorithm.py @@ -34,32 +34,17 @@ def __init__( # init learning_role as object of Learning class learning_role, learning_rate=1e-4, - episodes_collecting_initial_experience=100, batch_size=1024, - tau=0.005, gamma=0.99, - gradient_steps=-1, - policy_delay=2, - target_policy_noise=0.2, - target_noise_clip=0.5, actor_architecture="mlp", + **kwargs, # allow additional params for specific algorithms ): super().__init__() self.learning_role = learning_role self.learning_rate = learning_rate - self.episodes_collecting_initial_experience = ( - episodes_collecting_initial_experience - ) self.batch_size = batch_size self.gamma = gamma - self.tau = tau - - self.gradient_steps = gradient_steps - - self.policy_delay = policy_delay - self.target_noise_clip = target_noise_clip - self.target_policy_noise = target_policy_noise if actor_architecture in actor_architecture_aliases.keys(): self.actor_architecture_class = actor_architecture_aliases[ diff --git a/assume/reinforcement_learning/algorithms/matd3.py b/assume/reinforcement_learning/algorithms/matd3.py index a84857bd1..cb26e2448 100644 --- a/assume/reinforcement_learning/algorithms/matd3.py +++ b/assume/reinforcement_learning/algorithms/matd3.py @@ -11,7 +11,7 @@ from assume.common.base import LearningStrategy from assume.reinforcement_learning.algorithms.base_algorithm import RLAlgorithm -from assume.reinforcement_learning.learning_utils import polyak_update +from assume.reinforcement_learning.learning_utils import polyak_update, collect_obs_for_central_critic from assume.reinforcement_learning.neural_network_architecture import CriticTD3 logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class TD3(RLAlgorithm): Original paper: https://arxiv.org/pdf/1802.09477.pdf """ - + def __init__( self, learning_role, @@ -46,16 +46,16 @@ def __init__( super().__init__( learning_role, learning_rate, - episodes_collecting_initial_experience, batch_size, - tau, gamma, - gradient_steps, - policy_delay, - target_policy_noise, - target_noise_clip, actor_architecture, ) + self.episodes_collecting_initial_experience = episodes_collecting_initial_experience + self.tau = tau + self.gradient_steps = gradient_steps + self.policy_delay = policy_delay + self.target_policy_noise = target_policy_noise + self.target_noise_clip = target_noise_clip self.n_updates = 0 def save_params(self, directory): @@ -201,6 +201,8 @@ def load_actor_params(self, directory: str) -> None: except Exception: logger.warning(f"No actor values loaded for agent {u_id}") + + def initialize_policy(self, actors_and_critics: dict = None) -> None: """ Create actor and critic networks for reinforcement learning. @@ -293,7 +295,7 @@ def create_critics(self) -> None: This method initializes critic networks for each agent in the reinforcement learning setup. Notes: - The observation dimension need to be the same, due to the centralized criic that all actors share. + The observation dimension need to be the same, due to the centralized critic that all actors share. If you have units with different observation dimensions. They need to have different critics and hence learning roles. """ n_agents = len(self.learning_role.rl_strats) @@ -458,47 +460,14 @@ def update_policy(self): all_actions = actions.view(self.batch_size, -1) - # this takes the unique observations from all other agents assuming that - # the unique observations are at the end of the observation vector - temp = th.cat( - ( - states[:, :i, self.obs_dim - self.unique_obs_dim :].reshape( - self.batch_size, -1 - ), - states[ - :, i + 1 :, self.obs_dim - self.unique_obs_dim : - ].reshape(self.batch_size, -1), - ), - axis=1, + #collect observations for critic + all_states = collect_obs_for_central_critic( + states, i, self.obs_dim, self.unique_obs_dim, self.batch_size ) - - # the final all_states vector now contains the current agent's observation - # and the unique observations from all other agents - all_states = th.cat( - (states[:, i, :].reshape(self.batch_size, -1), temp), axis=1 - ).view(self.batch_size, -1) - # all_states = states[:, i, :].reshape(self.batch_size, -1) - - # this is the same as above but for the next states - temp = th.cat( - ( - next_states[ - :, :i, self.obs_dim - self.unique_obs_dim : - ].reshape(self.batch_size, -1), - next_states[ - :, i + 1 :, self.obs_dim - self.unique_obs_dim : - ].reshape(self.batch_size, -1), - ), - axis=1, + all_next_states = collect_obs_for_central_critic( + next_states, i, self.obs_dim, self.unique_obs_dim, self.batch_size ) - # the final all_next_states vector now contains the current agent's observation - # and the unique observations from all other agents - all_next_states = th.cat( - (next_states[:, i, :].reshape(self.batch_size, -1), temp), axis=1 - ).view(self.batch_size, -1) - # all_next_states = next_states[:, i, :].reshape(self.batch_size, -1) - with th.no_grad(): # Compute the next Q-values: min over all critics targets next_q_values = th.cat( @@ -548,3 +517,73 @@ def update_policy(self): actor.parameters(), actor_target.parameters(), self.tau ) i += 1 + + +def get_actions(rl_strategy, next_observation): + """ + Gets actions for a unit based on the observation using MATD3. + + Args: + rl_strategy (RLStrategy): The strategy containing relevant information. + next_observation (torch.Tensor): The observation. + + Returns: + torch.Tensor: The actions containing two bid prices. + tuple: The noise (if applicable). + + Note: + If the agent is in learning mode, the actions are chosen by the actor neuronal net and noise is added to the action. + In the first x episodes, the agent is in initial exploration mode, where the action is chosen by noise only to explore + the entire action space. X is defined by episodes_collecting_initial_experience. + If the agent is not in learning mode, the actions are chosen by the actor neuronal net without noise. + """ + + actor = rl_strategy.actor + device = rl_strategy.device + float_type = rl_strategy.float_type + act_dim = rl_strategy.act_dim + learning_mode = rl_strategy.learning_mode + perform_evaluation = rl_strategy.perform_evaluation + action_noise = rl_strategy.action_noise + collect_initial_experience_mode = rl_strategy.collect_initial_experience_mode + + # distinction whether we are in learning mode or not to handle exploration realised with noise + if learning_mode and not perform_evaluation: + # if we are in learning mode the first x episodes we want to explore the entire action space + # to get a good initial experience, in the area around the costs of the agent + if collect_initial_experience_mode: + # define current action as solely noise + noise = ( + th.normal(mean=0.0, std=0.2, size=(1, act_dim), dtype=float_type) + .to(device) + .squeeze() + ) + + # ============================================================================= + # 2.1 Get Actions and handle exploration + # ============================================================================= + base_bid = next_observation[-1] + + # add noise to the last dimension of the observation + # needs to be adjusted if observation space is changed, because only makes sense + # if the last dimension of the observation space are the marginal cost + curr_action = noise + base_bid.clone().detach() + + else: + # if we are not in the initial exploration phase we choose the action with the actor neural net + # and add noise to the action + curr_action = actor(next_observation).detach() # calls the forward method of the actor network + noise = th.tensor( + action_noise.noise(), device=device, dtype=float_type + ) + curr_action += noise + else: + # if we are not in learning mode we just use the actor neural net to get the action without adding noise + curr_action = actor(next_observation).detach() + noise = tuple(0 for _ in range(act_dim)) + + # Clamp actions to be within the valid action space bounds + curr_action = curr_action.clamp(-1, 1) + + return curr_action, noise + diff --git a/assume/reinforcement_learning/algorithms/ppo.py b/assume/reinforcement_learning/algorithms/ppo.py new file mode 100644 index 000000000..788719bfc --- /dev/null +++ b/assume/reinforcement_learning/algorithms/ppo.py @@ -0,0 +1,604 @@ +# SPDX-FileCopyrightText: ASSUME Developers +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +import logging +import os + +import torch as th +from torch.nn import functional as F +from torch.optim import Adam + +from assume.common.base import LearningStrategy +from assume.reinforcement_learning.algorithms.base_algorithm import RLAlgorithm +from assume.reinforcement_learning.neural_network_architecture import CriticPPO +from assume.reinforcement_learning.learning_utils import collect_obs_for_central_critic + + +logger = logging.getLogger(__name__) + + +class PPO(RLAlgorithm): + """ + Proximal Policy Optimization (PPO) is a robust and efficient policy gradient method for reinforcement learning. + It strikes a balance between trust-region methods and simpler approaches by using clipped objective functions. + PPO avoids large updates to the policy by restricting changes to stay within a specified range, which helps stabilize training. + The key improvements include the introduction of a surrogate objective that limits policy updates and ensures efficient learning, + as well as the use of multiple epochs of stochastic gradient descent on batches of data. + + Open AI Spinning guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html# + + Original paper: https://arxiv.org/pdf/1802.09477.pdf + """ + + # Change order and mandatory parameters in the superclass, removed and newly added parameters + def __init__( + self, + learning_role, + learning_rate: float, + gamma: float, # Discount factor for future rewards + gradient_steps: int, # Number of steps for updating the policy + clip_ratio: float, # Clipping parameter for policy updates + vf_coef: float, # Value function coefficient in the loss function + entropy_coef: float, # Entropy coefficient for exploration + max_grad_norm: float, # Gradient clipping value + gae_lambda: float, # GAE lambda for advantage estimation + actor_architecture: str, + ): + super().__init__( + learning_role=learning_role, + learning_rate=learning_rate, + gamma=gamma, + actor_architecture=actor_architecture, + ) + self.gradient_steps = gradient_steps + self.clip_ratio = clip_ratio + self.vf_coef = vf_coef + self.entropy_coef = entropy_coef + self.max_grad_norm = max_grad_norm + self.gae_lambda = gae_lambda + self.n_updates = 0 # Number of updates performed + self.batch_size = learning_role.batch_size + + # write error if different actor_architecture than dist is used + if actor_architecture != "dist": + raise ValueError( + "PPO only supports the 'dist' actor architecture. Please define 'dist' as actor architecture in config." + ) + + # Unchanged method from MATD3 + def save_params(self, directory): + """ + This method saves the parameters of both the actor and critic networks associated with the learning role. It organizes the + saved parameters into separate directories for critics and actors within the specified base directory. + + Args: + directory (str): The base directory for saving the parameters. + """ + self.save_critic_params(directory=f"{directory}/critics") + self.save_actor_params(directory=f"{directory}/actors") + + + + # Centralized + def save_critic_params(self, directory): + """ + Save the parameters of critic networks. + + This method saves the parameters of the critic networks, including the critic's state_dict, critic_target's state_dict. It organizes the saved parameters into a directory structure specific to the critic + associated with each learning strategy. + + Args: + directory (str): The base directory for saving the parameters. + """ + os.makedirs(directory, exist_ok=True) + for u_id in self.learning_role.rl_strats.keys(): + obj = { + "critic": self.learning_role.critics[u_id].state_dict(), + # "critic_target": self.learning_role.target_critics[u_id].state_dict(), + "critic_optimizer": self.learning_role.critics[ + u_id + ].optimizer.state_dict(), + } + path = f"{directory}/critic_{u_id}.pt" + th.save(obj, path) + + # Removed actor_target in comparison to MATD3 + def save_actor_params(self, directory): + """ + Save the parameters of actor networks. + + This method saves the parameters of the actor networks, including the actor's state_dict, actor_target's state_dict, and + the actor's optimizer state_dict. It organizes the saved parameters into a directory structure specific to the actor + associated with each learning strategy. + + Args: + directory (str): The base directory for saving the parameters. + """ + os.makedirs(directory, exist_ok=True) + for u_id in self.learning_role.rl_strats.keys(): + obj = { + "actor": self.learning_role.rl_strats[u_id].actor.state_dict(), + # "actor_target": self.learning_role.rl_strats[ + # u_id + # ].actor_target.state_dict(), + "actor_optimizer": self.learning_role.rl_strats[ + u_id + ].actor.optimizer.state_dict(), + } + path = f"{directory}/actor_{u_id}.pt" + th.save(obj, path) + + # Unchanged method from MATD3 + def load_params(self, directory: str) -> None: + """ + Load the parameters of both actor and critic networks. + + This method loads the parameters of both the actor and critic networks associated with the learning role from the specified + directory. It uses the `load_critic_params` and `load_actor_params` methods to load the respective parameters. + + Args: + directory (str): The directory from which the parameters should be loaded. + """ + self.load_critic_params(directory) + self.load_actor_params(directory) + + + + # Centralized + def load_critic_params(self, directory: str) -> None: + """ + Load the parameters of critic networks from a specified directory. + + This method loads the parameters of critic networks, including the critic's state_dict, critic_target's state_dict, and + the critic's optimizer state_dict, from the specified directory. It iterates through the learning strategies associated + with the learning role, loads the respective parameters, and updates the critic and target critic networks accordingly. + + Args: + directory (str): The directory from which the parameters should be loaded. + """ + logger.info("Loading critic parameters...") + + if not os.path.exists(directory): + logger.warning( + "Specified directory for loading the critics does not exist! Starting with randomly initialized values!" + ) + return + + for u_id in self.learning_role.rl_strats.keys(): + try: + critic_params = self.load_obj( + directory=f"{directory}/critics/critic_{str(u_id)}.pt" + ) + self.learning_role.critics[u_id].load_state_dict( + critic_params["critic"] + ) + self.learning_role.critics[u_id].optimizer.load_state_dict( + critic_params["critic_optimizer"] + ) + except Exception: + logger.warning(f"No critic values loaded for agent {u_id}") + + # Removed actor_target in comparison to MATD3 + def load_actor_params(self, directory: str) -> None: + """ + Load the parameters of actor networks from a specified directory. + + This method loads the parameters of actor networks, including the actor's state_dict, actor_target's state_dict, and + the actor's optimizer state_dict, from the specified directory. It iterates through the learning strategies associated + with the learning role, loads the respective parameters, and updates the actor and target actor networks accordingly. + + Args: + directory (str): The directory from which the parameters should be loaded. + """ + logger.info("Loading actor parameters...") + if not os.path.exists(directory): + logger.warning( + "Specified directory for loading the actors does not exist! Starting with randomly initialized values!" + ) + return + + for u_id in self.learning_role.rl_strats.keys(): + try: + actor_params = self.load_obj( + directory=f"{directory}/actors/actor_{str(u_id)}.pt" + ) + self.learning_role.rl_strats[u_id].actor.load_state_dict( + actor_params["actor"] + ) + # self.learning_role.rl_strats[u_id].actor_target.load_state_dict( + # actor_params["actor_target"] + # ) + self.learning_role.rl_strats[u_id].actor.optimizer.load_state_dict( + actor_params["actor_optimizer"] + ) + except Exception: + logger.warning(f"No actor values loaded for agent {u_id}") + + + # Centralized + def initialize_policy(self, actors_and_critics: dict = None) -> None: + """ + Create actor and critic networks for reinforcement learning. + + If `actors_and_critics` is None, this method creates new actor and critic networks. + If `actors_and_critics` is provided, it assigns existing networks to the respective attributes. + + Args: + actors_and_critics (dict): The actor and critic networks to be assigned. + + """ + if actors_and_critics is None: + self.create_actors() + self.create_critics() + + else: + self.learning_role.critics = actors_and_critics["critics"] + # self.learning_role.target_critics = actors_and_critics["target_critics"] + for u_id, unit_strategy in self.learning_role.rl_strats.items(): + unit_strategy.actor = actors_and_critics["actors"][u_id] + # unit_strategy.actor_target = actors_and_critics["actor_targets"][u_id] + + self.obs_dim = actors_and_critics["obs_dim"] + self.act_dim = actors_and_critics["act_dim"] + self.unique_obs_dim = actors_and_critics["unique_obs_dim"] + + # Removed actor_target in comparison to MATD3 + def create_actors(self) -> None: + """ + Create actor networks for reinforcement learning for each unit strategy. + + This method initializes actor networks and their corresponding target networks for each unit strategy. + The actors are designed to map observations to action probabilities in a reinforcement learning setting. + + The created actor networks are associated with each unit strategy and stored as attributes. + + Notes: + The observation dimension need to be the same, due to the centralized criic that all actors share. + If you have units with different observation dimensions. They need to have different critics and hence learning roles. + + """ + + obs_dim_list = [] + act_dim_list = [] + + for _, unit_strategy in self.learning_role.rl_strats.items(): + unit_strategy.actor = self.actor_architecture_class( + obs_dim=unit_strategy.obs_dim, + act_dim=unit_strategy.act_dim, + float_type=self.float_type, + unique_obs_dim=unit_strategy.unique_obs_dim, + num_timeseries_obs_dim=unit_strategy.num_timeseries_obs_dim, + ).to(self.device) + + unit_strategy.actor.optimizer = Adam( + unit_strategy.actor.parameters(), lr=self.learning_rate + ) + + obs_dim_list.append(unit_strategy.obs_dim) + act_dim_list.append(unit_strategy.act_dim) + + if len(set(obs_dim_list)) > 1: + raise ValueError( + "All observation dimensions must be the same for all RL agents" + ) + else: + self.obs_dim = obs_dim_list[0] + + if len(set(act_dim_list)) > 1: + raise ValueError("All action dimensions must be the same for all RL agents") + else: + self.act_dim = act_dim_list[0] + + + # Centralized + def create_critics(self) -> None: + """ + Create decentralized critic networks for reinforcement learning. + + This method initializes a separate critic network for each agent in the reinforcement learning setup. + Each critic learns to predict the value function based on the individual agent's observation. + + Notes: + Each agent has its own critic, so the critic is no longer shared among all agents. + """ + + n_agents = len(self.learning_role.rl_strats) + strategy: LearningStrategy + unique_obs_dim_list = [] + + for u_id, strategy in self.learning_role.rl_strats.items(): + self.learning_role.critics[u_id] = CriticPPO( + n_agents=n_agents, + obs_dim=strategy.obs_dim, + act_dim=strategy.act_dim, + unique_obs_dim=strategy.unique_obs_dim, + float_type=self.float_type, + ) + + self.learning_role.critics[u_id].optimizer = Adam( + self.learning_role.critics[u_id].parameters(), lr=self.learning_rate + ) + + self.learning_role.critics[u_id] = self.learning_role.critics[u_id].to( + self.device + ) + + unique_obs_dim_list.append(strategy.unique_obs_dim) + + # check if all unique_obs_dim are the same and raise an error if not + # if they are all the same, set the unique_obs_dim attribute + if len(set(unique_obs_dim_list)) > 1: + raise ValueError( + "All unique_obs_dim values must be the same for all RL agents" + ) + else: + self.unique_obs_dim = unique_obs_dim_list[0] + + + # Centralized + def extract_policy(self) -> dict: + """ + Extract actor and critic networks. + + This method extracts the actor and critic networks associated with each learning strategy and organizes them into a + dictionary structure. The extracted networks include actors, and critics. The resulting + dictionary is typically used for saving and sharing these networks. + + Returns: + dict: The extracted actor and critic networks. + """ + actors = {} + + for u_id, unit_strategy in self.learning_role.rl_strats.items(): + actors[u_id] = unit_strategy.actor + + actors_and_critics = { + "actors": actors, + "critics": self.learning_role.critics, + "obs_dim": self.obs_dim, + "act_dim": self.act_dim, + "unique_obs_dim": self.unique_obs_dim, + } + + return actors_and_critics + + def get_values(self, states, actions): + """ + Gets values for a unit based on the observation using PPO. + + Args: + rl_strategy (RLStrategy): The strategy containing relevant information. + next_observation (torch.Tensor): The observation. + + Returns: + torch.Tensor: The value of the observation. + """ + #counter iterating over all agents for dynamic buffer slice + i=0 + + #get length of all states to pass it on as batch size, since the entire buffer is used for the PPO + buffer_length = len(states) + all_actions = actions.view(buffer_length, -1) + # Initialize an empty tensor to store all values + all_values = th.empty(0, buffer_length, 1) + + for u_id in self.learning_role.rl_strats.keys(): + + all_states = collect_obs_for_central_critic(states, i, self.obs_dim, self.unique_obs_dim, buffer_length) + + critic = self.learning_role.critics[u_id] + + # Pass the current states through the critic network to get value estimates. + values = critic(all_states, all_actions) + + if all_values.numel() == 0: + all_values = values + else: + all_values = th.cat((all_values, values), dim=1) + + i=i+1 + + return all_values + + def get_advantages(self, rewards, values): + + # Compute advantages using Generalized Advantage Estimation (GAE) + advantages = [] + advantage = 0 + returns = [] + + # Iterate through the collected experiences in reverse order to calculate advantages and returns + for t in reversed(range(len(rewards))): + + logger.debug(f"Reward: {t}") + + if t == len(rewards) - 1: + next_value = 0 + else: + next_value = values[t + 1] + + # Temporal difference delta Equation 12 from PPO paper + delta = ( + - values[t] + rewards[t] + self.gamma * next_value + ) # Use self.gamma for discount factor + + logger.debug(f"Delta: {delta}") + + # GAE advantage Equation 11 from PPO paper + advantage = ( + delta + self.gamma * self.gae_lambda * advantage + ) # Use self.gae_lambda for advantage estimation + + logger.debug(f"Last_advantage: {advantage}") + + advantages.insert(0, advantage) + returns.insert(0, advantage + values[t]) + + # Convert advantages and returns to tensors + advantages = th.tensor(advantages, dtype=th.float32, device=self.device) + returns = th.tensor(returns, dtype=th.float32, device=self.device) + + #Normalize advantages + #in accordance with spinning up and mappo version of PPO + mean_advantages = th.nanmean(advantages) + std_advantages = th.std(advantages) + advantages = (advantages - mean_advantages) / (std_advantages + 1e-5) + + #TODO: Should we detach here? I though becaus of normalisation not being included in backward + # but unsure if this is correct + return advantages, returns + + + def update_policy(self): + """ + Perform policy updates using PPO with the clipped objective. + """ + + logger.debug("Updating Policy") + # We will iterate for multiple epochs to update both the policy (actor) and value (critic) networks + # The number of epochs controls how many times we update using the same collected data (from the buffer). + + + # Retrieve experiences from the buffer + # The collected experiences (observations, actions, rewards, log_probs) are stored in the buffer. + full_transitions = self.learning_role.buffer.get() + full_values = self.get_values(full_transitions.observations, full_transitions.actions) + full_advantages, full_returns = self.get_advantages(full_transitions.rewards, full_values) + + + + for _ in range(self.gradient_steps): + self.n_updates += 1 + + transitions, batch_inds = self.learning_role.buffer.sample(self.batch_size) + states = transitions.observations + actions = transitions.actions + log_probs = transitions.log_probs + advantages = full_advantages[batch_inds] + returns = full_returns[batch_inds] + values = self.get_values(states, actions) # always use updated values --> check later if best + + # Iterate through over each agent's strategy + # Each agent has its own actor. Critic (value network) is centralized. + for u_id in self.learning_role.rl_strats.keys(): + + # Centralized + critic = self.learning_role.critics[u_id] + # Decentralized + actor = self.learning_role.rl_strats[u_id].actor + + + # Evaluate the new log-probabilities and entropy under the current policy + action_distribution = actor(states)[1] + new_log_probs = action_distribution.log_prob(actions).sum(-1) + + + entropy = action_distribution.entropy().sum(-1) + + # Compute the ratio of new policy to old policy + ratio = (new_log_probs - log_probs).exp() + + logger.debug(f"Ratio: {ratio}") + + # Surrogate loss calculation + surrogate1 = ratio * advantages + surrogate2 = ( + th.clamp(ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio) + * advantages + ) # Use self.clip_ratio + + logger.debug(f"surrogate1: {surrogate1}") + logger.debug(f"surrogate2: {surrogate2}") + + # Final policy loss (clipped surrogate loss) equation 7 from PPO paper + policy_loss = th.min(surrogate1, surrogate2).mean() + + logger.debug(f"policy_loss: {policy_loss}") + + # Value loss (mean squared error between the predicted values and returns) + value_loss = F.mse_loss(returns, values.squeeze()) + + logger.debug(f"value loss: {value_loss}") + + # Total loss: policy loss + value loss - entropy bonus + # euqation 9 from PPO paper multiplied with -1 to enable minimizing + total_loss = ( + - policy_loss + + self.vf_coef * value_loss + - self.entropy_coef * entropy.mean() + ) # Use self.vf_coef and self.entropy_coef + + logger.debug(f"total loss: {total_loss}") + + # Zero the gradients and perform backpropagation for both actor and critic + actor.optimizer.zero_grad() + critic.optimizer.zero_grad() + total_loss.backward(retain_graph=True) + + # Clip gradients to prevent gradient explosion + th.nn.utils.clip_grad_norm_( + actor.parameters(), self.max_grad_norm + ) # Use self.max_grad_norm + th.nn.utils.clip_grad_norm_( + critic.parameters(), self.max_grad_norm + ) # Use self.max_grad_norm + + # Perform optimization steps + actor.optimizer.step() + critic.optimizer.step() + + + + +def get_actions(rl_strategy, next_observation): + """ + Gets actions for a unit based on the observation using PPO. + + Args: + rl_strategy (RLStrategy): The strategy containing relevant information. + next_observation (torch.Tensor): The observation. + + Returns: + torch.Tensor: The sampled actions. + torch.Tensor: The log probability of the sampled actions. + """ + logger.debug("ppo.py: Get_actions method") + + actor = rl_strategy.actor + device = rl_strategy.device + learning_mode = rl_strategy.learning_mode + perform_evaluation = rl_strategy.perform_evaluation + + # Pass observation through the actor network to get action logits (mean of action distribution) + action_logits, action_distribution = actor(next_observation) + action_logits = action_logits.detach() + logger.debug(f"Action logits: {action_logits}") + + logger.debug(f"Action distribution: {action_distribution}") + + if learning_mode and not perform_evaluation: + + # Sample an action from the distribution + sampled_action = action_distribution.sample().to(device) + + else: + # If not in learning mode or during evaluation, use the mean of the action distribution + sampled_action = action_logits.detach() + + logger.debug(f"Sampled action: {sampled_action}") + + # Get the log probability of the sampled actions (for later PPO loss calculation) + # Sum the log probabilities across all action dimensions TODO: Why sum? + log_prob_action = action_distribution.log_prob(sampled_action).sum(dim=-1) + + # Detach the log probability tensor to stop gradient tracking (since we only need the value for later) + log_prob_action = log_prob_action.detach() + + logger.debug(f"Detached log probability of the sampled action: {log_prob_action}") + + + return sampled_action, log_prob_action + + + diff --git a/assume/reinforcement_learning/buffer.py b/assume/reinforcement_learning/buffer.py index efcc1f2ba..b84de1056 100644 --- a/assume/reinforcement_learning/buffer.py +++ b/assume/reinforcement_learning/buffer.py @@ -141,6 +141,8 @@ def add( self.rewards[self.pos : self.pos + len_obs] = reward.copy() self.pos += len_obs + + # Circular buffer if self.pos + len_obs >= self.buffer_size: self.full = True self.pos = 0 @@ -165,10 +167,359 @@ def sample(self, batch_size: int) -> ReplayBufferSamples: batch_inds = np.random.randint(0, upper_bound - 1, size=batch_size) data = ( - self.observations[batch_inds, :, :], + self.observations[batch_inds, :, :], # current observation self.actions[batch_inds, :, :], - self.observations[batch_inds + 1, :, :], + self.observations[batch_inds + 1, :, :], # next observation self.rewards[batch_inds], ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + + +class RolloutBufferTransitions(NamedTuple): + """ + A named tuple that represents the data stored in a rollout buffer for PPO. + + Attributes: + observations (torch.Tensor): The observations of the agents. + actions (torch.Tensor): The actions taken by the agents. + log_probs (torch.Tensor): The log probabilities of the actions taken. + advantages (torch.Tensor): The advantages calculated using GAE. + returns (torch.Tensor): The returns (discounted rewards) calculated. + """ + + observations: th.Tensor + actions: th.Tensor + rewards: th.Tensor + log_probs: th.Tensor + + +class RolloutBuffer: + def __init__( + self, + obs_dim: int, + act_dim: int, + n_rl_units: int, + device: str, + float_type, + buffer_size: int, + ): + """ + A class that represents a rollout buffer for storing observations, actions, and rewards. + The buffer starts empty and is dynamically expanded when needed. + + Args: + obs_dim (int): The dimension of the observation space. + act_dim (int): The dimension of the action space. + n_rl_units (int): The number of reinforcement learning units. + device (str): The device to use for storing the data (e.g., 'cpu' or 'cuda'). + float_type (torch.dtype): The data type to use for the stored data. + buffer_size (int): The maximal size of the buffer + """ + + self.obs_dim = obs_dim + self.act_dim = act_dim + self.n_rl_units = n_rl_units + self.device = device + self.buffer_size = buffer_size + + # Start with no buffer (None), will be created dynamically when first data is added + self.observations = ( + None # Stores the agent's observations (states) at each timestep + ) + self.actions = None # Stores the actions taken by the agent + self.rewards = None # Stores the rewards received after each action + self.log_probs = None # Stores the log-probabilities of the actions, used to compute the ratio for policy update + + # self.values = ( + # None # Stores the value estimates (critic's predictions) of each state + # ) + # self.advantages = None # Stores the computed advantages using GAE (Generalized Advantage Estimation), central to PPO's policy updates + # self.returns = None # Stores the discounted rewards (also known as returns), used to compute the value loss for training the critic + + self.pos = 0 + self.full = False + + # Datatypes for numpy and PyTorch + self.np_float_type = np.float16 if float_type == th.float16 else np.float32 + self.th_float_type = float_type + + def initialize_buffer(self, size): + """Initializes the buffer with the given size.""" + self.observations = np.zeros( + (size, self.n_rl_units, self.obs_dim), dtype=self.np_float_type + ) + self.actions = np.zeros( + (size, self.n_rl_units, self.act_dim), dtype=self.np_float_type + ) + self.rewards = np.zeros((size, self.n_rl_units), dtype=self.np_float_type) + self.log_probs = np.zeros((size, self.n_rl_units), dtype=np.float32) + # self.values = np.zeros((size, self.n_rl_units), dtype=np.float32) + # self.advantages = np.zeros((size, self.n_rl_units), dtype=np.float32) + # self.returns = np.zeros((size, self.n_rl_units), dtype=np.float32) + + def expand_buffer(self, additional_size): + """Expands the buffer by the given additional size and checks if there is enough memory available.""" + + # Calculation of the memory requirement for all 7 arrays + additional_memory_usage = ( + np.zeros( + (additional_size, self.n_rl_units, self.obs_dim), + dtype=self.np_float_type, + ).nbytes + + np.zeros( + (additional_size, self.n_rl_units, self.act_dim), + dtype=self.np_float_type, + ).nbytes + + np.zeros( + (additional_size, self.n_rl_units), dtype=self.np_float_type + ).nbytes # rewards + + np.zeros( + (additional_size, self.n_rl_units), dtype=np.float32 + ).nbytes # log_probs + # + np.zeros( + # (additional_size, self.n_rl_units), dtype=np.float32 + # ).nbytes # values + # + np.zeros( + # (additional_size, self.n_rl_units), dtype=np.float32 + # ).nbytes # advantages + # + np.zeros( + # (additional_size, self.n_rl_units), dtype=np.float32 + # ).nbytes # returns + ) + + # Check whether enough memory is available + if psutil is not None: + mem_available = psutil.virtual_memory().available + if additional_memory_usage > mem_available: + # Conversion to GB + additional_memory_usage_gb = additional_memory_usage / 1e9 + mem_available_gb = mem_available / 1e9 + raise MemoryError( + f"{additional_memory_usage_gb:.2f}GB required, but only {mem_available_gb:.2f}GB available." + ) + + if self.pos + additional_size > self.buffer_size: + warnings.warn( + f"Expanding the buffer will exceed the maximum buffer size of {self.buffer_size}. " + f"Current position: {self.pos}, additional size: {additional_size}." + ) + + self.observations = np.concatenate( + ( + self.observations, + np.zeros( + (additional_size, self.n_rl_units, self.obs_dim), + dtype=self.np_float_type, + ), + ), + axis=0, + ) + self.actions = np.concatenate( + ( + self.actions, + np.zeros( + (additional_size, self.n_rl_units, self.act_dim), + dtype=self.np_float_type, + ), + ), + axis=0, + ) + self.rewards = np.concatenate( + ( + self.rewards, + np.zeros( + (additional_size, self.n_rl_units), dtype=self.np_float_type + ), + ), + axis=0, + ) + self.log_probs = np.concatenate( + ( + self.log_probs, + np.zeros((additional_size, self.n_rl_units), dtype=np.float32), + ), + axis=0, + ) + # self.values = np.concatenate( + # ( + # self.values, + # np.zeros((additional_size, self.n_rl_units), dtype=np.float32), + # ), + # axis=0, + # ) + # self.advantages = np.concatenate( + # ( + # self.advantages, + # np.zeros((additional_size, self.n_rl_units), dtype=np.float32), + # ), + # axis=0, + # ) + # self.returns = np.concatenate( + # ( + # self.returns, + # np.zeros((additional_size, self.n_rl_units), dtype=np.float32), + # ), + # axis=0, + # ) + + def add( + self, + obs: np.array, + actions: np.array, + reward: np.array, + log_probs: np.array, + ): + """ + Adds an observation, action, reward, and log probabilities of all agents to the rollout buffer. + If the buffer does not exist, it will be initialized. If the buffer is full, it will be expanded. + + Args: + obs (numpy.ndarray): The observation to add. + actions (numpy.ndarray): The actions to add. + reward (numpy.ndarray): The reward to add. + log_probs (numpy.ndarray): The log probabilities of the actions taken. + """ + len_obs = obs.shape[0] + + if self.observations is None: + # Initialize buffer with initial size if it's the first add + self.initialize_buffer(len_obs) + + elif self.pos + len_obs > self.observations.shape[0]: + # If the buffer is full, expand it + self.expand_buffer(len_obs) + + # Add data to the buffer + self.observations[self.pos : self.pos + len_obs] = obs.copy() + self.actions[self.pos : self.pos + len_obs] = actions.copy() + self.rewards[self.pos : self.pos + len_obs] = reward.copy() + self.log_probs[self.pos : self.pos + len_obs] = log_probs.squeeze(-1).copy() + + self.pos += len_obs + + def reset(self): + """ + Resets the buffer, clearing all stored data. + Might be needed if policy is changed within one episode, then it needs to be killed and initalized again. + + """ + self.observations = None + self.actions = None + self.rewards = None + self.log_probs = None + # self.values = None + # self.advantages = None + # self.returns = None + self.pos = 0 + self.full = False + + # def compute_returns_and_advantages(self, last_values, dones): + # """ + # Compute the returns and advantages using Generalized Advantage Estimation (GAE). + + # Args: + # last_values (np.array): Value estimates for the last observation. + # dones (np.array): Whether the episode has finished for each agent. + # """ + # # Initialize the last advantage to 0. This will accumulate as we move backwards in time. + # last_advantage = 0 + + # # Loop backward through all the steps in the buffer to calculate returns and advantages. + # # This is because GAE (Generalized Advantage Estimation) relies on future rewards, + # # so we compute it from the last step back to the first step. + # for step in reversed(range(self.pos)): + + # # If we are at the last step in the buffer + # if step == self.pos - 1: + # # If it's the last step, check whether the episode has finished using `dones`. + # # `next_non_terminal` is 0 if the episode has ended, 1 if it's ongoing. + # next_non_terminal = 1.0 - dones + # # Use the provided last values (value estimates for the final observation in the episode) + # next_values = last_values + # else: + # # For other steps, use the mask to determine if the episode is ongoing. + # # If `masks[step + 1]` is 1, the episode is ongoing; if it's 0, the episode has ended. + # next_non_terminal = self.masks[step + 1] + # # Use the value of the next time step to compute the future returns + # next_values = self.values[step + 1] + + # # Temporal difference (TD) error, also known as delta: + # # This is the difference between the reward obtained at this step and the estimated value of this step + # # plus the discounted value of the next step (if the episode is ongoing). + # # This measures how "off" the value function is at predicting the future return. + # delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + + # # Compute the advantage for this step using GAE: + # # `delta` is the immediate advantage, and we add to it the discounted future advantage, + # # scaled by the factor `lambda` (from GAE). This allows for a more smooth approximation of advantage. + # # `next_non_terminal` ensures that if the episode has ended, the future advantage stops accumulating. + # self.advantages[step] = last_advantage = delta + self.gamma * self.gae_lambda * next_non_terminal * last_advantage + + # # The return is the advantage plus the baseline value estimate. + # # This makes sure that the return includes both the immediate rewards and the learned value of future rewards. + # self.returns[step] = self.advantages[step] + self.values[step] + + def to_torch(self, array: np.array, copy=True): + """ + Converts a numpy array to a PyTorch tensor. Note: It copies the data by default. + + Args: + array (numpy.ndarray): The numpy array to convert. + copy (bool, optional): Whether to copy or not the data + (may be useful to avoid changing things by reference). Defaults to True. + + Returns: + torch.Tensor: The converted PyTorch tensor. + """ + + if copy: + return th.tensor(array, dtype=self.th_float_type, device=self.device) + + return th.as_tensor(array, dtype=self.th_float_type, device=self.device) + + def get(self) -> RolloutBufferTransitions: + """ + Get all data stored in the buffer and convert it to PyTorch tensors. + Returns the observations, actions, log_probs, advantages, returns, and masks. + """ + data = ( + self.observations[: self.pos], + self.actions[: self.pos], + self.rewards[: self.pos], + self.log_probs[: self.pos], + # self.masks[:self.pos], + ) + + return RolloutBufferTransitions(*tuple(map(self.to_torch, data))) + + + def sample(self, batch_size: int) -> RolloutBufferTransitions: + """ + Samples a random batch of experiences from the rollout buffer. + Unlike the replay buffer, this samples only from the current rollout data (up to self.pos) + and includes log probabilities needed for PPO updates. + + Args: + batch_size (int): The number of experiences to sample. + + Returns: + RolloutBufferTransitions: A named tuple containing the sampled observations, actions, rewards, + and log probabilities. + + Raises: + Exception: If there are less than batch_size entries in the buffer. + """ + if self.pos < batch_size: + raise Exception(f"Not enough entries in buffer (need {batch_size}, have {self.pos})") + + batch_inds = np.random.randint(0, self.pos, size=batch_size) + + data = ( + self.observations[batch_inds, :, :], + self.actions[batch_inds, :, :], + self.rewards[batch_inds], + self.log_probs[batch_inds], + ) + + return RolloutBufferTransitions(*tuple(map(self.to_torch, data))), batch_inds # also return the indices of the sampled minibatch episodes \ No newline at end of file diff --git a/assume/reinforcement_learning/learning_role.py b/assume/reinforcement_learning/learning_role.py index aadcc2b41..94f64bc98 100644 --- a/assume/reinforcement_learning/learning_role.py +++ b/assume/reinforcement_learning/learning_role.py @@ -12,9 +12,9 @@ from assume.common.base import LearningConfig, LearningStrategy from assume.common.utils import datetime2timestamp -from assume.reinforcement_learning.algorithms.base_algorithm import RLAlgorithm from assume.reinforcement_learning.algorithms.matd3 import TD3 -from assume.reinforcement_learning.buffer import ReplayBuffer +from assume.reinforcement_learning.algorithms.ppo import PPO +from assume.reinforcement_learning.buffer import ReplayBuffer, RolloutBuffer from assume.reinforcement_learning.learning_utils import linear_schedule_func logger = logging.getLogger(__name__) @@ -31,23 +31,29 @@ class Learning(Role): """ + # TD3 and PPO (Replay buffer, gradient steps, early stopping, self.eval_episodes_done potentiall irrelevant for PPO) def __init__( self, learning_config: LearningConfig, start: datetime = None, end: datetime = None, ): - # how many learning roles do exist and how are they named - self.buffer: ReplayBuffer = None + # General parameters + self.rl_algorithm_name = learning_config.get("algorithm", "matd3") + self.early_stopping_steps = learning_config.get(self.rl_algorithm_name, {}).get( + "early_stopping_steps", 10 + ) + self.early_stopping_threshold = learning_config.get( + self.rl_algorithm_name, {} + ).get("early_stopping_threshold", 0.05) self.episodes_done = 0 + # dict[key, value] self.rl_strats: dict[int, LearningStrategy] = {} - self.rl_algorithm = learning_config.get("algorithm", "matd3") - self.actor_architecture = learning_config.get("actor_architecture", "mlp") + + # For centralized critic in MATD3 self.critics = {} - self.target_critics = {} # define whether we train model or evaluate it - self.training_episodes = learning_config["training_episodes"] self.learning_mode = learning_config["learning_mode"] self.continue_learning = learning_config["continue_learning"] self.perform_evaluation = learning_config["perform_evaluation"] @@ -56,39 +62,26 @@ def __init__( "trained_policies_load_path", self.trained_policies_save_path ) - # if early_stopping_steps are not provided then set default to no early stopping (early_stopping_steps need to be greater than validation_episodes) - self.early_stopping_steps = learning_config.get( - "early_stopping_steps", - int( - self.training_episodes - / learning_config.get("validation_episodes_interval", 5) - + 1 - ), + + self.learning_rate = learning_config["learning_rate"] + self.actor_architecture = learning_config.get(self.rl_algorithm_name, {}).get( + "actor_architecture", "mlp" ) - self.early_stopping_threshold = learning_config.get( - "early_stopping_threshold", 0.05 + self.training_episodes = learning_config[ + "training_episodes" + ] + self.train_freq = learning_config.get(self.rl_algorithm_name, {}).get( + "train_freq" ) - - cuda_device = ( - learning_config["device"] - if "cuda" in learning_config.get("device", "cpu") - else "cpu" + self.batch_size = learning_config.get(self.rl_algorithm_name, {}).get( + "batch_size", 128 ) - self.device = th.device(cuda_device if th.cuda.is_available() else "cpu") - - # future: add option to choose between float16 and float32 - # float_type = learning_config.get("float_type", "float32") - self.float_type = th.float - - th.backends.cuda.matmul.allow_tf32 = True - th.backends.cudnn.allow_tf32 = True - if start is not None: self.start = datetime2timestamp(start) if end is not None: self.end = datetime2timestamp(end) - self.learning_rate = learning_config.get("learning_rate", 1e-4) + self.gamma = learning_config.get(self.rl_algorithm_name, {}).get("gamma", 0.99) self.learning_rate_schedule = learning_config.get( "learning_rate_schedule", None ) @@ -103,34 +96,96 @@ def __init__( self.calc_noise_from_progress = linear_schedule_func(noise_dt) else: self.calc_noise_from_progress = lambda x: noise_dt - + # if we do not have initial experience collected we will get an error as no samples are available on the # buffer from which we can draw experience to adapt the strategy, hence we set it to minimum one episode - self.episodes_collecting_initial_experience = max( - learning_config.get("episodes_collecting_initial_experience", 5), 1 + learning_config.get(self.rl_algorithm_name, {}).get( + "episodes_collecting_initial_experience", 5 + ), + 1, + ) + + # if early_stopping_steps are not provided then set default to no early stopping (early_stopping_steps need to be greater than validation_episodes) + self.early_stopping_steps = learning_config.get( + "early_stopping_steps", + int( + self.training_episodes + / learning_config.get("validation_episodes_interval", 5) + + 1 + ), + ) + self.early_stopping_threshold = learning_config.get( + "early_stopping_threshold", 0.05 + ) + + # if early_stopping_steps are not provided then set default to no early stopping (early_stopping_steps need to be greater than validation_episodes) + self.early_stopping_steps = learning_config.get( + "early_stopping_steps", + int( + self.training_episodes + / learning_config.get("validation_episodes_interval", 5) + + 1 + ), + ) + self.early_stopping_threshold = learning_config.get( + "early_stopping_threshold", 0.05 ) - self.train_freq = learning_config.get("train_freq", "1h") self.gradient_steps = ( int(self.train_freq[:-1]) if learning_config.get("gradient_steps", -1) == -1 else learning_config["gradient_steps"] ) - self.batch_size = learning_config.get("batch_size", 128) - self.gamma = learning_config.get("gamma", 0.99) - self.eval_episodes_done = 0 + # Algorithm-specific parameters + if self.rl_algorithm_name == "matd3": + self.buffer: ReplayBuffer = None + self.target_critics = {} + self.noise_sigma = learning_config["matd3"]["noise_sigma"] + self.noise_scale = learning_config["matd3"]["noise_scale"] + + elif self.rl_algorithm_name == "ppo": + self.buffer: RolloutBuffer = None + + # Potentially more parameters for PPO + self.steps_per_epoch = learning_config["ppo"].get("steps_per_epoch", 10) + self.clip_ratio = learning_config["ppo"].get("clip_ratio", 0.2) + self.entropy_coeff = learning_config["ppo"].get("entropy_coeff", 0.02) + self.value_coeff = learning_config["ppo"].get("value_coeff", 0.5) + self.max_grad_norm = learning_config["ppo"].get("max_grad_norm", 0.5) + self.gae_lambda = learning_config["ppo"].get("gae_lambda", 0.95) + + + + cuda_device = ( + learning_config["device"] + if "cuda" in learning_config.get("device", "cpu") + else "cpu" + ) + + self.device = th.device(cuda_device if th.cuda.is_available() else "cpu") + + # future: add option to choose between float16 and float32 + # float_type = learning_config.get("float_type", "float32") + self.float_type = th.float + + th.backends.cuda.matmul.allow_tf32 = True + th.backends.cudnn.allow_tf32 = True + - # function that initializes learning, needs to be an extra function so that it can be called after buffer is given to Role - self.create_learning_algorithm(self.rl_algorithm) - # store evaluation values + # Initialize the algorithm depending on the type + self.create_learning_algorithm(self.rl_algorithm_name) + + # Initialize evaluation metrics + self.eval_episodes_done = 0 self.max_eval = defaultdict(lambda: -1e9) self.rl_eval = defaultdict(list) - # list of avg_changes + # List of avg changes self.avg_rewards = [] + # TD3 and PPO def load_inter_episodic_data(self, inter_episodic_data): """ Load the inter-episodic data from the dict stored across simulation runs. @@ -139,6 +194,7 @@ def load_inter_episodic_data(self, inter_episodic_data): inter_episodic_data (dict): The inter-episodic data to be loaded. """ + # TODO: Make this function of algorithm so that we loose case sensitivity here self.episodes_done = inter_episodic_data["episodes_done"] self.eval_episodes_done = inter_episodic_data["eval_episodes_done"] self.max_eval = inter_episodic_data["max_eval"] @@ -146,13 +202,15 @@ def load_inter_episodic_data(self, inter_episodic_data): self.avg_rewards = inter_episodic_data["avg_all_eval"] self.buffer = inter_episodic_data["buffer"] - # if enough initial experience was collected according to specifications in learning config - # turn off initial exploration and go into full learning mode - if self.episodes_done > self.episodes_collecting_initial_experience: - self.turn_off_initial_exploration() + if self.rl_algorithm_name == "matd3": + # if enough initial experience was collected according to specifications in learning config + # turn off initial exploration and go into full learning mode + if self.episodes_done > self.episodes_collecting_initial_experience: + self.turn_off_initial_exploration() self.initialize_policy(inter_episodic_data["actors_and_critics"]) + # TD3 and PPO def get_inter_episodic_data(self): """ Dump the inter-episodic data to a dict for storing across simulation runs. @@ -171,6 +229,7 @@ def get_inter_episodic_data(self): "actors_and_critics": self.rl_algorithm.extract_policy(), } + # TD3 and PPO def setup(self) -> None: """ Set up the learning role for reinforcement learning training. @@ -197,16 +256,37 @@ def save_buffer_and_update(self, content: dict, meta: dict) -> None: meta (dict): The metadata associated with the message. (not needed yet) """ - if content.get("type") == "save_buffer_and_update": - data = content["data"] - self.buffer.add( - obs=data[0], - actions=data[1], - reward=data[2], - ) + if self.rl_algorithm_name == "matd3": + if content.get("type") == "save_buffer_and_update": + data = content["data"] + self.buffer.add( + obs=data[0], + actions=data[1], + reward=data[2], + ) + + self.update_policy() + + elif self.rl_algorithm_name == "ppo": + + logger.debug("save_buffer_and_update in learning_role.py") - self.update_policy() + if content.get("type") == "save_buffer_and_update": + data = content["data"] + self.buffer.add( + obs=data[0], + actions=data[1], + reward=data[2], + log_probs=data[3], + ) + + self.update_policy() + + # since the PPO is an on-policy algorithm it onyl uses the expercience collected with the current policy + # after the policy-update which ultimately changes the policy, theb buffer needs to be cleared + self.buffer.reset() + # TD3 def turn_off_initial_exploration(self) -> None: """ Disable initial exploration mode for all learning strategies. @@ -224,36 +304,31 @@ def get_progress_remaining(self) -> float: Get the remaining learning progress from the simulation run. """ - total_duration = self.end - self.start - elapsed_duration = self.context.current_timestamp - self.start + for _, unit in self.rl_strats.items(): + unit.action_noise.scale = stored_scale - learning_episodes = ( - self.training_episodes - self.episodes_collecting_initial_experience - ) + def get_noise_scale(self) -> None: + """ + Get the noise scale from the first learning strategy (unit) in rl_strats. - if self.episodes_done < self.episodes_collecting_initial_experience: - progress_remaining = 1 - else: - progress_remaining = ( - 1 - - ( - (self.episodes_done - self.episodes_collecting_initial_experience) - / learning_episodes - ) - - ((1 / learning_episodes) * (elapsed_duration / total_duration)) - ) + Notes: + The noise scale is the same for all learning strategies (units) in rl_strats, so we only need to get it from one unit. + It is only depended on the number of updates done so far, which is determined by the number of episodes done and the update frequency. + + """ + stored_scale = list(self.rl_strats.values())[0].action_noise.scale - return progress_remaining + return stored_scale - def create_learning_algorithm(self, algorithm: RLAlgorithm): + def create_learning_algorithm(self, algorithm: str): """ - Create and initialize the reinforcement learning algorithm. + Create and initialize the reinforcement learning algorithm, based on defined algorithm type. This method creates and initializes the reinforcement learning algorithm based on the specified algorithm name. The algorithm is associated with the learning role and configured with relevant hyperparameters. Args: - algorithm (RLAlgorithm): The name of the reinforcement learning algorithm. + algorithm (str): The name of the reinforcement learning algorithm. """ if algorithm == "matd3": self.rl_algorithm = TD3( @@ -265,9 +340,27 @@ def create_learning_algorithm(self, algorithm: RLAlgorithm): gamma=self.gamma, actor_architecture=self.actor_architecture, ) + elif algorithm == "ppo": + self.rl_algorithm = PPO( + learning_role=self, + learning_rate=self.learning_rate, + gamma=self.gamma, # Discount factor + gradient_steps=self.gradient_steps, # Number of epochs for policy updates + clip_ratio=self.clip_ratio, # PPO-specific clipping parameter + vf_coef=self.value_coeff, # Coefficient for value function loss + entropy_coef=self.entropy_coeff, # Coefficient for entropy to encourage exploration + max_grad_norm=self.max_grad_norm, # Maximum gradient norm for clipping + gae_lambda=self.gae_lambda, # Lambda for Generalized Advantage Estimation (GAE) + actor_architecture=self.actor_architecture, # Actor network architecture + ) else: logger.error(f"Learning algorithm {algorithm} not implemented!") + # Loop over rl_strats + # self.rl_algorithm an die Learning Strategy übergeben + # Damit die Learning Strategy auf act/get_actions zugreifen kann + + # TD3 def initialize_policy(self, actors_and_critics: dict = None) -> None: """ Initialize the policy of the reinforcement learning agent considering the respective algorithm. @@ -289,6 +382,7 @@ def initialize_policy(self, actors_and_critics: dict = None) -> None: f"Folder with pretrained policies {directory} does not exist" ) + # TD3 and PPO def update_policy(self) -> None: """ Update the policy of the reinforcement learning agent. @@ -300,8 +394,11 @@ def update_policy(self) -> None: Notes: This method is typically scheduled to run periodically during training to continuously improve the agent's policy. """ - if self.episodes_done > self.episodes_collecting_initial_experience: + if self.rl_algorithm_name == "ppo": self.rl_algorithm.update_policy() + else: + if self.episodes_done > self.episodes_collecting_initial_experience: + self.rl_algorithm.update_policy() def compare_and_save_policies(self, metrics: dict) -> bool: """ diff --git a/assume/reinforcement_learning/learning_unit_operator.py b/assume/reinforcement_learning/learning_unit_operator.py index 9f9c5f427..31b3fde86 100644 --- a/assume/reinforcement_learning/learning_unit_operator.py +++ b/assume/reinforcement_learning/learning_unit_operator.py @@ -146,14 +146,21 @@ def write_learning_to_output(self, orderbook: Orderbook, market_id: str) -> None } ) + # Only for MATD3, not for PPO + # Check if exploration_noise is not empty (MATD3) action_tuple = unit.outputs["actions"].at[start] - noise_tuple = unit.outputs["exploration_noise"].at[start] + if "exploration_noise" in unit.outputs and hasattr(unit.outputs["exploration_noise"].at[start], "numel"): + noise_tuple = unit.outputs["exploration_noise"].at[start] + action_dim = action_tuple.numel() for i in range(action_dim): - output_dict[f"exploration_noise_{i}"] = ( - noise_tuple[i] if action_dim > 1 else noise_tuple - ) + # Only for MATD3, not for PPO + if "exploration_noise" in unit.outputs and hasattr(unit.outputs["exploration_noise"].loc[start], "numel"): + output_dict[f"exploration_noise_{i}"] = ( + noise_tuple[i] if action_dim > 1 else noise_tuple + ) + # For MATD3 and PPO output_dict[f"actions_{i}"] = ( action_tuple[i] if action_dim > 1 else action_tuple ) @@ -172,6 +179,7 @@ def write_learning_to_output(self, orderbook: Orderbook, market_id: str) -> None }, ) + # Executed in the interval set by train_frequency async def write_to_learning_role( self, ) -> None: @@ -179,6 +187,9 @@ async def write_to_learning_role( Writes learning results to the learning agent. """ + + # print("write_to_learning_role in learning_unit_operator.py") + if len(self.rl_units) == 0: return @@ -187,11 +198,12 @@ async def write_to_learning_role( device = self.learning_strategies["device"] learning_unit_count = len(self.rl_units) + # How many reward values are available in the first learning unit -> equals the number of steps values_len = len(self.rl_units[0].outputs["rl_rewards"]) # return if no data is available if values_len == 0: return - + all_observations = th.zeros( (values_len, learning_unit_count, obs_dim), device=device ) @@ -200,17 +212,38 @@ async def write_to_learning_role( ) all_rewards = [] + # For PPO + # dimensions: steps, learning units, one log-prob for multiple observations/actions dimensions + all_log_probs = th.zeros( + (values_len, learning_unit_count, 1), device=device + ) + + # i is the index of the learning unit, unit is the learning unit object for i, unit in enumerate(self.rl_units): + # Convert pandas Series to torch Tensor obs_tensor = th.stack(unit.outputs["rl_observations"][:values_len], dim=0) + actions_tensor = th.stack( unit.outputs["rl_actions"][:values_len], dim=0 ).reshape(-1, act_dim) + # In the second dimension, the tensors include the number of the learning units + # Three dimensions: Steps, learning units, observation/action dimensions all_observations[:, i, :] = obs_tensor all_actions[:, i, :] = actions_tensor all_rewards.append(unit.outputs["rl_rewards"]) + # For PPO + # Check whether the list of tensors is not empty and whether the tensors contain elements + if unit.outputs["rl_log_probs"]: # and all(t.numel() > 0 for t in unit.outputs["rl_log_probs"][:values_len]): + + log_prob_tensor = th.stack( + unit.outputs["rl_log_probs"][:values_len], dim=0 + ).unsqueeze(-1) + + all_log_probs[:, i, :] = log_prob_tensor + # reset the outputs unit.reset_saved_rl_data() @@ -227,8 +260,18 @@ async def write_to_learning_role( .numpy() .reshape(-1, learning_unit_count, act_dim) ) + + all_rewards = np.array(all_rewards).reshape(-1, learning_unit_count) - rl_agent_data = (all_observations, all_actions, all_rewards) + + # For PPO + if unit.outputs["rl_log_probs"]: # and all(t.numel() > 0 for t in unit.outputs["rl_log_probs"][:values_len]): + all_log_probs = all_log_probs.detach().cpu().numpy().reshape(-1, learning_unit_count, 1) + + rl_agent_data = (all_observations, all_actions, all_rewards, all_log_probs) + # For MATD3 + else: + rl_agent_data = (all_observations, all_actions, all_rewards) learning_role_addr = self.context.data.get("learning_agent_addr") diff --git a/assume/reinforcement_learning/learning_utils.py b/assume/reinforcement_learning/learning_utils.py index 661240231..bc8e90ffc 100644 --- a/assume/reinforcement_learning/learning_utils.py +++ b/assume/reinforcement_learning/learning_utils.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: AGPL-3.0-or-later from collections.abc import Callable +from datetime import timedelta +import pandas as pd from datetime import datetime from typing import TypedDict @@ -10,12 +12,14 @@ import torch as th +# TD3 and PPO class ObsActRew(TypedDict): observation: list[th.Tensor] action: list[th.Tensor] reward: list[th.Tensor] +# TD3 and PPO observation_dict = dict[list[datetime], ObsActRew] # A schedule takes the remaining progress as input @@ -23,6 +27,8 @@ class ObsActRew(TypedDict): Schedule = Callable[[float], float] + +# TD3 # Ornstein-Uhlenbeck Noise # from https://github.com/songrotek/DDPG/blob/master/ou_noise.py class OUNoise: @@ -56,6 +62,7 @@ def noise(self): return noise +# TD3 class NormalActionNoise: """ A gaussian action noise @@ -80,6 +87,7 @@ def update_noise_decay(self, updated_decay: float): self.dt = updated_decay +# TD3 def polyak_update(params, target_params, tau: float): """ Perform a Polyak average update on ``target_params`` using ``params``: @@ -153,3 +161,47 @@ def func(_): return val return func + + +def collect_obs_for_central_critic( + states: th.Tensor, i: int, obs_dim: int, unique_obs_dim: int, batch_size: int +) -> th.Tensor: + """ + This function samels the observations from allagents for the central critic. + In detail it takes all actions and concates all unique_obs of the agents and one time the similar observations. + + Args: + actions (th.Tensor): The actions + n_agents (int): Number of agents + n_actions (int): Number of actions + + Returns: + th.Tensor: The sampled actions + """ + # Sample actions for the central critic + + # this takes the unique observations from all other agents assuming that + # the unique observations are at the end of the observation vector + temp = th.cat( + ( + states[:, :i, obs_dim - unique_obs_dim :].reshape( + batch_size, -1 + ), + states[ + :, i + 1 :, obs_dim - unique_obs_dim : + ].reshape(batch_size, -1), + ), + axis=1, + ) + + # the final all_states vector now contains the current agent's observation + # and the unique observations from all other agents + all_states = th.cat( + (states[:, i, :].reshape(batch_size, -1), temp), axis=1 + ).view(batch_size, -1) + + + return all_states + + + diff --git a/assume/reinforcement_learning/neural_network_architecture.py b/assume/reinforcement_learning/neural_network_architecture.py index 565d6998e..3cd8babc0 100644 --- a/assume/reinforcement_learning/neural_network_architecture.py +++ b/assume/reinforcement_learning/neural_network_architecture.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: AGPL-3.0-or-later +import numpy as np import torch as th from torch import nn from torch.nn import functional as F @@ -91,6 +92,55 @@ def q1_forward(self, obs, actions): return x +class CriticPPO(nn.Module): + """Critic Network for Proximal Policy Optimization (PPO). + + Centralized critic, meaning that is has access to the observation space of all competitive learning agents. + + Args: + n_agents (int): Number of agents + obs_dim (int): Dimension of each state + act_dim (int): Dimension of each action + """ + + def __init__(self, n_agents: int, obs_dim: int, act_dim: int, float_type, unique_obs_dim: int = 0): + super().__init__() + + self.obs_dim = obs_dim + unique_obs_dim * (n_agents - 1) + self.act_dim = act_dim * n_agents + + if n_agents <= 50: + self.FC_1 = nn.Linear(self.obs_dim + self.act_dim, 512, dtype=float_type) + self.FC_2 = nn.Linear(512, 256, dtype=float_type) + self.FC_3 = nn.Linear(256, 128, dtype=float_type) + self.FC_4 = nn.Linear(128, 1, dtype=float_type) + else: + self.FC_1 = nn.Linear(self.obs_dim + self.act_dim, 1024, dtype=float_type) + self.FC_2 = nn.Linear(1024, 512, dtype=float_type) + self.FC_3 = nn.Linear(512, 128, dtype=float_type) + self.FC_4 = nn.Linear(128, 1, dtype=float_type) + + for layer in [self.FC_1, self.FC_2, self.FC_3, self.FC_4]: + nn.init.orthogonal_(layer.weight, gain=np.sqrt(2)) + nn.init.constant_(layer.bias, 0.0) + + def forward(self, obs, actions): + """ + Args: + obs (torch.Tensor): The observations + actions (torch.Tensor): The actions + + """ + + xu = th.cat([obs, actions], dim=-1) + + x = F.relu(self.FC_1(xu)) + x = F.relu(self.FC_2(x)) + x = F.relu(self.FC_3(x)) + value = self.FC_4(x) + + return value + class Actor(nn.Module): """ Parent class for actor networks. @@ -115,10 +165,44 @@ def __init__(self, obs_dim: int, act_dim: int, float_type, *args, **kwargs): def forward(self, obs): x = F.relu(self.FC1(obs)) x = F.relu(self.FC2(x)) + # Works with MATD3, output of softsign: [-1, 1] x = F.softsign(self.FC3(x)) + # x = th.tanh(self.FC3(x)) + # Tested for PPO, scales the output to [0, 1] range + #x = th.sigmoid(self.FC3(x)) + return x + +class DistActor(MLPActor): + """ + The actor based on the neural network MLP actor that contrcuts a distribution for the action defintion. + """ + def __init__(self, obs_dim: int, act_dim: int, float_type, *args, **kwargs): + super().__init__(obs_dim, act_dim, float_type, *args, **kwargs) + + + def initialize_weights(self, final_gain=np.sqrt(2)): + for layer in [self.FC1, self.FC2]: + nn.init.orthogonal_(layer.weight, gain=np.sqrt(2)) + nn.init.constant_(layer.bias, 0.0) + # use smaller gain for final layer + nn.init.orthogonal_(self.FC3.weight, gain=final_gain) + nn.init.constant_(self.FC3.bias, 0.0) + + + def forward(self, obs): + x = F.relu(self.FC1(obs)) + x = F.relu(self.FC2(x)) + # Works with MATD3, output of softsign: [-1, 1] + x = F.softsign(self.FC3(x)) + + # Create a normal distribution for continuous actions (with assumed standard deviation of + # TODO: 0.01/0.0 as in marlbenchmark or 1.0 or sheduled decrease?) + dist = th.distributions.Normal(x, 0.2) # --> eventuell als hyperparameter und eventuell sigmoid (0,1) + + return x, dist class LSTMActor(Actor): diff --git a/assume/reinforcement_learning/raw_ppo.py b/assume/reinforcement_learning/raw_ppo.py new file mode 100644 index 000000000..0a59a1e59 --- /dev/null +++ b/assume/reinforcement_learning/raw_ppo.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: ASSUME Developers +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +from collections import deque + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + + +class MLPActorCritic(nn.Module): + """ + Simple MLP Actor-Critic network with separate actor and critic heads. + """ + + def __init__(self, obs_dim, act_dim): + super().__init__() + self.shared = nn.Sequential( + nn.Linear(obs_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU() + ) + # Actor head + self.actor = nn.Linear(64, act_dim) + # Critic head + self.critic = nn.Linear(64, 1) + + def forward(self, obs): + shared_out = self.shared(obs) + return self.actor(shared_out), self.critic(shared_out) + + def act(self, obs): + logits, value = self(obs) + dist = torch.distributions.Categorical(logits=logits) + action = dist.sample() + return action, dist.log_prob(action), value + + def evaluate_actions(self, obs, actions): + logits, values = self(obs) + dist = torch.distributions.Categorical(logits=logits) + action_log_probs = dist.log_prob(actions) + dist_entropy = dist.entropy() + return action_log_probs, torch.squeeze(values, dim=-1), dist_entropy + + +class PPO: + """ + Proximal Policy Optimization (PPO) implementation in PyTorch. + """ + + def __init__( + self, + env, + actor_critic, + clip_param=0.2, + entcoeff=0.01, + optim_stepsize=1e-3, + optim_epochs=4, + gamma=0.99, + lam=0.95, + batch_size=64, + ): + self.env = env + self.actor_critic = actor_critic + self.clip_param = clip_param + self.entcoeff = entcoeff + self.optim_epochs = optim_epochs + self.optim_stepsize = optim_stepsize + self.gamma = gamma + self.lam = lam + self.batch_size = batch_size + self.optimizer = optim.Adam( + self.actor_critic.parameters(), lr=self.optim_stepsize + ) + + def discount_rewards(self, rewards, dones, gamma): + """ + Compute discounted rewards. + """ + discounted_rewards = [] + r = 0 + for reward, done in zip(reversed(rewards), reversed(dones)): + if done: + r = 0 + r = reward + gamma * r + discounted_rewards.insert(0, r) + return discounted_rewards + + def compute_gae(self, rewards, values, dones, gamma, lam): + """ + Compute Generalized Advantage Estimation (GAE). + """ + adv = 0 + advantages = [] + for t in reversed(range(len(rewards))): + delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] + adv = delta + gamma * lam * adv * (1 - dones[t]) + advantages.insert(0, adv) + return advantages + + def rollout(self, timesteps_per_actorbatch): + """ + Collect trajectories by running the policy in the environment. + """ + # Reset env + obs = self.env.reset() + ( + obs_list, + actions_list, + rewards_list, + dones_list, + log_probs_list, + values_list, + ) = [], [], [], [], [], [] + for _ in range(timesteps_per_actorbatch): + obs_tensor = torch.FloatTensor(obs).unsqueeze(0) + action, log_prob, value = self.actor_critic.act(obs_tensor) + + obs_list.append(obs_tensor) + actions_list.append(action) + log_probs_list.append(log_prob) + values_list.append(value) + + next_obs, reward, done, _ = self.env.step(action.item()) + rewards_list.append(reward) + dones_list.append(done) + + obs = next_obs + if done: + obs = self.env.reset() + + obs_tensor = torch.FloatTensor(obs).unsqueeze(0) + _, _, last_value = self.actor_critic.act(obs_tensor) + + values_list.append(last_value) + + return { + "observations": torch.cat(obs_list), + "actions": torch.cat(actions_list), + "log_probs": torch.cat(log_probs_list), + "values": torch.cat(values_list), + "rewards": rewards_list, + "dones": dones_list, + } + + def ppo_update(self, batch, clip_param, entcoeff): + """ + Update the policy using PPO objective. + """ + observations, actions, old_log_probs, returns, advantages = batch + + for _ in range(self.optim_epochs): + new_log_probs, values, entropy = self.actor_critic.evaluate_actions( + observations, actions + ) + + ratio = torch.exp(new_log_probs - old_log_probs) + surr1 = ratio * advantages + surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages + + policy_loss = -torch.min(surr1, surr2).mean() + value_loss = (returns - values).pow(2).mean() + entropy_loss = entropy.mean() + + loss = policy_loss + 0.5 * value_loss - entcoeff * entropy_loss + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + def train(self, total_timesteps, timesteps_per_actorbatch, log_interval=100): + """ + Main training loop. + """ + total_timesteps_done = 0 + reward_history = deque(maxlen=100) + + while total_timesteps_done < total_timesteps: + # Rollout + batch = self.rollout(timesteps_per_actorbatch) + observations = batch["observations"] + actions = batch["actions"] + old_log_probs = batch["log_probs"] + rewards = batch["rewards"] + dones = batch["dones"] + values = batch["values"].detach() + + # Compute discounted rewards and advantages + returns = torch.FloatTensor( + self.discount_rewards(rewards, dones, self.gamma) + ) + advantages = torch.FloatTensor( + self.compute_gae(rewards, values.numpy(), dones, self.gamma, self.lam) + ) + + # Normalize advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # Update the policy using PPO + batch_data = (observations, actions, old_log_probs, returns, advantages) + self.ppo_update(batch_data, self.clip_param, self.entcoeff) + + total_timesteps_done += timesteps_per_actorbatch + avg_reward = np.mean(rewards) + reward_history.append(avg_reward) + + if total_timesteps_done % log_interval == 0: + print( + f"Timesteps: {total_timesteps_done}, Avg Reward: {np.mean(reward_history)}" + ) + + +# Example usage with CartPole environment +env = gym.make("CartPole-v1") +obs_dim = env.observation_space.shape[0] +act_dim = env.action_space.n + +actor_critic = MLPActorCritic(obs_dim, act_dim) +ppo = PPO( + env, + actor_critic, + clip_param=0.2, + entcoeff=0.01, + optim_stepsize=1e-3, + optim_epochs=4, + gamma=0.99, + lam=0.95, +) + +ppo.train(total_timesteps=10000, timesteps_per_actorbatch=256) diff --git a/assume/scenario/loader_csv.py b/assume/scenario/loader_csv.py index 547d818cb..e45a84e0a 100644 --- a/assume/scenario/loader_csv.py +++ b/assume/scenario/loader_csv.py @@ -26,6 +26,8 @@ from assume.strategies import BaseStrategy from assume.world import World +# from assume.reinforcement_learning.learning_utils import calculate_total_timesteps_per_episode + logger = logging.getLogger(__name__) @@ -437,6 +439,9 @@ def load_config_and_create_forecaster( start = pd.Timestamp(config["start_date"]) end = pd.Timestamp(config["end_date"]) + # New addition for PPO + time_step = pd.Timedelta(config["time_step"]) + index = pd.date_range( start=start, end=end + timedelta(days=1), @@ -527,6 +532,8 @@ def load_config_and_create_forecaster( "demand_units": demand_units, "dsm_units": dsm_units, "forecaster": forecaster, + # New addition for PPO + "time_step": time_step, } @@ -849,64 +856,57 @@ def run_learning( verbose: bool = False, ) -> None: """ - Train Deep Reinforcement Learning (DRL) agents to act in a simulated market environment. - - This function runs multiple episodes of simulation to train DRL agents, performs evaluation, and saves the best runs. It maintains the buffer and learned agents in memory to avoid resetting them with each new run. - - Args: - world (World): An instance of the World class representing the simulation environment. - inputs_path (str): The path to the folder containing input files necessary for the simulation. - scenario (str): The name of the scenario for the simulation. - study_case (str): The specific study case for the simulation. - - Note: - - The function uses a ReplayBuffer to store experiences for training the DRL agents. - - It iterates through training episodes, updating the agents and evaluating their performance at regular intervals. - - Initial exploration is active at the beginning and is disabled after a certain number of episodes to improve the performance of DRL algorithms. - - Upon completion of training, the function performs an evaluation run using the best policy learned during training. - - The best policies are chosen based on the average reward obtained during the evaluation runs, and they are saved for future use. + Train Deep Reinforcement Learning (DRL) agents (either MATD3 or PPO) to act in a simulated market environment. + This function runs multiple episodes of simulation to train DRL agents, performs evaluation, and saves the best runs. + It maintains the buffer and learned agents in memory to avoid resetting them with each new run. """ - from assume.reinforcement_learning.buffer import ReplayBuffer + from assume.reinforcement_learning.buffer import ReplayBuffer, RolloutBuffer if not verbose: logger.setLevel(logging.WARNING) - # remove csv path so that nothing is written while learning temp_csv_path = world.export_csv_path world.export_csv_path = "" - # initialize policies already here to set the obs_dim and act_dim in the learning role actors_and_critics = None - world.learning_role.initialize_policy(actors_and_critics=actors_and_critics) + world.learning_role.initialize_policy(actors_and_critics=actors_and_critics) # Leads to the initialization of the Learning role, makes world.learning_role.rl_algorithm_name accessible + world.output_role.delete_similar_runs() # check if we already stored policies for this simulation save_path = world.learning_config["trained_policies_save_path"] if Path(save_path).is_dir(): - # we are in learning mode and about to train new policies, which might overwrite existing ones accept = input( f"{save_path=} exists - should we overwrite current learnings? (y/N) " ) if not accept.lower().startswith("y"): - # stop here - do not start learning or save anything - raise AssumeException("don't overwrite existing strategies") + raise AssumeException("Don't overwrite existing strategies") - # ----------------------------------------- - # Load scenario data to reuse across episodes + # Load scenario data scenario_data = load_config_and_create_forecaster(inputs_path, scenario, study_case) - # ----------------------------------------- - # Information that needs to be stored across episodes, aka one simulation run + # For PPO buffer size calculation + validation_interval_from_config = world.learning_config.get( + "validation_episodes_interval", 5 + ) + + buffer_cls = ( + ReplayBuffer + if world.learning_role.rl_algorithm_name == "matd3" + else RolloutBuffer + ) + buffer = buffer_cls( + buffer_size=int(float(world.learning_config.get("buffer_size", 5e5))), + obs_dim=world.learning_role.rl_algorithm.obs_dim, + act_dim=world.learning_role.rl_algorithm.act_dim, + n_rl_units=len(world.learning_role.rl_strats), + device=world.learning_role.device, + float_type=world.learning_role.float_type, + ) + inter_episodic_data = { - "buffer": ReplayBuffer( - buffer_size=int(world.learning_config.get("replay_buffer_size", 5e5)), - obs_dim=world.learning_role.rl_algorithm.obs_dim, - act_dim=world.learning_role.rl_algorithm.act_dim, - n_rl_units=len(world.learning_role.rl_strats), - device=world.learning_role.device, - float_type=world.learning_role.float_type, - ), + "buffer": buffer, "actors_and_critics": None, "max_eval": defaultdict(lambda: -1e9), "all_eval": defaultdict(list), @@ -919,16 +919,19 @@ def run_learning( validation_interval = min( world.learning_role.training_episodes, - world.learning_config.get("validation_episodes_interval", 5), + validation_interval_from_config, ) eval_episode = 1 + # Training loop with integrated validation after a certain number of episodes for episode in tqdm( range(1, world.learning_role.training_episodes + 1), desc="Training Episodes", ): - # TODO normally, loading twice should not create issues, somehow a scheduling issue is raised currently + + # print("loader_csv: Episode: ", episode) + if episode != 1: setup_world( world=world, @@ -941,51 +944,73 @@ def run_learning( # Give the newly initialized learning role the needed information across episodes world.learning_role.load_inter_episodic_data(inter_episodic_data) - world.run() + world.run() # triggers calculate_bids() - # ----------------------------------------- - # Store updated information across episodes inter_episodic_data = world.learning_role.get_inter_episodic_data() inter_episodic_data["episodes_done"] = episode - # evaluation run: - if ( - episode % validation_interval == 0 - and episode + # Perform validation at regular intervals + if episode % validation_interval == 0 and ( + episode >= world.learning_role.episodes_collecting_initial_experience + validation_interval + if world.learning_role.rl_algorithm_name == "matd3" + else episode >= validation_interval # For PPO ): + + logger.debug(f"Validation of loader_csv after episode {episode}") + world.reset() - # load evaluation run setup_world( world=world, scenario_data=scenario_data, study_case=study_case, - perform_evaluation=True, + perform_evaluation=True, # perform evaluation triggers save_buffer_and_update, which triggers update_policy() eval_episode=eval_episode, ) world.learning_role.load_inter_episodic_data(inter_episodic_data) - world.run() - total_rewards = world.output_role.get_sum_reward() + if world.learning_role.rl_algorithm_name == "matd3": + total_rewards = world.output_role.get_sum_reward() - if len(total_rewards) == 0: - raise AssumeException("No rewards were collected during evaluation run") - - avg_reward = np.mean(total_rewards) + if len(total_rewards) == 0: + raise AssumeException("No rewards were collected during evaluation run") + avg_reward = np.mean(total_rewards) + terminate = world.learning_role.compare_and_save_policies( + {"avg_reward": avg_reward} + ) - # check reward improvement in evaluation run - # and store best run in eval folder - terminate = world.learning_role.compare_and_save_policies( - {"avg_reward": avg_reward} - ) + if world.learning_role.rl_algorithm_name == "ppo": + # TODO: add surrogate loss as a parameter to compare_and_save_policies + # PPO uses the surrogate loss to monitor policy updates. + # The surrogate loss quantifies how much the new policy has changed compared to the old one. + # If the surrogate loss becomes too small or too large, it can indicate issues: + # - A very small value may mean that the policy is near its optimum. + # - A large value could indicate excessive policy updates, leading to instability. + # + # It may be useful to terminate the training early based on the surrogate loss, + # especially if no significant improvement is expected, or if the model becomes unstable. + # + # In this example, the surrogate_loss could be computed, and then + # `compare_and_save_policies` can be used to check whether the training should be terminated. + + # surrogate_loss = + # terminate = world.learning_role.compare_and_save_policies({"surrogate_loss": surrogate_loss}) + + # Reset the PPO Rollout Buffer after each update + inter_episodic_data["buffer"].reset() + + total_rewards = world.output_role.get_sum_reward() + avg_reward = np.mean(total_rewards) + terminate = world.learning_role.compare_and_save_policies( + {"avg_reward": avg_reward} + ) inter_episodic_data["eval_episodes_done"] = eval_episode - # if we have not improved in the last x evaluations, we stop loop if terminate: break @@ -993,20 +1018,18 @@ def run_learning( world.reset() - # if at end of simulation save last policies if episode == (world.learning_role.training_episodes): world.learning_role.rl_algorithm.save_params( directory=f"{world.learning_role.trained_policies_save_path}/last_policies" ) - # container shutdown implicitly with new initialisation logger.info("################") logger.info("Training finished, Start evaluation run") world.export_csv_path = temp_csv_path world.reset() - # load scenario for evaluation + # Based on the parameters for setup_world, it is automatically recognized if training or evaluation is to be performed. Now the evaluation is performed. setup_world( world=world, scenario_data=scenario_data, @@ -1016,6 +1039,8 @@ def run_learning( world.learning_role.load_inter_episodic_data(inter_episodic_data) + logger.debug("Evaluation finished") + if __name__ == "__main__": data = read_grid(Path("examples/inputs/example_01d")) diff --git a/assume/strategies/learning_strategies.py b/assume/strategies/learning_strategies.py index 279539fb3..21bd04052 100644 --- a/assume/strategies/learning_strategies.py +++ b/assume/strategies/learning_strategies.py @@ -138,7 +138,8 @@ def __init__(self, *args, **kwargs): # based on learning config self.algorithm = kwargs.get("algorithm", "matd3") - actor_architecture = kwargs.get("actor_architecture", "mlp") + algo_config = kwargs.get(self.algorithm, {}) + actor_architecture = algo_config.get("actor_architecture", "mlp") if actor_architecture in actor_architecture_aliases.keys(): self.actor_architecture_class = actor_architecture_aliases[ @@ -180,11 +181,15 @@ def __init__(self, *args, **kwargs): elif Path(kwargs["trained_policies_save_path"]).is_dir(): self.load_actor_params(load_path=kwargs["trained_policies_save_path"]) + # Ensure action_noise is defined even when not in learning or evaluation mode + self.action_noise = None + self.collect_initial_experience_mode = None else: raise FileNotFoundError( f"No policies were provided for DRL unit {self.unit_id}!. Please provide a valid path to the trained policies." ) + def calculate_bids( self, unit: SupportsMinMax, @@ -242,7 +247,9 @@ def calculate_bids( # ============================================================================= # 2. Get the Actions, based on the observations # ============================================================================= - actions, noise = self.get_actions(next_observation) + # Depending on the algorithm, we call specific function that passes obs through actor and generates actions + # extra_info is either noise (MATD3) or log_probs (PPO) + actions, extra_info = self.get_actions(self, next_observation) # ============================================================================= # 3. Transform Actions into bids @@ -287,73 +294,19 @@ def calculate_bids( # store results in unit outputs as series to be written to the database by the unit operator unit.outputs["actions"].at[start] = actions - unit.outputs["exploration_noise"].at[start] = noise - - return bids - - def get_actions(self, next_observation): - """ - Determines actions based on the current observation, applying noise for exploration if in learning mode. - - Args - ---- - next_observation : torch.Tensor - The current observation data that influences bid prices. - - Returns - ------- - torch.Tensor - Actions that include bid prices for both inflexible and flexible components. - - Notes - ----- - If the agent is in learning mode, noise is added to encourage exploration. In initial exploration, - actions are derived from noise and the marginal cost to explore the action space around the marginal cost. When not in learning mode, - actions are generated by the actor network without added noise. Actions are clamped to [-1, 1]. - """ - - # distinction whether we are in learning mode or not to handle exploration realised with noise - if self.learning_mode and not self.perform_evaluation: - # if we are in learning mode the first x episodes we want to explore the entire action space - # to get a good initial experience, in the area around the costs of the agent - if self.collect_initial_experience_mode: - # define current action as solely noise - noise = ( - th.normal( - mean=0.0, std=0.2, size=(1, self.act_dim), dtype=self.float_type - ) - .to(self.device) - .squeeze() - ) + # unit.outputs["exploration_noise"][start] = noise + # TODO: Make this algo specific function + # Check if extra_info is noise or log_probs and store it accordingly - # ============================================================================= - # 2.1 Get Actions and handle exploration - # ============================================================================= - base_bid = next_observation[-1] - - # add noise to the last dimension of the observation - # needs to be adjusted if observation space is changed, because only makes sense - # if the last dimension of the observation space are the marginal cost - curr_action = noise + base_bid.clone().detach() - - else: - # if we are not in the initial exploration phase we choose the action with the actor neural net - # and add noise to the action - curr_action = self.actor(next_observation).detach() - noise = th.tensor( - self.action_noise.noise(), device=self.device, dtype=self.float_type - ) - curr_action += noise + if isinstance(extra_info, th.Tensor) and extra_info.shape == actions.shape: + unit.outputs["exploration_noise"].at[start] = extra_info # It's noise else: - # if we are not in learning mode we just use the actor neural net to get the action without adding noise - curr_action = self.actor(next_observation).detach() + unit.outputs["rl_log_probs"].append(extra_info) # It's log_probs + - # noise is an tensor with zeros, because we are not in learning mode - noise = th.zeros(self.act_dim, dtype=self.float_type) - - curr_action = curr_action.clamp(-1, 1) + bids = self.remove_empty_bids(bids) - return curr_action, noise + return bids def create_observation( self, diff --git a/assume/world.py b/assume/world.py index 75b21c8d2..ab1c86e08 100644 --- a/assume/world.py +++ b/assume/world.py @@ -6,7 +6,7 @@ import logging import sys import time -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from mango import ( @@ -392,14 +392,45 @@ def add_rl_unit_operator(self, id: str = "Operator-RL") -> None: # after creation of an agent - we set additional context params if self.learning_mode: + # Extract algorithm from the Learning_Config + algorithm = self.learning_config.get("algorithm", "matd3") + + # Select correct train_freq based on the algorithm + if algorithm == "matd3": + train_freq = self.learning_config.get("matd3", {}).get("train_freq", "24h") + elif algorithm == "ppo": + train_freq = self.learning_config.get("ppo", {}).get("train_freq", "24h") + else: + train_freq = "24h" # Standard value if algorithm is not defined + unit_operator_agent._role_context.data.update( { "learning_agent_addr": self.learning_agent_addr, "train_start": self.start, "train_end": self.end, - "train_freq": self.learning_config.get("train_freq", "24h"), + "train_freq": train_freq, } ) + + + # Convert train_freq to hours for comparison + freq_value = int(train_freq[:-1]) # Extract the numerical value + freq_unit = train_freq[-1] # Extract the time unit (h for hours, d for days) + + # Convert the train_freq into hours + if freq_unit == "h": + train_freq_hours = freq_value + elif freq_unit == "d": + train_freq_hours = freq_value * 24 + else: + train_freq_hours = 24 # Default to 24 hours + + # Calculate time difference in hours + duration_hours = int((self.end - self.start) / timedelta(hours=1)) + + # Check if train_freq is larger than the time difference + if train_freq_hours > duration_hours: + print(f"Warning: The train frequency ({train_freq_hours}h) is larger than the time difference between start and end ({duration_hours}h).") else: unit_operator_agent._role_context.data.update( @@ -620,11 +651,13 @@ async def async_run(self, start_ts: datetime, end_ts: datetime): start_ts (datetime.datetime): The start timestamp for the simulation run. end_ts (datetime.datetime): The end timestamp for the simulation run. """ - logger.debug("activating container") + if not self.learning_mode: + logger.info("activating container") # agent is implicit added to self.container._agents async with activate(self.container) as c: await tasks_complete_or_sleeping(c) - logger.debug("all agents up - starting simulation") + if not self.learning_mode: + logger.info("all agents up - starting simulation") pbar = tqdm(total=end_ts - start_ts) # allow registration before first opening diff --git a/docker_configs/dashboard-definitions/ASSUME_Actor_Comparison.json b/docker_configs/dashboard-definitions/ASSUME_Actor_Comparison.json new file mode 100644 index 000000000..84a151fe2 --- /dev/null +++ b/docker_configs/dashboard-definitions/ASSUME_Actor_Comparison.json @@ -0,0 +1,2279 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "description": "This dashboard offers various perspectives (performance, run time and robustness) to compare different Actor architectures.", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 3, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "description": "", + "gridPos": { + "h": 2, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 10, + "options": { + "code": { + "language": "plaintext", + "showLineNumbers": false, + "showMiniMap": false + }, + "content": "# Welcome to ASSUMES Actor Comparison Dashboard\n", + "mode": "markdown" + }, + "pluginVersion": "10.4.0", + "targets": [ + { + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "format": "time_series", + "group": [], + "metricColumn": "none", + "rawQuery": false, + "rawSql": "SELECT\n datetime AS \"time\",\n power\nFROM market_dispatch\nWHERE\n $__timeFilter(datetime)\nORDER BY 1", + "refId": "A", + "select": [ + [ + { + "params": [ + "power" + ], + "type": "column" + } + ] + ], + "table": "market_dispatch", + "timeColumn": "datetime", + "timeColumnType": "timestamp", + "where": [ + { + "name": "$__timeFilter", + "params": [], + "type": "macro" + } + ] + } + ], + "type": "text" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 2 + }, + "id": 16, + "panels": [], + "title": "Market Result Comparison", + "type": "row" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisGridShow": false, + "axisLabel": "Price in [€/MWh]", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineStyle": { + "fill": "solid" + }, + "lineWidth": 0.5, + "pointSize": 2, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "line" + } + }, + "mappings": [], + "max": 90, + "min": 33, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "#a3222240", + "value": 55.7 + }, + { + "color": "#a322226e", + "value": 85.7 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Max. Price" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#8f8f8f52", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Avg. Price" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#404040", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Min. Price" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#5f6a7a82", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byRegexp", + "options": "Max. Price" + }, + "properties": [ + { + "id": "custom.fillBelowTo", + "value": "Min. Price" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Q1" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#00968252", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byRegexp", + "options": "Q3" + }, + "properties": [ + { + "id": "custom.fillBelowTo", + "value": "Q1" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Q3" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#0096824f", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 13, + "w": 12, + "x": 0, + "y": 3 + }, + "id": 25, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "10.4.0", + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "rawQuery": true, + "rawSql": "SELECT\r\n product_start AS \"time\",\r\n avg(price) AS \"Avg. Price\",\r\n min(price) AS \"Min. Price\", \r\n percentile_disc (0.25) WITHIN GROUP ( ORDER BY price ) AS \"Q1\", \r\n percentile_disc (0.75) WITHIN GROUP ( ORDER BY price ) AS \"Q3\", \r\n max(price) AS \"Max. Price\"\r\nFROM market_meta\r\nWHERE (SUBSTRING(simulation, 1, regexp_instr(simulation, 'run')-2) IN ($case_study)) AND $__timeFilter(product_start) \r\nGROUP BY product_start\r\nORDER BY product_start", + "refId": "A", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Market Clearing Price - Case 2 with LSTM", + "type": "timeseries" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "Accepted price and (flexible load) bid prices per unit in the chosen market", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic-by-name" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": true, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "area" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "transparent", + "value": null + }, + { + "color": "red", + "value": 85.7 + } + ] + }, + "unit": "currencyEUR" + }, + "overrides": [ + { + "matcher": { + "id": "byRegexp", + "options": "price .*" + }, + "properties": [ + { + "id": "unit", + "value": "€/MW" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Accepted price: pp_1 - EOM" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#a22223", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Bid price - flex:" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#009682", + "mode": "shades" + } + } + ] + } + ] + }, + "gridPos": { + "h": 13, + "w": 12, + "x": 12, + "y": 3 + }, + "id": 26, + "options": { + "legend": { + "calcs": [ + "min", + "max", + "mean" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true, + "sortBy": "Name", + "sortDesc": false + }, + "tooltip": { + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "9.2.15", + "targets": [ + { + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "time_series", + "group": [ + { + "params": [ + "$__interval", + "none" + ], + "type": "time" + }, + { + "params": [ + "unit_id" + ], + "type": "column" + } + ], + "hide": false, + "metricColumn": "none", + "rawQuery": true, + "rawSql": "SELECT\r\n $__timeGroupAlias(start_time,$__interval),\r\n avg(accepted_price::float) AS \"Accepted price:\",\r\n max(price) AS \"Bid price - flex:\",\r\n concat(unit_id, ' - ', market_id) as \"unit_id\"\r\nFROM market_orders\r\nWHERE\r\n $__timeFilter(start_time) AND\r\n unit_id <> 'demand_EOM' AND \r\n simulation = '$simulation'\r\nGROUP BY 1, unit_id, market_id\r\nORDER BY 1", + "refId": "A", + "select": [ + [ + { + "params": [ + "original_price" + ], + "type": "column" + }, + { + "params": [ + "avg" + ], + "type": "aggregate" + }, + { + "params": [ + "price" + ], + "type": "alias" + } + ], + [ + { + "params": [ + "unit_id" + ], + "type": "column" + }, + { + "params": [ + "unit_id" + ], + "type": "alias" + } + ] + ], + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + }, + "table": "market_orders", + "timeColumn": "start_time", + "timeColumnType": "timestamp", + "where": [ + { + "name": "$__timeFilter", + "params": [], + "type": "macro" + }, + { + "datatype": "text", + "name": "", + "params": [ + "market_id", + "=", + "'$market'" + ], + "type": "expression" + }, + { + "datatype": "text", + "name": "", + "params": [ + "simulation", + "=", + "'$simulation'" + ], + "type": "expression" + } + ] + } + ], + "title": "Bid Prices: $simulation", + "transformations": [ + { + "id": "filterFieldsByName", + "options": { + "byVariable": false, + "include": { + "names": [ + "Time", + "Bid price - flex: pp_1 - EOM", + "Bid price - flex: pp_10 - EOM", + "Bid price - flex: pp_11 - EOM", + "Bid price - flex: pp_2 - EOM", + "Bid price - flex: pp_3 - EOM", + "Bid price - flex: pp_4 - EOM", + "Bid price - flex: pp_5 - EOM", + "Bid price - flex: pp_6 - EOM", + "Bid price - flex: pp_7 - EOM", + "Bid price - flex: pp_8 - EOM", + "Bid price - flex: pp_9 - EOM", + "Accepted price: pp_1 - EOM" + ] + } + } + } + ], + "type": "timeseries" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "Averaged over all simulation runs.", + "fieldConfig": { + "defaults": { + "color": { + "fixedColor": "text", + "mode": "fixed" + }, + "custom": { + "align": "auto", + "cellOptions": { + "type": "color-text" + }, + "filterable": false, + "inspect": false + }, + "decimals": 2, + "mappings": [ + { + "options": { + "example_02a_harder": { + "color": "dark-blue", + "index": 0, + "text": "Case 1" + }, + "example_02a_harder_LF": { + "color": "semi-dark-blue", + "index": 1, + "text": "Case 1 (LF)" + }, + "example_02a_harder_lstm": { + "color": "blue", + "index": 2, + "text": "Case 1 with LSTM" + }, + "example_02a_harder_lstm_LF": { + "color": "light-blue", + "index": 3, + "text": "Case 1 with LSTM (LF)" + }, + "example_02b_harder": { + "color": "dark-orange", + "index": 4, + "text": "Case 2" + }, + "example_02b_harder_lstm": { + "color": "super-light-orange", + "index": 5, + "text": "Case 2 with LSTM" + } + }, + "type": "value" + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Case Study" + }, + "properties": [ + { + "id": "custom.width", + "value": 163 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Overall Average" + }, + "properties": [ + { + "id": "custom.width", + "value": 172 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Std. Dev" + }, + "properties": [ + { + "id": "custom.width", + "value": 82 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Std. Dev." + }, + "properties": [ + { + "id": "custom.width", + "value": 138 + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 9, + "x": 0, + "y": 16 + }, + "id": 24, + "options": { + "cellHeight": "sm", + "footer": { + "countRows": false, + "fields": [], + "reducer": [ + "sum" + ], + "show": false + }, + "showHeader": true, + "sortBy": [] + }, + "pluginVersion": "10.4.0", + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "rawQuery": true, + "rawSql": "SELECT\r\n simulation as \"Case Study\", \r\n avg(\"Avg. Price\") as \"Overall Average\", \r\n stddev(\"Avg. Price\") as \"Std. Dev.\"\r\nFROM (\r\n SELECT\r\n SUBSTRING(simulation, 1, regexp_instr(simulation, 'run')-2) as \"simulation\", \r\n substr(simulation, regexp_instr(simulation, 'run')) as \"run\",\r\n avg(price) AS \"Avg. Price\"\r\n FROM market_meta\r\n WHERE (\"market_id\" LIKE 'EOM')\r\n GROUP BY simulation\r\n ORDER BY 1, LENGTH(simulation), 2\r\n) as subselect\r\nGROUP BY simulation", + "refId": "A", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Overall Avg. Market Clearing Price", + "type": "table" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic-by-name" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "hidden", + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineWidth": 1, + "scaleDistribution": { + "type": "linear" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [ + { + "options": { + "example_02a_harder": { + "color": "dark-blue", + "index": 0, + "text": "Case 1" + }, + "example_02a_harder_LF": { + "color": "semi-dark-blue", + "index": 1, + "text": "Case 1 (LF)" + }, + "example_02a_harder_lstm": { + "color": "blue", + "index": 2, + "text": "Case 1 with LSTM" + }, + "example_02a_harder_lstm_LF": { + "color": "light-blue", + "index": 3, + "text": "Case 1 with LSTM (LF)" + }, + "example_02b_harder": { + "color": "dark-orange", + "index": 4, + "text": "Case 2" + }, + "example_02b_harder_lstm": { + "color": "super-light-orange", + "index": 5, + "text": "Case 2 with LSTM" + } + }, + "type": "value" + } + ], + "thresholds": { + "mode": "percentage", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Case 2" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "dark-orange", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 15, + "x": 9, + "y": 16 + }, + "id": 23, + "options": { + "barRadius": 0, + "barWidth": 0.97, + "colorByField": "simulation", + "fullHighlight": false, + "groupWidth": 0.7, + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "orientation": "auto", + "showValue": "always", + "stacking": "none", + "tooltip": { + "mode": "multi", + "sort": "none" + }, + "xTickLabelRotation": 0, + "xTickLabelSpacing": 0 + }, + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n SUBSTRING(simulation, 1, regexp_instr(simulation, 'run')-2) as \"simulation\", \r\n substr(simulation, regexp_instr(simulation, 'run')) as \"run\",\r\n avg(price) AS \"Avg. Price\"\r\nFROM market_meta\r\nWHERE (\"market_id\" LIKE 'EOM')\r\nGROUP BY simulation\r\nORDER BY 1, LENGTH(simulation), 2;", + "refId": "Market Clearing Price", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Avg. Market Clearing Price per Run", + "type": "barchart" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 24 + }, + "id": 18, + "panels": [], + "title": "Computation Time Comparison", + "type": "row" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "Comparison of Total Run Times for each Simulation Run", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic-by-name" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "bars", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineStyle": { + "fill": "solid" + }, + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 2, + "mappings": [ + { + "options": { + "1": { + "color": "dark-blue", + "index": 0, + "text": "Case 1" + }, + "2": { + "color": "semi-dark-blue", + "index": 1, + "text": "Case 1 (LF)" + }, + "3": { + "color": "blue", + "index": 2, + "text": "Case 1 with LSTM" + }, + "4": { + "color": "light-blue", + "index": 3, + "text": "Case 1 with LSTM (LF)" + }, + "5": { + "color": "semi-dark-purple", + "index": 4, + "text": "Case 2" + }, + "6": { + "color": "light-purple", + "index": 5, + "text": "Case 2 with LSTM" + } + }, + "type": "value" + } + ], + "min": 0, + "thresholds": { + "mode": "percentage", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "min" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-green", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "avg" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-yellow", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "max" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-red", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "bars" + }, + { + "id": "custom.showPoints", + "value": "always" + }, + { + "id": "custom.lineWidth", + "value": 0 + } + ] + }, + { + "matcher": { + "id": "byRegexp", + "options": "max" + }, + "properties": [ + { + "id": "custom.fillBelowTo", + "value": "min" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "index" + }, + "properties": [ + { + "id": "custom.axisGridShow", + "value": false + }, + { + "id": "custom.axisPlacement", + "value": "auto" + }, + { + "id": "max", + "value": 6.5 + }, + { + "id": "unit", + "value": "time: " + }, + { + "id": "min", + "value": 0.5 + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 11, + "x": 0, + "y": 25 + }, + "id": 19, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "multi", + "sort": "none" + }, + "xField": "index" + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": true, + "rawQuery": true, + "rawSql": "SELECT \r\n SUBSTRING(ident, 1, 18) as scenario,\r\n SUBSTRING(ident, 20) as run,\r\n value \r\nFROM \r\n kpis \r\nWHERE \r\n variable = 'total_run_time'\r\nORDER BY \r\n scenario, LENGTH(ident), run", + "refId": "A", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n scenario, \r\n ROW_NUMBER() OVER (ORDER BY scenario) as index, \r\n min(value),\r\n avg(value), \r\n max(value)\r\nFROM (\r\n SELECT \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) as scenario,\r\n SUBSTRING(ident, regexp_instr(ident, 'run')) as run,\r\n value \r\n FROM \r\n kpis \r\n WHERE \r\n variable = 'total_run_time'\r\n ORDER BY \r\n scenario, LENGTH(ident), run\r\n ) as subselect\r\nGROUP BY \r\n scenario", + "refId": "B", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Total Run Time", + "transformations": [ + { + "disabled": true, + "id": "calculateField", + "options": { + "alias": "index_2", + "index": { + "asPercentile": false + }, + "mode": "index", + "reduce": { + "reducer": "sum" + }, + "replaceFields": false + } + }, + { + "disabled": true, + "id": "calculateField", + "options": { + "alias": "index", + "binary": { + "left": "index_2", + "right": "1" + }, + "mode": "binary", + "reduce": { + "reducer": "sum" + } + } + } + ], + "type": "trend" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "How long takes one simulation episode?", + "fieldConfig": { + "defaults": { + "color": { + "fixedColor": "green", + "mode": "palette-classic-by-name" + }, + "custom": { + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineWidth": 1 + }, + "fieldMinMax": false, + "mappings": [ + { + "options": { + "example_02a_harder": { + "color": "dark-blue", + "index": 0, + "text": "Case 1" + }, + "example_02a_harder_LF": { + "color": "semi-dark-blue", + "index": 1, + "text": "Case 1 (LF)" + }, + "example_02a_harder_lstm": { + "color": "blue", + "index": 2, + "text": "Case 1 with LSTM" + }, + "example_02a_harder_lstm_LF": { + "color": "light-blue", + "index": 3, + "text": "Case 1 with LSTM (LF)" + }, + "example_02b_harder": { + "color": "semi-dark-purple", + "index": 4, + "text": "Case 2" + }, + "example_02b_harder_lstm": { + "color": "light-purple", + "index": 5, + "text": "Case 2 with LSTM" + } + }, + "type": "value" + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "s" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Case 1" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "semi-dark-blue", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Case 1 (LF)" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "semi-dark-blue", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Case 1 with LSTM" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "blue", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Case 1 with LSTM (LF)" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-blue", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Case 2" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "dark-orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Case 2 with LSTM" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-orange", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 15, + "w": 9, + "x": 11, + "y": 25 + }, + "id": 20, + "options": { + "bucketCount": 200, + "combine": false, + "legend": { + "calcs": [ + "mean", + "stdDev", + "count" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + } + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 1\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02a_harder' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case1", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 1 (LF)\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02a_harder_LF' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case1_LF", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 1 with LSTM\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND\r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02a_harder_lstm' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case1_LSTM", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 1 with LSTM (LF)\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND\r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02a_harder_lstm_LF' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case 1_LSTM_LF", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 2\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02b_harder' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case2", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "SELECT\r\n variable, \r\n ident,\r\n value as \"Case 2 with LSTM\"\r\nFROM \r\n kpis\r\nWHERE \r\n variable = 'run_time' AND \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) = 'example_02b_harder_lstm' AND \r\n (CASE \r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN 1=1\r\n WHEN 'exploration' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ 'eval'\r\n WHEN 'exploration' in ($sim_type_filter) AND 'evaluation' in ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'evaluation' in ($sim_type_filter) AND 'learning' in ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' OR ident ~ 'eval'\r\n WHEN 'exploration' IN ($sim_type_filter) THEN ident ~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n WHEN 'evaluation' IN ($sim_type_filter) THEN ident ~ 'eval'\r\n WHEN 'learning' IN ($sim_type_filter) THEN ident !~ '[0-9]{1,2}[_]{1}[1-5]{1}$' AND ident !~ 'eval'\r\n END)", + "refId": "Case2_LSTM", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Run Time per Simulation Episode", + "type": "histogram" + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "How many episodes were needed for training, until early stopping was triggered?", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "bars", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 1, + "fieldMinMax": false, + "mappings": [ + { + "options": { + "1": { + "color": "dark-blue", + "index": 0, + "text": "Case 1" + }, + "2": { + "color": "semi-dark-blue", + "index": 1, + "text": "Case 1 (LF)" + }, + "3": { + "color": "blue", + "index": 2, + "text": "Case 1 with LSTM" + }, + "4": { + "color": "light-blue", + "index": 3, + "text": "Case 1 with LSTM (LF)" + }, + "5": { + "color": "semi-dark-purple", + "index": 4, + "text": "Case 2" + }, + "6": { + "color": "light-purple", + "index": 5, + "text": "Case 2 with LSTM" + } + }, + "type": "value" + } + ], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "min. episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-green", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "min. eval episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-green", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "avg. episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-yellow", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "avg. eval episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-yellow", + "mode": "fixed" + } + }, + { + "id": "custom.drawStyle", + "value": "points" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "max. episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "super-light-red", + "mode": "fixed" + } + }, + { + "id": "custom.fillBelowTo", + "value": "min. episodes" + }, + { + "id": "custom.showPoints", + "value": "always" + }, + { + "id": "custom.lineWidth", + "value": 0 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "max. eval episodes" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-red", + "mode": "fixed" + } + }, + { + "id": "custom.fillBelowTo", + "value": "min. eval episodes" + }, + { + "id": "custom.showPoints", + "value": "always" + }, + { + "id": "custom.lineWidth", + "value": 0 + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "index" + }, + "properties": [ + { + "id": "unit", + "value": "time: " + }, + { + "id": "custom.axisGridShow", + "value": false + }, + { + "id": "max", + "value": 6.5 + }, + { + "id": "min", + "value": 0.5 + } + ] + } + ] + }, + "gridPos": { + "h": 11, + "w": 11, + "x": 0, + "y": 36 + }, + "id": 21, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + }, + "xField": "index" + }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": false, + "rawQuery": true, + "rawSql": "CREATE EXTENSION IF NOT EXISTS tablefunc; \r\n\r\nSELECT \r\n scenario,\r\n ROW_NUMBER() OVER (ORDER BY scenario) as index,\r\n min(episodes_done) as \"min. episodes\", \r\n avg(episodes_done) as \"avg. episodes\", \r\n max(episodes_done) as \"max. episodes\"\r\nFROM (\r\n SELECT \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) as scenario, \r\n SUBSTRING(ident, regexp_instr(ident, 'run')) as run, \r\n episodes_done::NUMERIC\r\n FROM crosstab( \r\n 'SELECT ident, variable, value \r\n FROM kpis \r\n WHERE variable = ''episodes_done'' OR variable = ''eval_episodes_done'' \r\n ORDER BY ident, LENGTH(ident), variable',\r\n 'VALUES(''episodes_done''), (''eval_episodes_done'')' \r\n ) as ct (ident text, episodes_done text, eval_episodes_done text)\r\n ORDER BY scenario, LENGTH(ident), run) as subselect\r\nGROUP BY scenario", + "refId": "A", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + }, + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "hide": true, + "rawQuery": true, + "rawSql": "CREATE EXTENSION IF NOT EXISTS tablefunc; \r\n\r\nSELECT\r\n scenario\r\nFROM (\r\n SELECT \r\n SUBSTRING(ident, 1, regexp_instr(ident, 'run')-2) as scenario, \r\n SUBSTRING(ident, regexp_instr(ident, 'run')) as run, \r\n episodes_done, \r\n eval_episodes_done \r\n FROM crosstab( \r\n 'SELECT ident, variable, value \r\n FROM kpis \r\n WHERE variable = ''episodes_done'' OR variable = ''eval_episodes_done'' \r\n ORDER BY ident, LENGTH(ident), variable', \r\n 'VALUES(''episodes_done''), (''eval_episodes_done'')' \r\n ) as ct (ident text, episodes_done text, eval_episodes_done text)\r\n ORDER BY scenario, LENGTH(ident), run) as subselect\r\nGROUP BY \r\n scenario ", + "refId": "B", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Number of Episodes", + "type": "trend" + }, + { + "collapsed": true, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 47 + }, + "id": 27, + "panels": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "description": "How much profit make all learning units combined per evaluation episode? (Trained models in eval_episode 1 are mostly based on exploration runs and performing badly.)", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "fillOpacity": 80, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineWidth": 1, + "scaleDistribution": { + "type": "linear" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "min": -3000000, + "thresholds": { + "mode": "percentage", + "steps": [ + { + "color": "text" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 3, + "x": 0, + "y": 50 + }, + "id": 22, + "options": { + "barRadius": 0, + "barWidth": 0.97, + "fullHighlight": false, + "groupWidth": 0.7, + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "orientation": "vertical", + "showValue": "auto", + "stacking": "none", + "tooltip": { + "mode": "single", + "sort": "none" + }, + "xField": "episode", + "xTickLabelRotation": 0, + "xTickLabelSpacing": 0 + }, + "targets": [ + { + "datasource": { + "type": "grafana-postgresql-datasource", + "uid": "P7B13B9DF907EC40C" + }, + "editorMode": "code", + "format": "table", + "rawQuery": true, + "rawSql": "SELECT\r\n SUBSTRING(simulation, 1, regexp_instr(simulation, 'eval')-2) as \"simulation run\",\r\n episode, \r\n sum(profit) as \"Sum (Learning Agents Profit)\"\r\nFROM rl_params\r\nWHERE simulation ~ '${simulation}_' AND perform_evaluation is true \r\nGROUP BY simulation, episode\r\nORDER BY 1, episode", + "refId": "A", + "sql": { + "columns": [ + { + "parameters": [], + "type": "function" + } + ], + "groupBy": [ + { + "property": { + "type": "string" + }, + "type": "groupBy" + } + ], + "limit": 50 + } + } + ], + "title": "Total Evaluation Profit", + "type": "barchart" + } + ], + "title": "RIP - Plot Friedhof", + "type": "row" + } + ], + "refresh": "", + "schemaVersion": 39, + "tags": [], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "example_02b_harder_run_4", + "value": "example_02b_harder_run_4" + }, + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "definition": "SELECT DISTINCT\nSUBSTRING(m.simulation, 0, LENGTH(m.simulation) +1 - strpos(REVERSE(m.simulation),'_')) AS market_simulation\nFROM rl_params m\nwhere learning_mode is True and perform_evaluation is False", + "description": "Can choose which simulation we want to show ", + "hide": 0, + "includeAll": false, + "multi": false, + "name": "simulation", + "options": [], + "query": "SELECT DISTINCT\nSUBSTRING(m.simulation, 0, LENGTH(m.simulation) +1 - strpos(REVERSE(m.simulation),'_')) AS market_simulation\nFROM rl_params m\nwhere learning_mode is True and perform_evaluation is False", + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 1, + "type": "query" + }, + { + "current": { + "selected": false, + "text": "pp_6", + "value": "pp_6" + }, + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "definition": "SELECT DISTINCT unit\nFROM rl_params\nwhere simulation ~ '^${simulation}_[1-9]+'", + "description": "All units that have an reinforcment learning strategy and hence have the Rl specific parameteres logged", + "hide": 0, + "includeAll": false, + "label": "rl_unit", + "multi": false, + "name": "rl_unit", + "options": [], + "query": "SELECT DISTINCT unit\nFROM rl_params\nwhere simulation ~ '^${simulation}_[1-9]+'", + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": false, + "text": "1551402000000", + "value": "1551402000000" + }, + "datasource": { + "type": "postgres", + "uid": "P7B13B9DF907EC40C" + }, + "definition": "", + "hide": 2, + "includeAll": false, + "multi": false, + "name": "timeRange", + "options": [], + "query": "SELECT MIN(datetime), MAX(datetime) FROM rl_params", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "type": "query" + }, + { + "current": { + "selected": true, + "text": [ + "example_02b_harder" + ], + "value": [ + "example_02b_harder" + ] + }, + "description": "The four different scenario and actor combinations.", + "hide": 0, + "includeAll": false, + "label": "Case Study", + "multi": true, + "name": "case_study", + "options": [ + { + "selected": false, + "text": "example_02a_harder", + "value": "example_02a_harder" + }, + { + "selected": false, + "text": "example_02a_harder_LF", + "value": "example_02a_harder_LF" + }, + { + "selected": false, + "text": "example_02a_harder_lstm", + "value": "example_02a_harder_lstm" + }, + { + "selected": false, + "text": "example_02a_harder_lstm_LF", + "value": "example_02a_harder_lstm_LF" + }, + { + "selected": true, + "text": "example_02b_harder", + "value": "example_02b_harder" + }, + { + "selected": false, + "text": "example_02b_harder_lstm", + "value": "example_02b_harder_lstm" + } + ], + "query": "example_02a_harder, example_02a_harder_LF, example_02a_harder_lstm, example_02a_harder_lstm_LF, example_02b_harder, example_02b_harder_lstm", + "queryValue": "", + "skipUrlSync": false, + "type": "custom" + }, + { + "current": { + "selected": true, + "text": [ + "exploration" + ], + "value": [ + "exploration" + ] + }, + "description": "Charts can be filtered by type of simulation (exploration, learning or evaluation episodes).", + "hide": 0, + "includeAll": false, + "label": "Simulation Type", + "multi": true, + "name": "sim_type_filter", + "options": [ + { + "selected": true, + "text": "exploration", + "value": "exploration" + }, + { + "selected": false, + "text": "learning", + "value": "learning" + }, + { + "selected": false, + "text": "evaluation", + "value": "evaluation" + } + ], + "query": "exploration, learning, evaluation", + "queryValue": "", + "skipUrlSync": false, + "type": "custom" + } + ] + }, + "time": { + "from": "2019-03-03T23:00:00.000Z", + "to": "2019-03-09T03:00:00.000Z" + }, + "timepicker": { + "refresh_intervals": [ + "5s", + "1m", + "5m", + "15m", + "30m", + "1h", + "2h" + ] + }, + "timezone": "", + "title": "ASSUME: Actor Comparison", + "uid": "bdnvfato1dlogd", + "version": 16, + "weekStart": "" +} \ No newline at end of file diff --git a/docs/source/learning.rst b/docs/source/learning.rst index 336755507..e02faa5f3 100644 --- a/docs/source/learning.rst +++ b/docs/source/learning.rst @@ -31,7 +31,7 @@ interacting in the same environment. The Markov game for :math:`N` agents consis a set of observations :math:`O_1, \ldots, O_N`, and a state transition function :math:`P: S \times A_1 \times \ldots \times A_N \rightarrow \mathcal{P}(S)` dependent on the state and actions of all agents. After taking action :math:`a_i \in A_i` in state :math:`s_i \in S` according to a policy :math:`\pi_i: O_i \rightarrow A_i`, every agent :math:`i` is transitioned into the new state :math:`s'_i \in S`. Each agent receives a reward :math:`r_i` according to the individual reward function :math:`R_i` and a private observation correlated with the state :math:`o_i: S \rightarrow O_i`. -Like MDP, each agent :math:`i` learns an optimal policy :math:`\pi_i^*(s)` that maximizes its expected reward. +Like Markov Decision Process (MDP), each agent :math:`i` learns an optimal policy :math:`\pi_i^*(s)` that maximizes its expected reward. To enable multi-agent learning some adjustments are needed within the learning algorithm to get from the TD3 to an MATD3 algorithm. Other authors used similar tweaks to improve the MADDPG algorithm and derive the MA-TD3 algorithm. diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index f8ef30027..fd7e7c291 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -42,6 +42,10 @@ Upcoming Release - **Tutorial 07**: Aligned Amiris loader with changes in format in Amiris compare (https://gitlab.com/fame-framework/fame-io/-/issues/203 and https://gitlab.com/fame-framework/fame-io/-/issues/208) - **Powerplant**: Remove duplicate `Powerplant.set_dispatch_plan()` which broke multi-market bidding +**Features:** + - **PPO Integration:** The Proximal Policy Optimization (PPO) algorithm has been integrated into the framework, providing users with an additional reinforcement learning algorithm option for training agents. PPO is a popular policy gradient method that has been shown to be effective in a wide range of applications, making it a valuable addition to the framework's learning capabilities. A tutorial on how to use this feature is coming soon. + + v0.4.3 - (11th November 2024) =========================================== @@ -108,6 +112,9 @@ v0.4.1 (8th October 2024) - add compatibility with pyyaml-include (#421) - make complex clearing compatible to RL (#430) - pin PyPSA to remove DeprecationWarnings for now (#431) +**New Features:** + + v0.4.0 (8th August 2024) ========================================= diff --git a/examples/examples.py b/examples/examples.py index 321b3fd8c..fa0ae07f4 100644 --- a/examples/examples.py +++ b/examples/examples.py @@ -65,7 +65,10 @@ # example_01g is used in the tutorial notebook #6: Advanced order types example # # DRL references case for learning advancement testing - "small_learning_1": {"scenario": "example_02a", "study_case": "base"}, + "small_learning_1": { + "scenario": "example_02a", + "study_case": "base", + }, "small_learning_2": {"scenario": "example_02b", "study_case": "base"}, "small_learning_3": {"scenario": "example_02c", "study_case": "base"}, # DRL cases with lstm instead of mlp as actor neural network architecture diff --git a/examples/inputs/example_01a/forecasts_df.csv b/examples/inputs/example_01a/forecasts_df.csv new file mode 100644 index 000000000..c2bae940e --- /dev/null +++ b/examples/inputs/example_01a/forecasts_df.csv @@ -0,0 +1,746 @@ +,fuel_price_natural gas,availability_Unit 4,availability_Unit 3,fuel_price_co2,fuel_price_oil,residual_load_EOM,fuel_price_lignite,availability_Unit 2,availability_Unit 1,fuel_price_uranium,price_EOM,fuel_price_biomass,demand_EOM,fuel_price_hard coal +2019-01-01 00:00:00,25.0,1.0,1.0,25.0,40.0,2163.3,2.0,1.0,1.0,1.0,45.05,20.0,2163.3,10.0 +2019-01-01 01:00:00,25.0,1.0,1.0,25.0,40.0,2082.7,2.0,1.0,1.0,1.0,45.05,20.0,2082.7,10.0 +2019-01-01 02:00:00,25.0,1.0,1.0,25.0,40.0,2005.7,2.0,1.0,1.0,1.0,45.05,20.0,2005.7,10.0 +2019-01-01 03:00:00,25.0,1.0,1.0,25.0,40.0,1965.6,2.0,1.0,1.0,1.0,25.65,20.0,1965.6,10.0 +2019-01-01 04:00:00,25.0,1.0,1.0,25.0,40.0,1954.85,2.0,1.0,1.0,1.0,25.65,20.0,1954.85,10.0 +2019-01-01 05:00:00,25.0,1.0,1.0,25.0,40.0,1931.75,2.0,1.0,1.0,1.0,25.65,20.0,1931.75,10.0 +2019-01-01 06:00:00,25.0,1.0,1.0,25.0,40.0,1906.1,2.0,1.0,1.0,1.0,25.65,20.0,1906.1,10.0 +2019-01-01 07:00:00,25.0,1.0,1.0,25.0,40.0,1943.75,2.0,1.0,1.0,1.0,25.65,20.0,1943.75,10.0 +2019-01-01 08:00:00,25.0,1.0,1.0,25.0,40.0,1984.65,2.0,1.0,1.0,1.0,25.65,20.0,1984.65,10.0 +2019-01-01 09:00:00,25.0,1.0,1.0,25.0,40.0,2111.8,2.0,1.0,1.0,1.0,45.05,20.0,2111.8,10.0 +2019-01-01 10:00:00,25.0,1.0,1.0,25.0,40.0,2246.25,2.0,1.0,1.0,1.0,45.05,20.0,2246.25,10.0 +2019-01-01 11:00:00,25.0,1.0,1.0,25.0,40.0,2389.6,2.0,1.0,1.0,1.0,45.05,20.0,2389.6,10.0 +2019-01-01 12:00:00,25.0,1.0,1.0,25.0,40.0,2456.15,2.0,1.0,1.0,1.0,45.05,20.0,2456.15,10.0 +2019-01-01 13:00:00,25.0,1.0,1.0,25.0,40.0,2439.6,2.0,1.0,1.0,1.0,45.05,20.0,2439.6,10.0 +2019-01-01 14:00:00,25.0,1.0,1.0,25.0,40.0,2426.7000000000003,2.0,1.0,1.0,1.0,45.05,20.0,2426.7000000000003,10.0 +2019-01-01 15:00:00,25.0,1.0,1.0,25.0,40.0,2448.35,2.0,1.0,1.0,1.0,45.05,20.0,2448.35,10.0 +2019-01-01 16:00:00,25.0,1.0,1.0,25.0,40.0,2534.65,2.0,1.0,1.0,1.0,45.05,20.0,2534.65,10.0 +2019-01-01 17:00:00,25.0,1.0,1.0,25.0,40.0,2694.85,2.0,1.0,1.0,1.0,45.05,20.0,2694.85,10.0 +2019-01-01 18:00:00,25.0,1.0,1.0,25.0,40.0,2733.2,2.0,1.0,1.0,1.0,45.05,20.0,2733.2,10.0 +2019-01-01 19:00:00,25.0,1.0,1.0,25.0,40.0,2682.55,2.0,1.0,1.0,1.0,45.05,20.0,2682.55,10.0 +2019-01-01 20:00:00,25.0,1.0,1.0,25.0,40.0,2567.55,2.0,1.0,1.0,1.0,45.05,20.0,2567.55,10.0 +2019-01-01 21:00:00,25.0,1.0,1.0,25.0,40.0,2486.85,2.0,1.0,1.0,1.0,45.05,20.0,2486.85,10.0 +2019-01-01 22:00:00,25.0,1.0,1.0,25.0,40.0,2437.0,2.0,1.0,1.0,1.0,45.05,20.0,2437.0,10.0 +2019-01-01 23:00:00,25.0,1.0,1.0,25.0,40.0,2297.85,2.0,1.0,1.0,1.0,45.05,20.0,2297.85,10.0 +2019-01-02 00:00:00,25.0,1.0,1.0,25.0,40.0,2191.4,2.0,1.0,1.0,1.0,45.05,20.0,2191.4,10.0 +2019-01-02 01:00:00,25.0,1.0,1.0,25.0,40.0,2116.25,2.0,1.0,1.0,1.0,45.05,20.0,2116.25,10.0 +2019-01-02 02:00:00,25.0,1.0,1.0,25.0,40.0,2096.9,2.0,1.0,1.0,1.0,45.05,20.0,2096.9,10.0 +2019-01-02 03:00:00,25.0,1.0,1.0,25.0,40.0,2121.4,2.0,1.0,1.0,1.0,45.05,20.0,2121.4,10.0 +2019-01-02 04:00:00,25.0,1.0,1.0,25.0,40.0,2192.8,2.0,1.0,1.0,1.0,45.05,20.0,2192.8,10.0 +2019-01-02 05:00:00,25.0,1.0,1.0,25.0,40.0,2346.55,2.0,1.0,1.0,1.0,45.05,20.0,2346.55,10.0 +2019-01-02 06:00:00,25.0,1.0,1.0,25.0,40.0,2635.85,2.0,1.0,1.0,1.0,45.05,20.0,2635.85,10.0 +2019-01-02 07:00:00,25.0,1.0,1.0,25.0,40.0,2908.45,2.0,1.0,1.0,1.0,45.05,20.0,2908.45,10.0 +2019-01-02 08:00:00,25.0,1.0,1.0,25.0,40.0,3075.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3075.7,10.0 +2019-01-02 09:00:00,25.0,1.0,1.0,25.0,40.0,3171.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3171.5,10.0 +2019-01-02 10:00:00,25.0,1.0,1.0,25.0,40.0,3223.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3223.6,10.0 +2019-01-02 11:00:00,25.0,1.0,1.0,25.0,40.0,3289.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3289.55,10.0 +2019-01-02 12:00:00,25.0,1.0,1.0,25.0,40.0,3302.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3302.5,10.0 +2019-01-02 13:00:00,25.0,1.0,1.0,25.0,40.0,3265.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3265.85,10.0 +2019-01-02 14:00:00,25.0,1.0,1.0,25.0,40.0,3205.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3205.7,10.0 +2019-01-02 15:00:00,25.0,1.0,1.0,25.0,40.0,3198.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3198.45,10.0 +2019-01-02 16:00:00,25.0,1.0,1.0,25.0,40.0,3259.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3259.35,10.0 +2019-01-02 17:00:00,25.0,1.0,1.0,25.0,40.0,3422.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3422.0,10.0 +2019-01-02 18:00:00,25.0,1.0,1.0,25.0,40.0,3405.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3405.75,10.0 +2019-01-02 19:00:00,25.0,1.0,1.0,25.0,40.0,3321.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3321.95,10.0 +2019-01-02 20:00:00,25.0,1.0,1.0,25.0,40.0,3135.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3135.05,10.0 +2019-01-02 21:00:00,25.0,1.0,1.0,25.0,40.0,2977.15,2.0,1.0,1.0,1.0,45.05,20.0,2977.15,10.0 +2019-01-02 22:00:00,25.0,1.0,1.0,25.0,40.0,2881.65,2.0,1.0,1.0,1.0,45.05,20.0,2881.65,10.0 +2019-01-02 23:00:00,25.0,1.0,1.0,25.0,40.0,2698.15,2.0,1.0,1.0,1.0,45.05,20.0,2698.15,10.0 +2019-01-03 00:00:00,25.0,1.0,1.0,25.0,40.0,2550.85,2.0,1.0,1.0,1.0,45.05,20.0,2550.85,10.0 +2019-01-03 01:00:00,25.0,1.0,1.0,25.0,40.0,2502.1,2.0,1.0,1.0,1.0,45.05,20.0,2502.1,10.0 +2019-01-03 02:00:00,25.0,1.0,1.0,25.0,40.0,2487.9,2.0,1.0,1.0,1.0,45.05,20.0,2487.9,10.0 +2019-01-03 03:00:00,25.0,1.0,1.0,25.0,40.0,2482.85,2.0,1.0,1.0,1.0,45.05,20.0,2482.85,10.0 +2019-01-03 04:00:00,25.0,1.0,1.0,25.0,40.0,2521.6,2.0,1.0,1.0,1.0,45.05,20.0,2521.6,10.0 +2019-01-03 05:00:00,25.0,1.0,1.0,25.0,40.0,2652.4,2.0,1.0,1.0,1.0,45.05,20.0,2652.4,10.0 +2019-01-03 06:00:00,25.0,1.0,1.0,25.0,40.0,2892.3,2.0,1.0,1.0,1.0,45.05,20.0,2892.3,10.0 +2019-01-03 07:00:00,25.0,1.0,1.0,25.0,40.0,3111.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3111.35,10.0 +2019-01-03 08:00:00,25.0,1.0,1.0,25.0,40.0,3240.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3240.1,10.0 +2019-01-03 09:00:00,25.0,1.0,1.0,25.0,40.0,3304.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3304.45,10.0 +2019-01-03 10:00:00,25.0,1.0,1.0,25.0,40.0,3333.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3333.7,10.0 +2019-01-03 11:00:00,25.0,1.0,1.0,25.0,40.0,3378.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3378.35,10.0 +2019-01-03 12:00:00,25.0,1.0,1.0,25.0,40.0,3390.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3390.2,10.0 +2019-01-03 13:00:00,25.0,1.0,1.0,25.0,40.0,3331.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3331.7,10.0 +2019-01-03 14:00:00,25.0,1.0,1.0,25.0,40.0,3278.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3278.8,10.0 +2019-01-03 15:00:00,25.0,1.0,1.0,25.0,40.0,3250.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3250.55,10.0 +2019-01-03 16:00:00,25.0,1.0,1.0,25.0,40.0,3315.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3315.5,10.0 +2019-01-03 17:00:00,25.0,1.0,1.0,25.0,40.0,3468.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3468.75,10.0 +2019-01-03 18:00:00,25.0,1.0,1.0,25.0,40.0,3452.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3452.75,10.0 +2019-01-03 19:00:00,25.0,1.0,1.0,25.0,40.0,3355.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3355.9,10.0 +2019-01-03 20:00:00,25.0,1.0,1.0,25.0,40.0,3174.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3174.25,10.0 +2019-01-03 21:00:00,25.0,1.0,1.0,25.0,40.0,3029.6000000000004,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3029.6000000000004,10.0 +2019-01-03 22:00:00,25.0,1.0,1.0,25.0,40.0,2911.75,2.0,1.0,1.0,1.0,45.05,20.0,2911.75,10.0 +2019-01-03 23:00:00,25.0,1.0,1.0,25.0,40.0,2747.45,2.0,1.0,1.0,1.0,45.05,20.0,2747.45,10.0 +2019-01-04 00:00:00,25.0,1.0,1.0,25.0,40.0,2593.5,2.0,1.0,1.0,1.0,45.05,20.0,2593.5,10.0 +2019-01-04 01:00:00,25.0,1.0,1.0,25.0,40.0,2501.8,2.0,1.0,1.0,1.0,45.05,20.0,2501.8,10.0 +2019-01-04 02:00:00,25.0,1.0,1.0,25.0,40.0,2476.3,2.0,1.0,1.0,1.0,45.05,20.0,2476.3,10.0 +2019-01-04 03:00:00,25.0,1.0,1.0,25.0,40.0,2497.55,2.0,1.0,1.0,1.0,45.05,20.0,2497.55,10.0 +2019-01-04 04:00:00,25.0,1.0,1.0,25.0,40.0,2552.15,2.0,1.0,1.0,1.0,45.05,20.0,2552.15,10.0 +2019-01-04 05:00:00,25.0,1.0,1.0,25.0,40.0,2668.9,2.0,1.0,1.0,1.0,45.05,20.0,2668.9,10.0 +2019-01-04 06:00:00,25.0,1.0,1.0,25.0,40.0,2914.55,2.0,1.0,1.0,1.0,45.05,20.0,2914.55,10.0 +2019-01-04 07:00:00,25.0,1.0,1.0,25.0,40.0,3181.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3181.35,10.0 +2019-01-04 08:00:00,25.0,1.0,1.0,25.0,40.0,3341.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3341.15,10.0 +2019-01-04 09:00:00,25.0,1.0,1.0,25.0,40.0,3417.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3417.55,10.0 +2019-01-04 10:00:00,25.0,1.0,1.0,25.0,40.0,3453.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3453.35,10.0 +2019-01-04 11:00:00,25.0,1.0,1.0,25.0,40.0,3484.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3484.95,10.0 +2019-01-04 12:00:00,25.0,1.0,1.0,25.0,40.0,3513.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3513.4,10.0 +2019-01-04 13:00:00,25.0,1.0,1.0,25.0,40.0,3452.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3452.75,10.0 +2019-01-04 14:00:00,25.0,1.0,1.0,25.0,40.0,3362.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3362.3,10.0 +2019-01-04 15:00:00,25.0,1.0,1.0,25.0,40.0,3327.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3327.9,10.0 +2019-01-04 16:00:00,25.0,1.0,1.0,25.0,40.0,3391.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3391.75,10.0 +2019-01-04 17:00:00,25.0,1.0,1.0,25.0,40.0,3496.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3496.3,10.0 +2019-01-04 18:00:00,25.0,1.0,1.0,25.0,40.0,3464.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3464.75,10.0 +2019-01-04 19:00:00,25.0,1.0,1.0,25.0,40.0,3365.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3365.9,10.0 +2019-01-04 20:00:00,25.0,1.0,1.0,25.0,40.0,3167.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3167.05,10.0 +2019-01-04 21:00:00,25.0,1.0,1.0,25.0,40.0,3016.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3016.75,10.0 +2019-01-04 22:00:00,25.0,1.0,1.0,25.0,40.0,2924.35,2.0,1.0,1.0,1.0,45.05,20.0,2924.35,10.0 +2019-01-04 23:00:00,25.0,1.0,1.0,25.0,40.0,2724.7,2.0,1.0,1.0,1.0,45.05,20.0,2724.7,10.0 +2019-01-05 00:00:00,25.0,1.0,1.0,25.0,40.0,2582.8,2.0,1.0,1.0,1.0,45.05,20.0,2582.8,10.0 +2019-01-05 01:00:00,25.0,1.0,1.0,25.0,40.0,2465.75,2.0,1.0,1.0,1.0,45.05,20.0,2465.75,10.0 +2019-01-05 02:00:00,25.0,1.0,1.0,25.0,40.0,2401.5,2.0,1.0,1.0,1.0,45.05,20.0,2401.5,10.0 +2019-01-05 03:00:00,25.0,1.0,1.0,25.0,40.0,2377.75,2.0,1.0,1.0,1.0,45.05,20.0,2377.75,10.0 +2019-01-05 04:00:00,25.0,1.0,1.0,25.0,40.0,2381.7,2.0,1.0,1.0,1.0,45.05,20.0,2381.7,10.0 +2019-01-05 05:00:00,25.0,1.0,1.0,25.0,40.0,2385.1,2.0,1.0,1.0,1.0,45.05,20.0,2385.1,10.0 +2019-01-05 06:00:00,25.0,1.0,1.0,25.0,40.0,2416.5,2.0,1.0,1.0,1.0,45.05,20.0,2416.5,10.0 +2019-01-05 07:00:00,25.0,1.0,1.0,25.0,40.0,2538.05,2.0,1.0,1.0,1.0,45.05,20.0,2538.05,10.0 +2019-01-05 08:00:00,25.0,1.0,1.0,25.0,40.0,2715.3,2.0,1.0,1.0,1.0,45.05,20.0,2715.3,10.0 +2019-01-05 09:00:00,25.0,1.0,1.0,25.0,40.0,2879.6,2.0,1.0,1.0,1.0,45.05,20.0,2879.6,10.0 +2019-01-05 10:00:00,25.0,1.0,1.0,25.0,40.0,2985.6,2.0,1.0,1.0,1.0,45.05,20.0,2985.6,10.0 +2019-01-05 11:00:00,25.0,1.0,1.0,25.0,40.0,3046.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3046.05,10.0 +2019-01-05 12:00:00,25.0,1.0,1.0,25.0,40.0,3039.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3039.6,10.0 +2019-01-05 13:00:00,25.0,1.0,1.0,25.0,40.0,2975.9,2.0,1.0,1.0,1.0,45.05,20.0,2975.9,10.0 +2019-01-05 14:00:00,25.0,1.0,1.0,25.0,40.0,2923.65,2.0,1.0,1.0,1.0,45.05,20.0,2923.65,10.0 +2019-01-05 15:00:00,25.0,1.0,1.0,25.0,40.0,2901.85,2.0,1.0,1.0,1.0,45.05,20.0,2901.85,10.0 +2019-01-05 16:00:00,25.0,1.0,1.0,25.0,40.0,2952.7,2.0,1.0,1.0,1.0,45.05,20.0,2952.7,10.0 +2019-01-05 17:00:00,25.0,1.0,1.0,25.0,40.0,3094.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3094.45,10.0 +2019-01-05 18:00:00,25.0,1.0,1.0,25.0,40.0,3096.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3096.85,10.0 +2019-01-05 19:00:00,25.0,1.0,1.0,25.0,40.0,3005.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3005.0,10.0 +2019-01-05 20:00:00,25.0,1.0,1.0,25.0,40.0,2814.35,2.0,1.0,1.0,1.0,45.05,20.0,2814.35,10.0 +2019-01-05 21:00:00,25.0,1.0,1.0,25.0,40.0,2696.8,2.0,1.0,1.0,1.0,45.05,20.0,2696.8,10.0 +2019-01-05 22:00:00,25.0,1.0,1.0,25.0,40.0,2644.85,2.0,1.0,1.0,1.0,45.05,20.0,2644.85,10.0 +2019-01-05 23:00:00,25.0,1.0,1.0,25.0,40.0,2496.0,2.0,1.0,1.0,1.0,45.05,20.0,2496.0,10.0 +2019-01-06 00:00:00,25.0,1.0,1.0,25.0,40.0,2330.15,2.0,1.0,1.0,1.0,45.05,20.0,2330.15,10.0 +2019-01-06 01:00:00,25.0,1.0,1.0,25.0,40.0,2227.3,2.0,1.0,1.0,1.0,45.05,20.0,2227.3,10.0 +2019-01-06 02:00:00,25.0,1.0,1.0,25.0,40.0,2160.5,2.0,1.0,1.0,1.0,45.05,20.0,2160.5,10.0 +2019-01-06 03:00:00,25.0,1.0,1.0,25.0,40.0,2152.2,2.0,1.0,1.0,1.0,45.05,20.0,2152.2,10.0 +2019-01-06 04:00:00,25.0,1.0,1.0,25.0,40.0,2154.3,2.0,1.0,1.0,1.0,45.05,20.0,2154.3,10.0 +2019-01-06 05:00:00,25.0,1.0,1.0,25.0,40.0,2132.85,2.0,1.0,1.0,1.0,45.05,20.0,2132.85,10.0 +2019-01-06 06:00:00,25.0,1.0,1.0,25.0,40.0,2112.85,2.0,1.0,1.0,1.0,45.05,20.0,2112.85,10.0 +2019-01-06 07:00:00,25.0,1.0,1.0,25.0,40.0,2167.55,2.0,1.0,1.0,1.0,45.05,20.0,2167.55,10.0 +2019-01-06 08:00:00,25.0,1.0,1.0,25.0,40.0,2297.35,2.0,1.0,1.0,1.0,45.05,20.0,2297.35,10.0 +2019-01-06 09:00:00,25.0,1.0,1.0,25.0,40.0,2472.95,2.0,1.0,1.0,1.0,45.05,20.0,2472.95,10.0 +2019-01-06 10:00:00,25.0,1.0,1.0,25.0,40.0,2622.1,2.0,1.0,1.0,1.0,45.05,20.0,2622.1,10.0 +2019-01-06 11:00:00,25.0,1.0,1.0,25.0,40.0,2750.0,2.0,1.0,1.0,1.0,45.05,20.0,2750.0,10.0 +2019-01-06 12:00:00,25.0,1.0,1.0,25.0,40.0,2773.15,2.0,1.0,1.0,1.0,45.05,20.0,2773.15,10.0 +2019-01-06 13:00:00,25.0,1.0,1.0,25.0,40.0,2704.3,2.0,1.0,1.0,1.0,45.05,20.0,2704.3,10.0 +2019-01-06 14:00:00,25.0,1.0,1.0,25.0,40.0,2673.75,2.0,1.0,1.0,1.0,45.05,20.0,2673.75,10.0 +2019-01-06 15:00:00,25.0,1.0,1.0,25.0,40.0,2635.45,2.0,1.0,1.0,1.0,45.05,20.0,2635.45,10.0 +2019-01-06 16:00:00,25.0,1.0,1.0,25.0,40.0,2722.9,2.0,1.0,1.0,1.0,45.05,20.0,2722.9,10.0 +2019-01-06 17:00:00,25.0,1.0,1.0,25.0,40.0,2898.1,2.0,1.0,1.0,1.0,45.05,20.0,2898.1,10.0 +2019-01-06 18:00:00,25.0,1.0,1.0,25.0,40.0,2943.5,2.0,1.0,1.0,1.0,45.05,20.0,2943.5,10.0 +2019-01-06 19:00:00,25.0,1.0,1.0,25.0,40.0,2864.6,2.0,1.0,1.0,1.0,45.05,20.0,2864.6,10.0 +2019-01-06 20:00:00,25.0,1.0,1.0,25.0,40.0,2747.65,2.0,1.0,1.0,1.0,45.05,20.0,2747.65,10.0 +2019-01-06 21:00:00,25.0,1.0,1.0,25.0,40.0,2670.0,2.0,1.0,1.0,1.0,45.05,20.0,2670.0,10.0 +2019-01-06 22:00:00,25.0,1.0,1.0,25.0,40.0,2636.2,2.0,1.0,1.0,1.0,45.05,20.0,2636.2,10.0 +2019-01-06 23:00:00,25.0,1.0,1.0,25.0,40.0,2501.95,2.0,1.0,1.0,1.0,45.05,20.0,2501.95,10.0 +2019-01-07 00:00:00,25.0,1.0,1.0,25.0,40.0,2353.4,2.0,1.0,1.0,1.0,45.05,20.0,2353.4,10.0 +2019-01-07 01:00:00,25.0,1.0,1.0,25.0,40.0,2265.65,2.0,1.0,1.0,1.0,45.05,20.0,2265.65,10.0 +2019-01-07 02:00:00,25.0,1.0,1.0,25.0,40.0,2241.7,2.0,1.0,1.0,1.0,45.05,20.0,2241.7,10.0 +2019-01-07 03:00:00,25.0,1.0,1.0,25.0,40.0,2275.5,2.0,1.0,1.0,1.0,45.05,20.0,2275.5,10.0 +2019-01-07 04:00:00,25.0,1.0,1.0,25.0,40.0,2367.8,2.0,1.0,1.0,1.0,45.05,20.0,2367.8,10.0 +2019-01-07 05:00:00,25.0,1.0,1.0,25.0,40.0,2585.05,2.0,1.0,1.0,1.0,45.05,20.0,2585.05,10.0 +2019-01-07 06:00:00,25.0,1.0,1.0,25.0,40.0,3015.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3015.7,10.0 +2019-01-07 07:00:00,25.0,1.0,1.0,25.0,40.0,3331.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3331.45,10.0 +2019-01-07 08:00:00,25.0,1.0,1.0,25.0,40.0,3464.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3464.15,10.0 +2019-01-07 09:00:00,25.0,1.0,1.0,25.0,40.0,3478.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3478.2,10.0 +2019-01-07 10:00:00,25.0,1.0,1.0,25.0,40.0,3542.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3542.6,10.0 +2019-01-07 11:00:00,25.0,1.0,1.0,25.0,40.0,3589.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3589.3,10.0 +2019-01-07 12:00:00,25.0,1.0,1.0,25.0,40.0,3585.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3585.75,10.0 +2019-01-07 13:00:00,25.0,1.0,1.0,25.0,40.0,3587.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3587.6,10.0 +2019-01-07 14:00:00,25.0,1.0,1.0,25.0,40.0,3570.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3570.95,10.0 +2019-01-07 15:00:00,25.0,1.0,1.0,25.0,40.0,3572.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3572.1,10.0 +2019-01-07 16:00:00,25.0,1.0,1.0,25.0,40.0,3625.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3625.75,10.0 +2019-01-07 17:00:00,25.0,1.0,1.0,25.0,40.0,3734.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3734.0,10.0 +2019-01-07 18:00:00,25.0,1.0,1.0,25.0,40.0,3691.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3691.5,10.0 +2019-01-07 19:00:00,25.0,1.0,1.0,25.0,40.0,3595.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3595.65,10.0 +2019-01-07 20:00:00,25.0,1.0,1.0,25.0,40.0,3428.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3428.9,10.0 +2019-01-07 21:00:00,25.0,1.0,1.0,25.0,40.0,3276.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3276.45,10.0 +2019-01-07 22:00:00,25.0,1.0,1.0,25.0,40.0,3114.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3114.75,10.0 +2019-01-07 23:00:00,25.0,1.0,1.0,25.0,40.0,2918.1,2.0,1.0,1.0,1.0,45.05,20.0,2918.1,10.0 +2019-01-08 00:00:00,25.0,1.0,1.0,25.0,40.0,2726.35,2.0,1.0,1.0,1.0,45.05,20.0,2726.35,10.0 +2019-01-08 01:00:00,25.0,1.0,1.0,25.0,40.0,2637.75,2.0,1.0,1.0,1.0,45.05,20.0,2637.75,10.0 +2019-01-08 02:00:00,25.0,1.0,1.0,25.0,40.0,2594.5,2.0,1.0,1.0,1.0,45.05,20.0,2594.5,10.0 +2019-01-08 03:00:00,25.0,1.0,1.0,25.0,40.0,2623.85,2.0,1.0,1.0,1.0,45.05,20.0,2623.85,10.0 +2019-01-08 04:00:00,25.0,1.0,1.0,25.0,40.0,2704.8500000000004,2.0,1.0,1.0,1.0,45.05,20.0,2704.8500000000004,10.0 +2019-01-08 05:00:00,25.0,1.0,1.0,25.0,40.0,2874.4,2.0,1.0,1.0,1.0,45.05,20.0,2874.4,10.0 +2019-01-08 06:00:00,25.0,1.0,1.0,25.0,40.0,3236.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3236.3,10.0 +2019-01-08 07:00:00,25.0,1.0,1.0,25.0,40.0,3550.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3550.7,10.0 +2019-01-08 08:00:00,25.0,1.0,1.0,25.0,40.0,3672.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3672.35,10.0 +2019-01-08 09:00:00,25.0,1.0,1.0,25.0,40.0,3677.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3677.75,10.0 +2019-01-08 10:00:00,25.0,1.0,1.0,25.0,40.0,3709.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3709.45,10.0 +2019-01-08 11:00:00,25.0,1.0,1.0,25.0,40.0,3739.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3739.6,10.0 +2019-01-08 12:00:00,25.0,1.0,1.0,25.0,40.0,3727.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3727.7,10.0 +2019-01-08 13:00:00,25.0,1.0,1.0,25.0,40.0,3717.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3717.7,10.0 +2019-01-08 14:00:00,25.0,1.0,1.0,25.0,40.0,3678.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3678.4,10.0 +2019-01-08 15:00:00,25.0,1.0,1.0,25.0,40.0,3658.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3658.05,10.0 +2019-01-08 16:00:00,25.0,1.0,1.0,25.0,40.0,3695.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3695.4,10.0 +2019-01-08 17:00:00,25.0,1.0,1.0,25.0,40.0,3790.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3790.15,10.0 +2019-01-08 18:00:00,25.0,1.0,1.0,25.0,40.0,3763.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3763.35,10.0 +2019-01-08 19:00:00,25.0,1.0,1.0,25.0,40.0,3674.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3674.5,10.0 +2019-01-08 20:00:00,25.0,1.0,1.0,25.0,40.0,3479.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3479.45,10.0 +2019-01-08 21:00:00,25.0,1.0,1.0,25.0,40.0,3301.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3301.55,10.0 +2019-01-08 22:00:00,25.0,1.0,1.0,25.0,40.0,3147.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3147.05,10.0 +2019-01-08 23:00:00,25.0,1.0,1.0,25.0,40.0,2963.15,2.0,1.0,1.0,1.0,45.05,20.0,2963.15,10.0 +2019-01-09 00:00:00,25.0,1.0,1.0,25.0,40.0,2777.55,2.0,1.0,1.0,1.0,45.05,20.0,2777.55,10.0 +2019-01-09 01:00:00,25.0,1.0,1.0,25.0,40.0,2685.7,2.0,1.0,1.0,1.0,45.05,20.0,2685.7,10.0 +2019-01-09 02:00:00,25.0,1.0,1.0,25.0,40.0,2638.6,2.0,1.0,1.0,1.0,45.05,20.0,2638.6,10.0 +2019-01-09 03:00:00,25.0,1.0,1.0,25.0,40.0,2661.5,2.0,1.0,1.0,1.0,45.05,20.0,2661.5,10.0 +2019-01-09 04:00:00,25.0,1.0,1.0,25.0,40.0,2738.35,2.0,1.0,1.0,1.0,45.05,20.0,2738.35,10.0 +2019-01-09 05:00:00,25.0,1.0,1.0,25.0,40.0,2890.0,2.0,1.0,1.0,1.0,45.05,20.0,2890.0,10.0 +2019-01-09 06:00:00,25.0,1.0,1.0,25.0,40.0,3246.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3246.9,10.0 +2019-01-09 07:00:00,25.0,1.0,1.0,25.0,40.0,3569.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3569.55,10.0 +2019-01-09 08:00:00,25.0,1.0,1.0,25.0,40.0,3703.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3703.15,10.0 +2019-01-09 09:00:00,25.0,1.0,1.0,25.0,40.0,3730.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3730.45,10.0 +2019-01-09 10:00:00,25.0,1.0,1.0,25.0,40.0,3772.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3772.75,10.0 +2019-01-09 11:00:00,25.0,1.0,1.0,25.0,40.0,3816.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3816.55,10.0 +2019-01-09 12:00:00,25.0,1.0,1.0,25.0,40.0,3792.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3792.45,10.0 +2019-01-09 13:00:00,25.0,1.0,1.0,25.0,40.0,3776.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3776.8,10.0 +2019-01-09 14:00:00,25.0,1.0,1.0,25.0,40.0,3742.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3742.7,10.0 +2019-01-09 15:00:00,25.0,1.0,1.0,25.0,40.0,3707.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3707.4,10.0 +2019-01-09 16:00:00,25.0,1.0,1.0,25.0,40.0,3728.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3728.8,10.0 +2019-01-09 17:00:00,25.0,1.0,1.0,25.0,40.0,3845.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3845.15,10.0 +2019-01-09 18:00:00,25.0,1.0,1.0,25.0,40.0,3804.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3804.35,10.0 +2019-01-09 19:00:00,25.0,1.0,1.0,25.0,40.0,3714.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3714.85,10.0 +2019-01-09 20:00:00,25.0,1.0,1.0,25.0,40.0,3503.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3503.05,10.0 +2019-01-09 21:00:00,25.0,1.0,1.0,25.0,40.0,3325.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3325.15,10.0 +2019-01-09 22:00:00,25.0,1.0,1.0,25.0,40.0,3174.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3174.7,10.0 +2019-01-09 23:00:00,25.0,1.0,1.0,25.0,40.0,2947.75,2.0,1.0,1.0,1.0,45.05,20.0,2947.75,10.0 +2019-01-10 00:00:00,25.0,1.0,1.0,25.0,40.0,2816.15,2.0,1.0,1.0,1.0,45.05,20.0,2816.15,10.0 +2019-01-10 01:00:00,25.0,1.0,1.0,25.0,40.0,2697.95,2.0,1.0,1.0,1.0,45.05,20.0,2697.95,10.0 +2019-01-10 02:00:00,25.0,1.0,1.0,25.0,40.0,2649.9,2.0,1.0,1.0,1.0,45.05,20.0,2649.9,10.0 +2019-01-10 03:00:00,25.0,1.0,1.0,25.0,40.0,2657.15,2.0,1.0,1.0,1.0,45.05,20.0,2657.15,10.0 +2019-01-10 04:00:00,25.0,1.0,1.0,25.0,40.0,2692.25,2.0,1.0,1.0,1.0,45.05,20.0,2692.25,10.0 +2019-01-10 05:00:00,25.0,1.0,1.0,25.0,40.0,2830.65,2.0,1.0,1.0,1.0,45.05,20.0,2830.65,10.0 +2019-01-10 06:00:00,25.0,1.0,1.0,25.0,40.0,3177.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3177.8,10.0 +2019-01-10 07:00:00,25.0,1.0,1.0,25.0,40.0,3467.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3467.45,10.0 +2019-01-10 08:00:00,25.0,1.0,1.0,25.0,40.0,3590.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3590.6,10.0 +2019-01-10 09:00:00,25.0,1.0,1.0,25.0,40.0,3593.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3593.15,10.0 +2019-01-10 10:00:00,25.0,1.0,1.0,25.0,40.0,3636.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3636.4,10.0 +2019-01-10 11:00:00,25.0,1.0,1.0,25.0,40.0,3665.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3665.65,10.0 +2019-01-10 12:00:00,25.0,1.0,1.0,25.0,40.0,3652.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3652.6,10.0 +2019-01-10 13:00:00,25.0,1.0,1.0,25.0,40.0,3625.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3625.95,10.0 +2019-01-10 14:00:00,25.0,1.0,1.0,25.0,40.0,3573.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3573.15,10.0 +2019-01-10 15:00:00,25.0,1.0,1.0,25.0,40.0,3550.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3550.55,10.0 +2019-01-10 16:00:00,25.0,1.0,1.0,25.0,40.0,3570.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3570.75,10.0 +2019-01-10 17:00:00,25.0,1.0,1.0,25.0,40.0,3688.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3688.05,10.0 +2019-01-10 18:00:00,25.0,1.0,1.0,25.0,40.0,3654.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3654.5,10.0 +2019-01-10 19:00:00,25.0,1.0,1.0,25.0,40.0,3570.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3570.0,10.0 +2019-01-10 20:00:00,25.0,1.0,1.0,25.0,40.0,3393.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3393.3,10.0 +2019-01-10 21:00:00,25.0,1.0,1.0,25.0,40.0,3210.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3210.3,10.0 +2019-01-10 22:00:00,25.0,1.0,1.0,25.0,40.0,3107.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3107.25,10.0 +2019-01-10 23:00:00,25.0,1.0,1.0,25.0,40.0,2918.55,2.0,1.0,1.0,1.0,45.05,20.0,2918.55,10.0 +2019-01-11 00:00:00,25.0,1.0,1.0,25.0,40.0,2809.6000000000004,2.0,1.0,1.0,1.0,45.05,20.0,2809.6000000000004,10.0 +2019-01-11 01:00:00,25.0,1.0,1.0,25.0,40.0,2705.1,2.0,1.0,1.0,1.0,45.05,20.0,2705.1,10.0 +2019-01-11 02:00:00,25.0,1.0,1.0,25.0,40.0,2679.4,2.0,1.0,1.0,1.0,45.05,20.0,2679.4,10.0 +2019-01-11 03:00:00,25.0,1.0,1.0,25.0,40.0,2702.0,2.0,1.0,1.0,1.0,45.05,20.0,2702.0,10.0 +2019-01-11 04:00:00,25.0,1.0,1.0,25.0,40.0,2770.4,2.0,1.0,1.0,1.0,45.05,20.0,2770.4,10.0 +2019-01-11 05:00:00,25.0,1.0,1.0,25.0,40.0,2914.55,2.0,1.0,1.0,1.0,45.05,20.0,2914.55,10.0 +2019-01-11 06:00:00,25.0,1.0,1.0,25.0,40.0,3293.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3293.9,10.0 +2019-01-11 07:00:00,25.0,1.0,1.0,25.0,40.0,3599.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3599.05,10.0 +2019-01-11 08:00:00,25.0,1.0,1.0,25.0,40.0,3736.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3736.35,10.0 +2019-01-11 09:00:00,25.0,1.0,1.0,25.0,40.0,3759.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3759.75,10.0 +2019-01-11 10:00:00,25.0,1.0,1.0,25.0,40.0,3792.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3792.85,10.0 +2019-01-11 11:00:00,25.0,1.0,1.0,25.0,40.0,3831.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3831.55,10.0 +2019-01-11 12:00:00,25.0,1.0,1.0,25.0,40.0,3803.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3803.0,10.0 +2019-01-11 13:00:00,25.0,1.0,1.0,25.0,40.0,3736.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3736.6,10.0 +2019-01-11 14:00:00,25.0,1.0,1.0,25.0,40.0,3648.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3648.95,10.0 +2019-01-11 15:00:00,25.0,1.0,1.0,25.0,40.0,3600.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3600.85,10.0 +2019-01-11 16:00:00,25.0,1.0,1.0,25.0,40.0,3618.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3618.25,10.0 +2019-01-11 17:00:00,25.0,1.0,1.0,25.0,40.0,3718.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3718.0,10.0 +2019-01-11 18:00:00,25.0,1.0,1.0,25.0,40.0,3661.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3661.0,10.0 +2019-01-11 19:00:00,25.0,1.0,1.0,25.0,40.0,3571.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3571.75,10.0 +2019-01-11 20:00:00,25.0,1.0,1.0,25.0,40.0,3353.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3353.9,10.0 +2019-01-11 21:00:00,25.0,1.0,1.0,25.0,40.0,3170.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3170.4,10.0 +2019-01-11 22:00:00,25.0,1.0,1.0,25.0,40.0,3049.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3049.85,10.0 +2019-01-11 23:00:00,25.0,1.0,1.0,25.0,40.0,2862.8,2.0,1.0,1.0,1.0,45.05,20.0,2862.8,10.0 +2019-01-12 00:00:00,25.0,1.0,1.0,25.0,40.0,2657.1,2.0,1.0,1.0,1.0,45.05,20.0,2657.1,10.0 +2019-01-12 01:00:00,25.0,1.0,1.0,25.0,40.0,2533.2,2.0,1.0,1.0,1.0,45.05,20.0,2533.2,10.0 +2019-01-12 02:00:00,25.0,1.0,1.0,25.0,40.0,2446.6,2.0,1.0,1.0,1.0,45.05,20.0,2446.6,10.0 +2019-01-12 03:00:00,25.0,1.0,1.0,25.0,40.0,2411.35,2.0,1.0,1.0,1.0,45.05,20.0,2411.35,10.0 +2019-01-12 04:00:00,25.0,1.0,1.0,25.0,40.0,2429.65,2.0,1.0,1.0,1.0,45.05,20.0,2429.65,10.0 +2019-01-12 05:00:00,25.0,1.0,1.0,25.0,40.0,2448.15,2.0,1.0,1.0,1.0,45.05,20.0,2448.15,10.0 +2019-01-12 06:00:00,25.0,1.0,1.0,25.0,40.0,2522.15,2.0,1.0,1.0,1.0,45.05,20.0,2522.15,10.0 +2019-01-12 07:00:00,25.0,1.0,1.0,25.0,40.0,2680.95,2.0,1.0,1.0,1.0,45.05,20.0,2680.95,10.0 +2019-01-12 08:00:00,25.0,1.0,1.0,25.0,40.0,2858.05,2.0,1.0,1.0,1.0,45.05,20.0,2858.05,10.0 +2019-01-12 09:00:00,25.0,1.0,1.0,25.0,40.0,3033.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3033.45,10.0 +2019-01-12 10:00:00,25.0,1.0,1.0,25.0,40.0,3158.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3158.5,10.0 +2019-01-12 11:00:00,25.0,1.0,1.0,25.0,40.0,3229.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3229.15,10.0 +2019-01-12 12:00:00,25.0,1.0,1.0,25.0,40.0,3208.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3208.6,10.0 +2019-01-12 13:00:00,25.0,1.0,1.0,25.0,40.0,3128.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3128.2,10.0 +2019-01-12 14:00:00,25.0,1.0,1.0,25.0,40.0,3074.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3074.55,10.0 +2019-01-12 15:00:00,25.0,1.0,1.0,25.0,40.0,3049.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3049.8,10.0 +2019-01-12 16:00:00,25.0,1.0,1.0,25.0,40.0,3102.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3102.55,10.0 +2019-01-12 17:00:00,25.0,1.0,1.0,25.0,40.0,3228.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3228.65,10.0 +2019-01-12 18:00:00,25.0,1.0,1.0,25.0,40.0,3214.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3214.4,10.0 +2019-01-12 19:00:00,25.0,1.0,1.0,25.0,40.0,3109.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3109.9,10.0 +2019-01-12 20:00:00,25.0,1.0,1.0,25.0,40.0,2941.1000000000004,2.0,1.0,1.0,1.0,45.05,20.0,2941.1000000000004,10.0 +2019-01-12 21:00:00,25.0,1.0,1.0,25.0,40.0,2798.75,2.0,1.0,1.0,1.0,45.05,20.0,2798.75,10.0 +2019-01-12 22:00:00,25.0,1.0,1.0,25.0,40.0,2737.7,2.0,1.0,1.0,1.0,45.05,20.0,2737.7,10.0 +2019-01-12 23:00:00,25.0,1.0,1.0,25.0,40.0,2606.9,2.0,1.0,1.0,1.0,45.05,20.0,2606.9,10.0 +2019-01-13 00:00:00,25.0,1.0,1.0,25.0,40.0,2431.55,2.0,1.0,1.0,1.0,45.05,20.0,2431.55,10.0 +2019-01-13 01:00:00,25.0,1.0,1.0,25.0,40.0,2323.05,2.0,1.0,1.0,1.0,45.05,20.0,2323.05,10.0 +2019-01-13 02:00:00,25.0,1.0,1.0,25.0,40.0,2257.5,2.0,1.0,1.0,1.0,45.05,20.0,2257.5,10.0 +2019-01-13 03:00:00,25.0,1.0,1.0,25.0,40.0,2218.85,2.0,1.0,1.0,1.0,45.05,20.0,2218.85,10.0 +2019-01-13 04:00:00,25.0,1.0,1.0,25.0,40.0,2211.95,2.0,1.0,1.0,1.0,45.05,20.0,2211.95,10.0 +2019-01-13 05:00:00,25.0,1.0,1.0,25.0,40.0,2195.75,2.0,1.0,1.0,1.0,45.05,20.0,2195.75,10.0 +2019-01-13 06:00:00,25.0,1.0,1.0,25.0,40.0,2170.25,2.0,1.0,1.0,1.0,45.05,20.0,2170.25,10.0 +2019-01-13 07:00:00,25.0,1.0,1.0,25.0,40.0,2276.3,2.0,1.0,1.0,1.0,45.05,20.0,2276.3,10.0 +2019-01-13 08:00:00,25.0,1.0,1.0,25.0,40.0,2461.55,2.0,1.0,1.0,1.0,45.05,20.0,2461.55,10.0 +2019-01-13 09:00:00,25.0,1.0,1.0,25.0,40.0,2670.55,2.0,1.0,1.0,1.0,45.05,20.0,2670.55,10.0 +2019-01-13 10:00:00,25.0,1.0,1.0,25.0,40.0,2829.45,2.0,1.0,1.0,1.0,45.05,20.0,2829.45,10.0 +2019-01-13 11:00:00,25.0,1.0,1.0,25.0,40.0,2985.7,2.0,1.0,1.0,1.0,45.05,20.0,2985.7,10.0 +2019-01-13 12:00:00,25.0,1.0,1.0,25.0,40.0,2992.1,2.0,1.0,1.0,1.0,45.05,20.0,2992.1,10.0 +2019-01-13 13:00:00,25.0,1.0,1.0,25.0,40.0,2940.1,2.0,1.0,1.0,1.0,45.05,20.0,2940.1,10.0 +2019-01-13 14:00:00,25.0,1.0,1.0,25.0,40.0,2895.55,2.0,1.0,1.0,1.0,45.05,20.0,2895.55,10.0 +2019-01-13 15:00:00,25.0,1.0,1.0,25.0,40.0,2874.55,2.0,1.0,1.0,1.0,45.05,20.0,2874.55,10.0 +2019-01-13 16:00:00,25.0,1.0,1.0,25.0,40.0,2922.4,2.0,1.0,1.0,1.0,45.05,20.0,2922.4,10.0 +2019-01-13 17:00:00,25.0,1.0,1.0,25.0,40.0,3061.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3061.95,10.0 +2019-01-13 18:00:00,25.0,1.0,1.0,25.0,40.0,3105.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3105.45,10.0 +2019-01-13 19:00:00,25.0,1.0,1.0,25.0,40.0,3032.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3032.35,10.0 +2019-01-13 20:00:00,25.0,1.0,1.0,25.0,40.0,2892.75,2.0,1.0,1.0,1.0,45.05,20.0,2892.75,10.0 +2019-01-13 21:00:00,25.0,1.0,1.0,25.0,40.0,2819.0,2.0,1.0,1.0,1.0,45.05,20.0,2819.0,10.0 +2019-01-13 22:00:00,25.0,1.0,1.0,25.0,40.0,2803.7,2.0,1.0,1.0,1.0,45.05,20.0,2803.7,10.0 +2019-01-13 23:00:00,25.0,1.0,1.0,25.0,40.0,2673.3,2.0,1.0,1.0,1.0,45.05,20.0,2673.3,10.0 +2019-01-14 00:00:00,25.0,1.0,1.0,25.0,40.0,2542.55,2.0,1.0,1.0,1.0,45.05,20.0,2542.55,10.0 +2019-01-14 01:00:00,25.0,1.0,1.0,25.0,40.0,2435.3500000000004,2.0,1.0,1.0,1.0,45.05,20.0,2435.3500000000004,10.0 +2019-01-14 02:00:00,25.0,1.0,1.0,25.0,40.0,2402.75,2.0,1.0,1.0,1.0,45.05,20.0,2402.75,10.0 +2019-01-14 03:00:00,25.0,1.0,1.0,25.0,40.0,2429.2,2.0,1.0,1.0,1.0,45.05,20.0,2429.2,10.0 +2019-01-14 04:00:00,25.0,1.0,1.0,25.0,40.0,2514.75,2.0,1.0,1.0,1.0,45.05,20.0,2514.75,10.0 +2019-01-14 05:00:00,25.0,1.0,1.0,25.0,40.0,2734.55,2.0,1.0,1.0,1.0,45.05,20.0,2734.55,10.0 +2019-01-14 06:00:00,25.0,1.0,1.0,25.0,40.0,3157.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3157.15,10.0 +2019-01-14 07:00:00,25.0,1.0,1.0,25.0,40.0,3498.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3498.55,10.0 +2019-01-14 08:00:00,25.0,1.0,1.0,25.0,40.0,3627.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3627.4,10.0 +2019-01-14 09:00:00,25.0,1.0,1.0,25.0,40.0,3634.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3634.6,10.0 +2019-01-14 10:00:00,25.0,1.0,1.0,25.0,40.0,3651.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3651.8,10.0 +2019-01-14 11:00:00,25.0,1.0,1.0,25.0,40.0,3692.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3692.35,10.0 +2019-01-14 12:00:00,25.0,1.0,1.0,25.0,40.0,3683.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3683.0,10.0 +2019-01-14 13:00:00,25.0,1.0,1.0,25.0,40.0,3682.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3682.7,10.0 +2019-01-14 14:00:00,25.0,1.0,1.0,25.0,40.0,3651.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3651.45,10.0 +2019-01-14 15:00:00,25.0,1.0,1.0,25.0,40.0,3604.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3604.8,10.0 +2019-01-14 16:00:00,25.0,1.0,1.0,25.0,40.0,3582.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3582.85,10.0 +2019-01-14 17:00:00,25.0,1.0,1.0,25.0,40.0,3740.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3740.6,10.0 +2019-01-14 18:00:00,25.0,1.0,1.0,25.0,40.0,3718.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3718.95,10.0 +2019-01-14 19:00:00,25.0,1.0,1.0,25.0,40.0,3627.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3627.7,10.0 +2019-01-14 20:00:00,25.0,1.0,1.0,25.0,40.0,3428.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3428.0,10.0 +2019-01-14 21:00:00,25.0,1.0,1.0,25.0,40.0,3256.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3256.45,10.0 +2019-01-14 22:00:00,25.0,1.0,1.0,25.0,40.0,3104.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3104.75,10.0 +2019-01-14 23:00:00,25.0,1.0,1.0,25.0,40.0,2898.65,2.0,1.0,1.0,1.0,45.05,20.0,2898.65,10.0 +2019-01-15 00:00:00,25.0,1.0,1.0,25.0,40.0,2764.5,2.0,1.0,1.0,1.0,45.05,20.0,2764.5,10.0 +2019-01-15 01:00:00,25.0,1.0,1.0,25.0,40.0,2677.45,2.0,1.0,1.0,1.0,45.05,20.0,2677.45,10.0 +2019-01-15 02:00:00,25.0,1.0,1.0,25.0,40.0,2648.0,2.0,1.0,1.0,1.0,45.05,20.0,2648.0,10.0 +2019-01-15 03:00:00,25.0,1.0,1.0,25.0,40.0,2658.05,2.0,1.0,1.0,1.0,45.05,20.0,2658.05,10.0 +2019-01-15 04:00:00,25.0,1.0,1.0,25.0,40.0,2712.55,2.0,1.0,1.0,1.0,45.05,20.0,2712.55,10.0 +2019-01-15 05:00:00,25.0,1.0,1.0,25.0,40.0,2885.3,2.0,1.0,1.0,1.0,45.05,20.0,2885.3,10.0 +2019-01-15 06:00:00,25.0,1.0,1.0,25.0,40.0,3224.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3224.55,10.0 +2019-01-15 07:00:00,25.0,1.0,1.0,25.0,40.0,3553.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3553.8,10.0 +2019-01-15 08:00:00,25.0,1.0,1.0,25.0,40.0,3676.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3676.35,10.0 +2019-01-15 09:00:00,25.0,1.0,1.0,25.0,40.0,3677.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3677.25,10.0 +2019-01-15 10:00:00,25.0,1.0,1.0,25.0,40.0,3724.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3724.55,10.0 +2019-01-15 11:00:00,25.0,1.0,1.0,25.0,40.0,3746.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3746.8,10.0 +2019-01-15 12:00:00,25.0,1.0,1.0,25.0,40.0,3715.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3715.65,10.0 +2019-01-15 13:00:00,25.0,1.0,1.0,25.0,40.0,3708.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3708.65,10.0 +2019-01-15 14:00:00,25.0,1.0,1.0,25.0,40.0,3689.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3689.45,10.0 +2019-01-15 15:00:00,25.0,1.0,1.0,25.0,40.0,3667.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3667.9,10.0 +2019-01-15 16:00:00,25.0,1.0,1.0,25.0,40.0,3667.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3667.75,10.0 +2019-01-15 17:00:00,25.0,1.0,1.0,25.0,40.0,3769.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3769.5,10.0 +2019-01-15 18:00:00,25.0,1.0,1.0,25.0,40.0,3737.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3737.35,10.0 +2019-01-15 19:00:00,25.0,1.0,1.0,25.0,40.0,3652.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3652.3,10.0 +2019-01-15 20:00:00,25.0,1.0,1.0,25.0,40.0,3463.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3463.9,10.0 +2019-01-15 21:00:00,25.0,1.0,1.0,25.0,40.0,3288.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3288.4,10.0 +2019-01-15 22:00:00,25.0,1.0,1.0,25.0,40.0,3138.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3138.75,10.0 +2019-01-15 23:00:00,25.0,1.0,1.0,25.0,40.0,2929.25,2.0,1.0,1.0,1.0,45.05,20.0,2929.25,10.0 +2019-01-16 00:00:00,25.0,1.0,1.0,25.0,40.0,2758.1,2.0,1.0,1.0,1.0,45.05,20.0,2758.1,10.0 +2019-01-16 01:00:00,25.0,1.0,1.0,25.0,40.0,2650.55,2.0,1.0,1.0,1.0,45.05,20.0,2650.55,10.0 +2019-01-16 02:00:00,25.0,1.0,1.0,25.0,40.0,2596.6,2.0,1.0,1.0,1.0,45.05,20.0,2596.6,10.0 +2019-01-16 03:00:00,25.0,1.0,1.0,25.0,40.0,2604.8500000000004,2.0,1.0,1.0,1.0,45.05,20.0,2604.8500000000004,10.0 +2019-01-16 04:00:00,25.0,1.0,1.0,25.0,40.0,2680.2,2.0,1.0,1.0,1.0,45.05,20.0,2680.2,10.0 +2019-01-16 05:00:00,25.0,1.0,1.0,25.0,40.0,2839.9,2.0,1.0,1.0,1.0,45.05,20.0,2839.9,10.0 +2019-01-16 06:00:00,25.0,1.0,1.0,25.0,40.0,3240.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3240.25,10.0 +2019-01-16 07:00:00,25.0,1.0,1.0,25.0,40.0,3586.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3586.95,10.0 +2019-01-16 08:00:00,25.0,1.0,1.0,25.0,40.0,3694.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3694.8,10.0 +2019-01-16 09:00:00,25.0,1.0,1.0,25.0,40.0,3700.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3700.25,10.0 +2019-01-16 10:00:00,25.0,1.0,1.0,25.0,40.0,3742.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3742.0,10.0 +2019-01-16 11:00:00,25.0,1.0,1.0,25.0,40.0,3774.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3774.15,10.0 +2019-01-16 12:00:00,25.0,1.0,1.0,25.0,40.0,3745.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3745.85,10.0 +2019-01-16 13:00:00,25.0,1.0,1.0,25.0,40.0,3737.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3737.8,10.0 +2019-01-16 14:00:00,25.0,1.0,1.0,25.0,40.0,3665.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3665.45,10.0 +2019-01-16 15:00:00,25.0,1.0,1.0,25.0,40.0,3641.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3641.9,10.0 +2019-01-16 16:00:00,25.0,1.0,1.0,25.0,40.0,3641.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3641.2,10.0 +2019-01-16 17:00:00,25.0,1.0,1.0,25.0,40.0,3782.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3782.85,10.0 +2019-01-16 18:00:00,25.0,1.0,1.0,25.0,40.0,3759.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3759.9,10.0 +2019-01-16 19:00:00,25.0,1.0,1.0,25.0,40.0,3677.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3677.9,10.0 +2019-01-16 20:00:00,25.0,1.0,1.0,25.0,40.0,3464.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3464.55,10.0 +2019-01-16 21:00:00,25.0,1.0,1.0,25.0,40.0,3290.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3290.15,10.0 +2019-01-16 22:00:00,25.0,1.0,1.0,25.0,40.0,3143.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3143.1,10.0 +2019-01-16 23:00:00,25.0,1.0,1.0,25.0,40.0,2955.25,2.0,1.0,1.0,1.0,45.05,20.0,2955.25,10.0 +2019-01-17 00:00:00,25.0,1.0,1.0,25.0,40.0,2806.4,2.0,1.0,1.0,1.0,45.05,20.0,2806.4,10.0 +2019-01-17 01:00:00,25.0,1.0,1.0,25.0,40.0,2709.8,2.0,1.0,1.0,1.0,45.05,20.0,2709.8,10.0 +2019-01-17 02:00:00,25.0,1.0,1.0,25.0,40.0,2648.45,2.0,1.0,1.0,1.0,45.05,20.0,2648.45,10.0 +2019-01-17 03:00:00,25.0,1.0,1.0,25.0,40.0,2648.75,2.0,1.0,1.0,1.0,45.05,20.0,2648.75,10.0 +2019-01-17 04:00:00,25.0,1.0,1.0,25.0,40.0,2731.2,2.0,1.0,1.0,1.0,45.05,20.0,2731.2,10.0 +2019-01-17 05:00:00,25.0,1.0,1.0,25.0,40.0,2890.55,2.0,1.0,1.0,1.0,45.05,20.0,2890.55,10.0 +2019-01-17 06:00:00,25.0,1.0,1.0,25.0,40.0,3265.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3265.5,10.0 +2019-01-17 07:00:00,25.0,1.0,1.0,25.0,40.0,3611.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3611.0,10.0 +2019-01-17 08:00:00,25.0,1.0,1.0,25.0,40.0,3729.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3729.15,10.0 +2019-01-17 09:00:00,25.0,1.0,1.0,25.0,40.0,3723.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3723.15,10.0 +2019-01-17 10:00:00,25.0,1.0,1.0,25.0,40.0,3752.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3752.65,10.0 +2019-01-17 11:00:00,25.0,1.0,1.0,25.0,40.0,3801.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3801.5,10.0 +2019-01-17 12:00:00,25.0,1.0,1.0,25.0,40.0,3798.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3798.05,10.0 +2019-01-17 13:00:00,25.0,1.0,1.0,25.0,40.0,3778.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3778.05,10.0 +2019-01-17 14:00:00,25.0,1.0,1.0,25.0,40.0,3732.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3732.3,10.0 +2019-01-17 15:00:00,25.0,1.0,1.0,25.0,40.0,3685.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3685.15,10.0 +2019-01-17 16:00:00,25.0,1.0,1.0,25.0,40.0,3670.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3670.45,10.0 +2019-01-17 17:00:00,25.0,1.0,1.0,25.0,40.0,3807.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3807.35,10.0 +2019-01-17 18:00:00,25.0,1.0,1.0,25.0,40.0,3804.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3804.0,10.0 +2019-01-17 19:00:00,25.0,1.0,1.0,25.0,40.0,3701.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3701.25,10.0 +2019-01-17 20:00:00,25.0,1.0,1.0,25.0,40.0,3522.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3522.15,10.0 +2019-01-17 21:00:00,25.0,1.0,1.0,25.0,40.0,3326.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3326.2,10.0 +2019-01-17 22:00:00,25.0,1.0,1.0,25.0,40.0,3166.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3166.9,10.0 +2019-01-17 23:00:00,25.0,1.0,1.0,25.0,40.0,2949.15,2.0,1.0,1.0,1.0,45.05,20.0,2949.15,10.0 +2019-01-18 00:00:00,25.0,1.0,1.0,25.0,40.0,2822.65,2.0,1.0,1.0,1.0,45.05,20.0,2822.65,10.0 +2019-01-18 01:00:00,25.0,1.0,1.0,25.0,40.0,2705.15,2.0,1.0,1.0,1.0,45.05,20.0,2705.15,10.0 +2019-01-18 02:00:00,25.0,1.0,1.0,25.0,40.0,2636.85,2.0,1.0,1.0,1.0,45.05,20.0,2636.85,10.0 +2019-01-18 03:00:00,25.0,1.0,1.0,25.0,40.0,2651.85,2.0,1.0,1.0,1.0,45.05,20.0,2651.85,10.0 +2019-01-18 04:00:00,25.0,1.0,1.0,25.0,40.0,2715.4,2.0,1.0,1.0,1.0,45.05,20.0,2715.4,10.0 +2019-01-18 05:00:00,25.0,1.0,1.0,25.0,40.0,2850.15,2.0,1.0,1.0,1.0,45.05,20.0,2850.15,10.0 +2019-01-18 06:00:00,25.0,1.0,1.0,25.0,40.0,3207.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3207.95,10.0 +2019-01-18 07:00:00,25.0,1.0,1.0,25.0,40.0,3519.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3519.1,10.0 +2019-01-18 08:00:00,25.0,1.0,1.0,25.0,40.0,3623.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3623.05,10.0 +2019-01-18 09:00:00,25.0,1.0,1.0,25.0,40.0,3612.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3612.6,10.0 +2019-01-18 10:00:00,25.0,1.0,1.0,25.0,40.0,3608.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3608.0,10.0 +2019-01-18 11:00:00,25.0,1.0,1.0,25.0,40.0,3625.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3625.5,10.0 +2019-01-18 12:00:00,25.0,1.0,1.0,25.0,40.0,3583.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3583.0,10.0 +2019-01-18 13:00:00,25.0,1.0,1.0,25.0,40.0,3514.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3514.85,10.0 +2019-01-18 14:00:00,25.0,1.0,1.0,25.0,40.0,3457.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3457.15,10.0 +2019-01-18 15:00:00,25.0,1.0,1.0,25.0,40.0,3418.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3418.0,10.0 +2019-01-18 16:00:00,25.0,1.0,1.0,25.0,40.0,3433.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3433.5,10.0 +2019-01-18 17:00:00,25.0,1.0,1.0,25.0,40.0,3599.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3599.75,10.0 +2019-01-18 18:00:00,25.0,1.0,1.0,25.0,40.0,3587.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3587.15,10.0 +2019-01-18 19:00:00,25.0,1.0,1.0,25.0,40.0,3489.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3489.6,10.0 +2019-01-18 20:00:00,25.0,1.0,1.0,25.0,40.0,3290.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3290.1,10.0 +2019-01-18 21:00:00,25.0,1.0,1.0,25.0,40.0,3139.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3139.35,10.0 +2019-01-18 22:00:00,25.0,1.0,1.0,25.0,40.0,3022.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3022.8,10.0 +2019-01-18 23:00:00,25.0,1.0,1.0,25.0,40.0,2887.4,2.0,1.0,1.0,1.0,45.05,20.0,2887.4,10.0 +2019-01-19 00:00:00,25.0,1.0,1.0,25.0,40.0,2757.35,2.0,1.0,1.0,1.0,45.05,20.0,2757.35,10.0 +2019-01-19 01:00:00,25.0,1.0,1.0,25.0,40.0,2649.3,2.0,1.0,1.0,1.0,45.05,20.0,2649.3,10.0 +2019-01-19 02:00:00,25.0,1.0,1.0,25.0,40.0,2575.25,2.0,1.0,1.0,1.0,45.05,20.0,2575.25,10.0 +2019-01-19 03:00:00,25.0,1.0,1.0,25.0,40.0,2542.55,2.0,1.0,1.0,1.0,45.05,20.0,2542.55,10.0 +2019-01-19 04:00:00,25.0,1.0,1.0,25.0,40.0,2529.4,2.0,1.0,1.0,1.0,45.05,20.0,2529.4,10.0 +2019-01-19 05:00:00,25.0,1.0,1.0,25.0,40.0,2521.0,2.0,1.0,1.0,1.0,45.05,20.0,2521.0,10.0 +2019-01-19 06:00:00,25.0,1.0,1.0,25.0,40.0,2554.15,2.0,1.0,1.0,1.0,45.05,20.0,2554.15,10.0 +2019-01-19 07:00:00,25.0,1.0,1.0,25.0,40.0,2689.6,2.0,1.0,1.0,1.0,45.05,20.0,2689.6,10.0 +2019-01-19 08:00:00,25.0,1.0,1.0,25.0,40.0,2868.35,2.0,1.0,1.0,1.0,45.05,20.0,2868.35,10.0 +2019-01-19 09:00:00,25.0,1.0,1.0,25.0,40.0,3016.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3016.45,10.0 +2019-01-19 10:00:00,25.0,1.0,1.0,25.0,40.0,3093.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3093.5,10.0 +2019-01-19 11:00:00,25.0,1.0,1.0,25.0,40.0,3114.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3114.6,10.0 +2019-01-19 12:00:00,25.0,1.0,1.0,25.0,40.0,3061.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3061.4,10.0 +2019-01-19 13:00:00,25.0,1.0,1.0,25.0,40.0,2959.45,2.0,1.0,1.0,1.0,45.05,20.0,2959.45,10.0 +2019-01-19 14:00:00,25.0,1.0,1.0,25.0,40.0,2884.05,2.0,1.0,1.0,1.0,45.05,20.0,2884.05,10.0 +2019-01-19 15:00:00,25.0,1.0,1.0,25.0,40.0,2874.35,2.0,1.0,1.0,1.0,45.05,20.0,2874.35,10.0 +2019-01-19 16:00:00,25.0,1.0,1.0,25.0,40.0,2932.25,2.0,1.0,1.0,1.0,45.05,20.0,2932.25,10.0 +2019-01-19 17:00:00,25.0,1.0,1.0,25.0,40.0,3144.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3144.0,10.0 +2019-01-19 18:00:00,25.0,1.0,1.0,25.0,40.0,3175.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3175.8,10.0 +2019-01-19 19:00:00,25.0,1.0,1.0,25.0,40.0,3086.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3086.2,10.0 +2019-01-19 20:00:00,25.0,1.0,1.0,25.0,40.0,2903.3,2.0,1.0,1.0,1.0,45.05,20.0,2903.3,10.0 +2019-01-19 21:00:00,25.0,1.0,1.0,25.0,40.0,2790.05,2.0,1.0,1.0,1.0,45.05,20.0,2790.05,10.0 +2019-01-19 22:00:00,25.0,1.0,1.0,25.0,40.0,2756.0,2.0,1.0,1.0,1.0,45.05,20.0,2756.0,10.0 +2019-01-19 23:00:00,25.0,1.0,1.0,25.0,40.0,2652.8,2.0,1.0,1.0,1.0,45.05,20.0,2652.8,10.0 +2019-01-20 00:00:00,25.0,1.0,1.0,25.0,40.0,2531.9,2.0,1.0,1.0,1.0,45.05,20.0,2531.9,10.0 +2019-01-20 01:00:00,25.0,1.0,1.0,25.0,40.0,2435.45,2.0,1.0,1.0,1.0,45.05,20.0,2435.45,10.0 +2019-01-20 02:00:00,25.0,1.0,1.0,25.0,40.0,2381.0,2.0,1.0,1.0,1.0,45.05,20.0,2381.0,10.0 +2019-01-20 03:00:00,25.0,1.0,1.0,25.0,40.0,2352.6,2.0,1.0,1.0,1.0,45.05,20.0,2352.6,10.0 +2019-01-20 04:00:00,25.0,1.0,1.0,25.0,40.0,2333.05,2.0,1.0,1.0,1.0,45.05,20.0,2333.05,10.0 +2019-01-20 05:00:00,25.0,1.0,1.0,25.0,40.0,2310.95,2.0,1.0,1.0,1.0,45.05,20.0,2310.95,10.0 +2019-01-20 06:00:00,25.0,1.0,1.0,25.0,40.0,2278.05,2.0,1.0,1.0,1.0,45.05,20.0,2278.05,10.0 +2019-01-20 07:00:00,25.0,1.0,1.0,25.0,40.0,2363.3,2.0,1.0,1.0,1.0,45.05,20.0,2363.3,10.0 +2019-01-20 08:00:00,25.0,1.0,1.0,25.0,40.0,2525.85,2.0,1.0,1.0,1.0,45.05,20.0,2525.85,10.0 +2019-01-20 09:00:00,25.0,1.0,1.0,25.0,40.0,2694.95,2.0,1.0,1.0,1.0,45.05,20.0,2694.95,10.0 +2019-01-20 10:00:00,25.0,1.0,1.0,25.0,40.0,2787.25,2.0,1.0,1.0,1.0,45.05,20.0,2787.25,10.0 +2019-01-20 11:00:00,25.0,1.0,1.0,25.0,40.0,2871.95,2.0,1.0,1.0,1.0,45.05,20.0,2871.95,10.0 +2019-01-20 12:00:00,25.0,1.0,1.0,25.0,40.0,2860.55,2.0,1.0,1.0,1.0,45.05,20.0,2860.55,10.0 +2019-01-20 13:00:00,25.0,1.0,1.0,25.0,40.0,2778.4,2.0,1.0,1.0,1.0,45.05,20.0,2778.4,10.0 +2019-01-20 14:00:00,25.0,1.0,1.0,25.0,40.0,2703.95,2.0,1.0,1.0,1.0,45.05,20.0,2703.95,10.0 +2019-01-20 15:00:00,25.0,1.0,1.0,25.0,40.0,2688.55,2.0,1.0,1.0,1.0,45.05,20.0,2688.55,10.0 +2019-01-20 16:00:00,25.0,1.0,1.0,25.0,40.0,2752.85,2.0,1.0,1.0,1.0,45.05,20.0,2752.85,10.0 +2019-01-20 17:00:00,25.0,1.0,1.0,25.0,40.0,2974.75,2.0,1.0,1.0,1.0,45.05,20.0,2974.75,10.0 +2019-01-20 18:00:00,25.0,1.0,1.0,25.0,40.0,3070.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3070.1,10.0 +2019-01-20 19:00:00,25.0,1.0,1.0,25.0,40.0,3011.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3011.55,10.0 +2019-01-20 20:00:00,25.0,1.0,1.0,25.0,40.0,2889.25,2.0,1.0,1.0,1.0,45.05,20.0,2889.25,10.0 +2019-01-20 21:00:00,25.0,1.0,1.0,25.0,40.0,2817.3,2.0,1.0,1.0,1.0,45.05,20.0,2817.3,10.0 +2019-01-20 22:00:00,25.0,1.0,1.0,25.0,40.0,2839.0,2.0,1.0,1.0,1.0,45.05,20.0,2839.0,10.0 +2019-01-20 23:00:00,25.0,1.0,1.0,25.0,40.0,2735.75,2.0,1.0,1.0,1.0,45.05,20.0,2735.75,10.0 +2019-01-21 00:00:00,25.0,1.0,1.0,25.0,40.0,2651.5,2.0,1.0,1.0,1.0,45.05,20.0,2651.5,10.0 +2019-01-21 01:00:00,25.0,1.0,1.0,25.0,40.0,2573.35,2.0,1.0,1.0,1.0,45.05,20.0,2573.35,10.0 +2019-01-21 02:00:00,25.0,1.0,1.0,25.0,40.0,2541.4,2.0,1.0,1.0,1.0,45.05,20.0,2541.4,10.0 +2019-01-21 03:00:00,25.0,1.0,1.0,25.0,40.0,2557.4,2.0,1.0,1.0,1.0,45.05,20.0,2557.4,10.0 +2019-01-21 04:00:00,25.0,1.0,1.0,25.0,40.0,2640.5,2.0,1.0,1.0,1.0,45.05,20.0,2640.5,10.0 +2019-01-21 05:00:00,25.0,1.0,1.0,25.0,40.0,2841.0,2.0,1.0,1.0,1.0,45.05,20.0,2841.0,10.0 +2019-01-21 06:00:00,25.0,1.0,1.0,25.0,40.0,3255.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3255.4,10.0 +2019-01-21 07:00:00,25.0,1.0,1.0,25.0,40.0,3591.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3591.05,10.0 +2019-01-21 08:00:00,25.0,1.0,1.0,25.0,40.0,3684.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3684.05,10.0 +2019-01-21 09:00:00,25.0,1.0,1.0,25.0,40.0,3685.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3685.5,10.0 +2019-01-21 10:00:00,25.0,1.0,1.0,25.0,40.0,3698.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3698.45,10.0 +2019-01-21 11:00:00,25.0,1.0,1.0,25.0,40.0,3725.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3725.05,10.0 +2019-01-21 12:00:00,25.0,1.0,1.0,25.0,40.0,3689.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3689.65,10.0 +2019-01-21 13:00:00,25.0,1.0,1.0,25.0,40.0,3664.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3664.7,10.0 +2019-01-21 14:00:00,25.0,1.0,1.0,25.0,40.0,3609.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3609.65,10.0 +2019-01-21 15:00:00,25.0,1.0,1.0,25.0,40.0,3579.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3579.65,10.0 +2019-01-21 16:00:00,25.0,1.0,1.0,25.0,40.0,3582.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3582.25,10.0 +2019-01-21 17:00:00,25.0,1.0,1.0,25.0,40.0,3765.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3765.7,10.0 +2019-01-21 18:00:00,25.0,1.0,1.0,25.0,40.0,3779.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3779.65,10.0 +2019-01-21 19:00:00,25.0,1.0,1.0,25.0,40.0,3695.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3695.25,10.0 +2019-01-21 20:00:00,25.0,1.0,1.0,25.0,40.0,3514.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3514.15,10.0 +2019-01-21 21:00:00,25.0,1.0,1.0,25.0,40.0,3348.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3348.75,10.0 +2019-01-21 22:00:00,25.0,1.0,1.0,25.0,40.0,3228.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3228.05,10.0 +2019-01-21 23:00:00,25.0,1.0,1.0,25.0,40.0,3029.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3029.15,10.0 +2019-01-22 00:00:00,25.0,1.0,1.0,25.0,40.0,2894.4,2.0,1.0,1.0,1.0,45.05,20.0,2894.4,10.0 +2019-01-22 01:00:00,25.0,1.0,1.0,25.0,40.0,2817.35,2.0,1.0,1.0,1.0,45.05,20.0,2817.35,10.0 +2019-01-22 02:00:00,25.0,1.0,1.0,25.0,40.0,2764.95,2.0,1.0,1.0,1.0,45.05,20.0,2764.95,10.0 +2019-01-22 03:00:00,25.0,1.0,1.0,25.0,40.0,2756.75,2.0,1.0,1.0,1.0,45.05,20.0,2756.75,10.0 +2019-01-22 04:00:00,25.0,1.0,1.0,25.0,40.0,2814.4,2.0,1.0,1.0,1.0,45.05,20.0,2814.4,10.0 +2019-01-22 05:00:00,25.0,1.0,1.0,25.0,40.0,2973.65,2.0,1.0,1.0,1.0,45.05,20.0,2973.65,10.0 +2019-01-22 06:00:00,25.0,1.0,1.0,25.0,40.0,3339.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3339.0,10.0 +2019-01-22 07:00:00,25.0,1.0,1.0,25.0,40.0,3662.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3662.55,10.0 +2019-01-22 08:00:00,25.0,1.0,1.0,25.0,40.0,3775.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3775.65,10.0 +2019-01-22 09:00:00,25.0,1.0,1.0,25.0,40.0,3790.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3790.05,10.0 +2019-01-22 10:00:00,25.0,1.0,1.0,25.0,40.0,3807.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3807.5,10.0 +2019-01-22 11:00:00,25.0,1.0,1.0,25.0,40.0,3826.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3826.9,10.0 +2019-01-22 12:00:00,25.0,1.0,1.0,25.0,40.0,3800.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3800.2,10.0 +2019-01-22 13:00:00,25.0,1.0,1.0,25.0,40.0,3773.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3773.0,10.0 +2019-01-22 14:00:00,25.0,1.0,1.0,25.0,40.0,3742.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3742.55,10.0 +2019-01-22 15:00:00,25.0,1.0,1.0,25.0,40.0,3710.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3710.55,10.0 +2019-01-22 16:00:00,25.0,1.0,1.0,25.0,40.0,3702.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3702.75,10.0 +2019-01-22 17:00:00,25.0,1.0,1.0,25.0,40.0,3840.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3840.3,10.0 +2019-01-22 18:00:00,25.0,1.0,1.0,25.0,40.0,3833.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3833.25,10.0 +2019-01-22 19:00:00,25.0,1.0,1.0,25.0,40.0,3754.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3754.35,10.0 +2019-01-22 20:00:00,25.0,1.0,1.0,25.0,40.0,3555.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3555.65,10.0 +2019-01-22 21:00:00,25.0,1.0,1.0,25.0,40.0,3375.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3375.5,10.0 +2019-01-22 22:00:00,25.0,1.0,1.0,25.0,40.0,3240.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3240.9,10.0 +2019-01-22 23:00:00,25.0,1.0,1.0,25.0,40.0,3056.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3056.35,10.0 +2019-01-23 00:00:00,25.0,1.0,1.0,25.0,40.0,2936.65,2.0,1.0,1.0,1.0,45.05,20.0,2936.65,10.0 +2019-01-23 01:00:00,25.0,1.0,1.0,25.0,40.0,2835.6,2.0,1.0,1.0,1.0,45.05,20.0,2835.6,10.0 +2019-01-23 02:00:00,25.0,1.0,1.0,25.0,40.0,2796.5,2.0,1.0,1.0,1.0,45.05,20.0,2796.5,10.0 +2019-01-23 03:00:00,25.0,1.0,1.0,25.0,40.0,2798.4,2.0,1.0,1.0,1.0,45.05,20.0,2798.4,10.0 +2019-01-23 04:00:00,25.0,1.0,1.0,25.0,40.0,2838.35,2.0,1.0,1.0,1.0,45.05,20.0,2838.35,10.0 +2019-01-23 05:00:00,25.0,1.0,1.0,25.0,40.0,2987.45,2.0,1.0,1.0,1.0,45.05,20.0,2987.45,10.0 +2019-01-23 06:00:00,25.0,1.0,1.0,25.0,40.0,3332.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3332.9,10.0 +2019-01-23 07:00:00,25.0,1.0,1.0,25.0,40.0,3663.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3663.9,10.0 +2019-01-23 08:00:00,25.0,1.0,1.0,25.0,40.0,3757.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3757.9,10.0 +2019-01-23 09:00:00,25.0,1.0,1.0,25.0,40.0,3771.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3771.7,10.0 +2019-01-23 10:00:00,25.0,1.0,1.0,25.0,40.0,3787.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3787.35,10.0 +2019-01-23 11:00:00,25.0,1.0,1.0,25.0,40.0,3797.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3797.95,10.0 +2019-01-23 12:00:00,25.0,1.0,1.0,25.0,40.0,3773.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3773.2,10.0 +2019-01-23 13:00:00,25.0,1.0,1.0,25.0,40.0,3749.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3749.0,10.0 +2019-01-23 14:00:00,25.0,1.0,1.0,25.0,40.0,3704.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3704.5,10.0 +2019-01-23 15:00:00,25.0,1.0,1.0,25.0,40.0,3666.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3666.35,10.0 +2019-01-23 16:00:00,25.0,1.0,1.0,25.0,40.0,3661.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3661.8,10.0 +2019-01-23 17:00:00,25.0,1.0,1.0,25.0,40.0,3796.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3796.5,10.0 +2019-01-23 18:00:00,25.0,1.0,1.0,25.0,40.0,3801.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3801.5,10.0 +2019-01-23 19:00:00,25.0,1.0,1.0,25.0,40.0,3716.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3716.25,10.0 +2019-01-23 20:00:00,25.0,1.0,1.0,25.0,40.0,3525.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3525.65,10.0 +2019-01-23 21:00:00,25.0,1.0,1.0,25.0,40.0,3358.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3358.3,10.0 +2019-01-23 22:00:00,25.0,1.0,1.0,25.0,40.0,3244.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3244.15,10.0 +2019-01-23 23:00:00,25.0,1.0,1.0,25.0,40.0,3046.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3046.5,10.0 +2019-01-24 00:00:00,25.0,1.0,1.0,25.0,40.0,2935.3,2.0,1.0,1.0,1.0,45.05,20.0,2935.3,10.0 +2019-01-24 01:00:00,25.0,1.0,1.0,25.0,40.0,2834.55,2.0,1.0,1.0,1.0,45.05,20.0,2834.55,10.0 +2019-01-24 02:00:00,25.0,1.0,1.0,25.0,40.0,2751.2,2.0,1.0,1.0,1.0,45.05,20.0,2751.2,10.0 +2019-01-24 03:00:00,25.0,1.0,1.0,25.0,40.0,2748.6,2.0,1.0,1.0,1.0,45.05,20.0,2748.6,10.0 +2019-01-24 04:00:00,25.0,1.0,1.0,25.0,40.0,2780.95,2.0,1.0,1.0,1.0,45.05,20.0,2780.95,10.0 +2019-01-24 05:00:00,25.0,1.0,1.0,25.0,40.0,2926.25,2.0,1.0,1.0,1.0,45.05,20.0,2926.25,10.0 +2019-01-24 06:00:00,25.0,1.0,1.0,25.0,40.0,3254.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3254.55,10.0 +2019-01-24 07:00:00,25.0,1.0,1.0,25.0,40.0,3570.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3570.3,10.0 +2019-01-24 08:00:00,25.0,1.0,1.0,25.0,40.0,3695.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3695.75,10.0 +2019-01-24 09:00:00,25.0,1.0,1.0,25.0,40.0,3721.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3721.25,10.0 +2019-01-24 10:00:00,25.0,1.0,1.0,25.0,40.0,3761.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3761.8,10.0 +2019-01-24 11:00:00,25.0,1.0,1.0,25.0,40.0,3783.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3783.55,10.0 +2019-01-24 12:00:00,25.0,1.0,1.0,25.0,40.0,3758.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3758.9,10.0 +2019-01-24 13:00:00,25.0,1.0,1.0,25.0,40.0,3741.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3741.5,10.0 +2019-01-24 14:00:00,25.0,1.0,1.0,25.0,40.0,3699.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3699.55,10.0 +2019-01-24 15:00:00,25.0,1.0,1.0,25.0,40.0,3693.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3693.6,10.0 +2019-01-24 16:00:00,25.0,1.0,1.0,25.0,40.0,3674.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3674.7,10.0 +2019-01-24 17:00:00,25.0,1.0,1.0,25.0,40.0,3785.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3785.1,10.0 +2019-01-24 18:00:00,25.0,1.0,1.0,25.0,40.0,3765.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3765.85,10.0 +2019-01-24 19:00:00,25.0,1.0,1.0,25.0,40.0,3681.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3681.55,10.0 +2019-01-24 20:00:00,25.0,1.0,1.0,25.0,40.0,3485.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3485.7,10.0 +2019-01-24 21:00:00,25.0,1.0,1.0,25.0,40.0,3319.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3319.35,10.0 +2019-01-24 22:00:00,25.0,1.0,1.0,25.0,40.0,3188.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3188.45,10.0 +2019-01-24 23:00:00,25.0,1.0,1.0,25.0,40.0,3007.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3007.1,10.0 +2019-01-25 00:00:00,25.0,1.0,1.0,25.0,40.0,2894.8,2.0,1.0,1.0,1.0,45.05,20.0,2894.8,10.0 +2019-01-25 01:00:00,25.0,1.0,1.0,25.0,40.0,2806.55,2.0,1.0,1.0,1.0,45.05,20.0,2806.55,10.0 +2019-01-25 02:00:00,25.0,1.0,1.0,25.0,40.0,2763.65,2.0,1.0,1.0,1.0,45.05,20.0,2763.65,10.0 +2019-01-25 03:00:00,25.0,1.0,1.0,25.0,40.0,2753.75,2.0,1.0,1.0,1.0,45.05,20.0,2753.75,10.0 +2019-01-25 04:00:00,25.0,1.0,1.0,25.0,40.0,2809.55,2.0,1.0,1.0,1.0,45.05,20.0,2809.55,10.0 +2019-01-25 05:00:00,25.0,1.0,1.0,25.0,40.0,2960.8,2.0,1.0,1.0,1.0,45.05,20.0,2960.8,10.0 +2019-01-25 06:00:00,25.0,1.0,1.0,25.0,40.0,3294.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3294.9,10.0 +2019-01-25 07:00:00,25.0,1.0,1.0,25.0,40.0,3606.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3606.15,10.0 +2019-01-25 08:00:00,25.0,1.0,1.0,25.0,40.0,3721.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3721.95,10.0 +2019-01-25 09:00:00,25.0,1.0,1.0,25.0,40.0,3751.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3751.5,10.0 +2019-01-25 10:00:00,25.0,1.0,1.0,25.0,40.0,3770.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3770.15,10.0 +2019-01-25 11:00:00,25.0,1.0,1.0,25.0,40.0,3778.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3778.15,10.0 +2019-01-25 12:00:00,25.0,1.0,1.0,25.0,40.0,3739.35,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3739.35,10.0 +2019-01-25 13:00:00,25.0,1.0,1.0,25.0,40.0,3668.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3668.5,10.0 +2019-01-25 14:00:00,25.0,1.0,1.0,25.0,40.0,3612.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3612.75,10.0 +2019-01-25 15:00:00,25.0,1.0,1.0,25.0,40.0,3570.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3570.25,10.0 +2019-01-25 16:00:00,25.0,1.0,1.0,25.0,40.0,3578.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3578.05,10.0 +2019-01-25 17:00:00,25.0,1.0,1.0,25.0,40.0,3717.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3717.75,10.0 +2019-01-25 18:00:00,25.0,1.0,1.0,25.0,40.0,3722.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3722.75,10.0 +2019-01-25 19:00:00,25.0,1.0,1.0,25.0,40.0,3642.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3642.05,10.0 +2019-01-25 20:00:00,25.0,1.0,1.0,25.0,40.0,3445.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3445.55,10.0 +2019-01-25 21:00:00,25.0,1.0,1.0,25.0,40.0,3303.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3303.65,10.0 +2019-01-25 22:00:00,25.0,1.0,1.0,25.0,40.0,3208.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3208.5,10.0 +2019-01-25 23:00:00,25.0,1.0,1.0,25.0,40.0,3034.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3034.9,10.0 +2019-01-26 00:00:00,25.0,1.0,1.0,25.0,40.0,2898.9,2.0,1.0,1.0,1.0,45.05,20.0,2898.9,10.0 +2019-01-26 01:00:00,25.0,1.0,1.0,25.0,40.0,2769.3,2.0,1.0,1.0,1.0,45.05,20.0,2769.3,10.0 +2019-01-26 02:00:00,25.0,1.0,1.0,25.0,40.0,2691.6,2.0,1.0,1.0,1.0,45.05,20.0,2691.6,10.0 +2019-01-26 03:00:00,25.0,1.0,1.0,25.0,40.0,2647.4,2.0,1.0,1.0,1.0,45.05,20.0,2647.4,10.0 +2019-01-26 04:00:00,25.0,1.0,1.0,25.0,40.0,2631.5,2.0,1.0,1.0,1.0,45.05,20.0,2631.5,10.0 +2019-01-26 05:00:00,25.0,1.0,1.0,25.0,40.0,2618.6,2.0,1.0,1.0,1.0,45.05,20.0,2618.6,10.0 +2019-01-26 06:00:00,25.0,1.0,1.0,25.0,40.0,2655.6,2.0,1.0,1.0,1.0,45.05,20.0,2655.6,10.0 +2019-01-26 07:00:00,25.0,1.0,1.0,25.0,40.0,2806.95,2.0,1.0,1.0,1.0,45.05,20.0,2806.95,10.0 +2019-01-26 08:00:00,25.0,1.0,1.0,25.0,40.0,2992.7,2.0,1.0,1.0,1.0,45.05,20.0,2992.7,10.0 +2019-01-26 09:00:00,25.0,1.0,1.0,25.0,40.0,3145.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3145.4,10.0 +2019-01-26 10:00:00,25.0,1.0,1.0,25.0,40.0,3240.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3240.45,10.0 +2019-01-26 11:00:00,25.0,1.0,1.0,25.0,40.0,3273.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3273.6,10.0 +2019-01-26 12:00:00,25.0,1.0,1.0,25.0,40.0,3245.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3245.1,10.0 +2019-01-26 13:00:00,25.0,1.0,1.0,25.0,40.0,3144.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3144.1,10.0 +2019-01-26 14:00:00,25.0,1.0,1.0,25.0,40.0,3073.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3073.8,10.0 +2019-01-26 15:00:00,25.0,1.0,1.0,25.0,40.0,3041.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3041.6,10.0 +2019-01-26 16:00:00,25.0,1.0,1.0,25.0,40.0,3050.1000000000004,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3050.1000000000004,10.0 +2019-01-26 17:00:00,25.0,1.0,1.0,25.0,40.0,3205.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3205.0,10.0 +2019-01-26 18:00:00,25.0,1.0,1.0,25.0,40.0,3238.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3238.7,10.0 +2019-01-26 19:00:00,25.0,1.0,1.0,25.0,40.0,3131.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3131.4,10.0 +2019-01-26 20:00:00,25.0,1.0,1.0,25.0,40.0,2947.2,2.0,1.0,1.0,1.0,45.05,20.0,2947.2,10.0 +2019-01-26 21:00:00,25.0,1.0,1.0,25.0,40.0,2853.8,2.0,1.0,1.0,1.0,45.05,20.0,2853.8,10.0 +2019-01-26 22:00:00,25.0,1.0,1.0,25.0,40.0,2811.6,2.0,1.0,1.0,1.0,45.05,20.0,2811.6,10.0 +2019-01-26 23:00:00,25.0,1.0,1.0,25.0,40.0,2658.4,2.0,1.0,1.0,1.0,45.05,20.0,2658.4,10.0 +2019-01-27 00:00:00,25.0,1.0,1.0,25.0,40.0,2515.4,2.0,1.0,1.0,1.0,45.05,20.0,2515.4,10.0 +2019-01-27 01:00:00,25.0,1.0,1.0,25.0,40.0,2415.35,2.0,1.0,1.0,1.0,45.05,20.0,2415.35,10.0 +2019-01-27 02:00:00,25.0,1.0,1.0,25.0,40.0,2342.15,2.0,1.0,1.0,1.0,45.05,20.0,2342.15,10.0 +2019-01-27 03:00:00,25.0,1.0,1.0,25.0,40.0,2325.15,2.0,1.0,1.0,1.0,45.05,20.0,2325.15,10.0 +2019-01-27 04:00:00,25.0,1.0,1.0,25.0,40.0,2324.15,2.0,1.0,1.0,1.0,45.05,20.0,2324.15,10.0 +2019-01-27 05:00:00,25.0,1.0,1.0,25.0,40.0,2303.3500000000004,2.0,1.0,1.0,1.0,45.05,20.0,2303.3500000000004,10.0 +2019-01-27 06:00:00,25.0,1.0,1.0,25.0,40.0,2262.6,2.0,1.0,1.0,1.0,45.05,20.0,2262.6,10.0 +2019-01-27 07:00:00,25.0,1.0,1.0,25.0,40.0,2375.0,2.0,1.0,1.0,1.0,45.05,20.0,2375.0,10.0 +2019-01-27 08:00:00,25.0,1.0,1.0,25.0,40.0,2508.5,2.0,1.0,1.0,1.0,45.05,20.0,2508.5,10.0 +2019-01-27 09:00:00,25.0,1.0,1.0,25.0,40.0,2685.5,2.0,1.0,1.0,1.0,45.05,20.0,2685.5,10.0 +2019-01-27 10:00:00,25.0,1.0,1.0,25.0,40.0,2815.7,2.0,1.0,1.0,1.0,45.05,20.0,2815.7,10.0 +2019-01-27 11:00:00,25.0,1.0,1.0,25.0,40.0,2944.8,2.0,1.0,1.0,1.0,45.05,20.0,2944.8,10.0 +2019-01-27 12:00:00,25.0,1.0,1.0,25.0,40.0,2934.5,2.0,1.0,1.0,1.0,45.05,20.0,2934.5,10.0 +2019-01-27 13:00:00,25.0,1.0,1.0,25.0,40.0,2812.5,2.0,1.0,1.0,1.0,45.05,20.0,2812.5,10.0 +2019-01-27 14:00:00,25.0,1.0,1.0,25.0,40.0,2738.9500000000003,2.0,1.0,1.0,1.0,45.05,20.0,2738.9500000000003,10.0 +2019-01-27 15:00:00,25.0,1.0,1.0,25.0,40.0,2702.65,2.0,1.0,1.0,1.0,45.05,20.0,2702.65,10.0 +2019-01-27 16:00:00,25.0,1.0,1.0,25.0,40.0,2739.95,2.0,1.0,1.0,1.0,45.05,20.0,2739.95,10.0 +2019-01-27 17:00:00,25.0,1.0,1.0,25.0,40.0,2941.15,2.0,1.0,1.0,1.0,45.05,20.0,2941.15,10.0 +2019-01-27 18:00:00,25.0,1.0,1.0,25.0,40.0,3023.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3023.1,10.0 +2019-01-27 19:00:00,25.0,1.0,1.0,25.0,40.0,2978.5,2.0,1.0,1.0,1.0,45.05,20.0,2978.5,10.0 +2019-01-27 20:00:00,25.0,1.0,1.0,25.0,40.0,2867.25,2.0,1.0,1.0,1.0,45.05,20.0,2867.25,10.0 +2019-01-27 21:00:00,25.0,1.0,1.0,25.0,40.0,2793.2,2.0,1.0,1.0,1.0,45.05,20.0,2793.2,10.0 +2019-01-27 22:00:00,25.0,1.0,1.0,25.0,40.0,2803.6,2.0,1.0,1.0,1.0,45.05,20.0,2803.6,10.0 +2019-01-27 23:00:00,25.0,1.0,1.0,25.0,40.0,2655.95,2.0,1.0,1.0,1.0,45.05,20.0,2655.95,10.0 +2019-01-28 00:00:00,25.0,1.0,1.0,25.0,40.0,2541.65,2.0,1.0,1.0,1.0,45.05,20.0,2541.65,10.0 +2019-01-28 01:00:00,25.0,1.0,1.0,25.0,40.0,2456.9,2.0,1.0,1.0,1.0,45.05,20.0,2456.9,10.0 +2019-01-28 02:00:00,25.0,1.0,1.0,25.0,40.0,2434.5,2.0,1.0,1.0,1.0,45.05,20.0,2434.5,10.0 +2019-01-28 03:00:00,25.0,1.0,1.0,25.0,40.0,2458.25,2.0,1.0,1.0,1.0,45.05,20.0,2458.25,10.0 +2019-01-28 04:00:00,25.0,1.0,1.0,25.0,40.0,2547.7,2.0,1.0,1.0,1.0,45.05,20.0,2547.7,10.0 +2019-01-28 05:00:00,25.0,1.0,1.0,25.0,40.0,2729.85,2.0,1.0,1.0,1.0,45.05,20.0,2729.85,10.0 +2019-01-28 06:00:00,25.0,1.0,1.0,25.0,40.0,3171.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3171.25,10.0 +2019-01-28 07:00:00,25.0,1.0,1.0,25.0,40.0,3522.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3522.25,10.0 +2019-01-28 08:00:00,25.0,1.0,1.0,25.0,40.0,3647.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3647.2,10.0 +2019-01-28 09:00:00,25.0,1.0,1.0,25.0,40.0,3695.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3695.6,10.0 +2019-01-28 10:00:00,25.0,1.0,1.0,25.0,40.0,3734.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3734.15,10.0 +2019-01-28 11:00:00,25.0,1.0,1.0,25.0,40.0,3768.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3768.0,10.0 +2019-01-28 12:00:00,25.0,1.0,1.0,25.0,40.0,3765.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3765.65,10.0 +2019-01-28 13:00:00,25.0,1.0,1.0,25.0,40.0,3752.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3752.15,10.0 +2019-01-28 14:00:00,25.0,1.0,1.0,25.0,40.0,3713.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3713.8,10.0 +2019-01-28 15:00:00,25.0,1.0,1.0,25.0,40.0,3685.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3685.2,10.0 +2019-01-28 16:00:00,25.0,1.0,1.0,25.0,40.0,3663.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3663.9,10.0 +2019-01-28 17:00:00,25.0,1.0,1.0,25.0,40.0,3766.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3766.4,10.0 +2019-01-28 18:00:00,25.0,1.0,1.0,25.0,40.0,3759.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3759.45,10.0 +2019-01-28 19:00:00,25.0,1.0,1.0,25.0,40.0,3646.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3646.3,10.0 +2019-01-28 20:00:00,25.0,1.0,1.0,25.0,40.0,3455.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3455.25,10.0 +2019-01-28 21:00:00,25.0,1.0,1.0,25.0,40.0,3284.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3284.4,10.0 +2019-01-28 22:00:00,25.0,1.0,1.0,25.0,40.0,3157.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3157.0,10.0 +2019-01-28 23:00:00,25.0,1.0,1.0,25.0,40.0,2950.95,2.0,1.0,1.0,1.0,45.05,20.0,2950.95,10.0 +2019-01-29 00:00:00,25.0,1.0,1.0,25.0,40.0,2789.7,2.0,1.0,1.0,1.0,45.05,20.0,2789.7,10.0 +2019-01-29 01:00:00,25.0,1.0,1.0,25.0,40.0,2701.05,2.0,1.0,1.0,1.0,45.05,20.0,2701.05,10.0 +2019-01-29 02:00:00,25.0,1.0,1.0,25.0,40.0,2658.4,2.0,1.0,1.0,1.0,45.05,20.0,2658.4,10.0 +2019-01-29 03:00:00,25.0,1.0,1.0,25.0,40.0,2682.3,2.0,1.0,1.0,1.0,45.05,20.0,2682.3,10.0 +2019-01-29 04:00:00,25.0,1.0,1.0,25.0,40.0,2740.65,2.0,1.0,1.0,1.0,45.05,20.0,2740.65,10.0 +2019-01-29 05:00:00,25.0,1.0,1.0,25.0,40.0,2923.15,2.0,1.0,1.0,1.0,45.05,20.0,2923.15,10.0 +2019-01-29 06:00:00,25.0,1.0,1.0,25.0,40.0,3293.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3293.75,10.0 +2019-01-29 07:00:00,25.0,1.0,1.0,25.0,40.0,3580.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3580.75,10.0 +2019-01-29 08:00:00,25.0,1.0,1.0,25.0,40.0,3679.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3679.8,10.0 +2019-01-29 09:00:00,25.0,1.0,1.0,25.0,40.0,3693.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3693.65,10.0 +2019-01-29 10:00:00,25.0,1.0,1.0,25.0,40.0,3696.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3696.55,10.0 +2019-01-29 11:00:00,25.0,1.0,1.0,25.0,40.0,3723.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3723.6,10.0 +2019-01-29 12:00:00,25.0,1.0,1.0,25.0,40.0,3684.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3684.75,10.0 +2019-01-29 13:00:00,25.0,1.0,1.0,25.0,40.0,3655.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3655.75,10.0 +2019-01-29 14:00:00,25.0,1.0,1.0,25.0,40.0,3597.4,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3597.4,10.0 +2019-01-29 15:00:00,25.0,1.0,1.0,25.0,40.0,3559.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3559.3,10.0 +2019-01-29 16:00:00,25.0,1.0,1.0,25.0,40.0,3508.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3508.7,10.0 +2019-01-29 17:00:00,25.0,1.0,1.0,25.0,40.0,3659.3,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3659.3,10.0 +2019-01-29 18:00:00,25.0,1.0,1.0,25.0,40.0,3720.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3720.45,10.0 +2019-01-29 19:00:00,25.0,1.0,1.0,25.0,40.0,3637.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3637.25,10.0 +2019-01-29 20:00:00,25.0,1.0,1.0,25.0,40.0,3468.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3468.85,10.0 +2019-01-29 21:00:00,25.0,1.0,1.0,25.0,40.0,3284.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3284.0,10.0 +2019-01-29 22:00:00,25.0,1.0,1.0,25.0,40.0,3158.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3158.5,10.0 +2019-01-29 23:00:00,25.0,1.0,1.0,25.0,40.0,2959.9,2.0,1.0,1.0,1.0,45.05,20.0,2959.9,10.0 +2019-01-30 00:00:00,25.0,1.0,1.0,25.0,40.0,2837.05,2.0,1.0,1.0,1.0,45.05,20.0,2837.05,10.0 +2019-01-30 01:00:00,25.0,1.0,1.0,25.0,40.0,2760.1,2.0,1.0,1.0,1.0,45.05,20.0,2760.1,10.0 +2019-01-30 02:00:00,25.0,1.0,1.0,25.0,40.0,2710.6,2.0,1.0,1.0,1.0,45.05,20.0,2710.6,10.0 +2019-01-30 03:00:00,25.0,1.0,1.0,25.0,40.0,2713.25,2.0,1.0,1.0,1.0,45.05,20.0,2713.25,10.0 +2019-01-30 04:00:00,25.0,1.0,1.0,25.0,40.0,2770.75,2.0,1.0,1.0,1.0,45.05,20.0,2770.75,10.0 +2019-01-30 05:00:00,25.0,1.0,1.0,25.0,40.0,2947.9,2.0,1.0,1.0,1.0,45.05,20.0,2947.9,10.0 +2019-01-30 06:00:00,25.0,1.0,1.0,25.0,40.0,3309.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3309.5,10.0 +2019-01-30 07:00:00,25.0,1.0,1.0,25.0,40.0,3575.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3575.2,10.0 +2019-01-30 08:00:00,25.0,1.0,1.0,25.0,40.0,3642.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3642.55,10.0 +2019-01-30 09:00:00,25.0,1.0,1.0,25.0,40.0,3642.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3642.1,10.0 +2019-01-30 10:00:00,25.0,1.0,1.0,25.0,40.0,3753.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3753.8,10.0 +2019-01-30 11:00:00,25.0,1.0,1.0,25.0,40.0,3796.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3796.9,10.0 +2019-01-30 12:00:00,25.0,1.0,1.0,25.0,40.0,3767.0,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3767.0,10.0 +2019-01-30 13:00:00,25.0,1.0,1.0,25.0,40.0,3757.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3757.8,10.0 +2019-01-30 14:00:00,25.0,1.0,1.0,25.0,40.0,3691.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3691.6,10.0 +2019-01-30 15:00:00,25.0,1.0,1.0,25.0,40.0,3638.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3638.6,10.0 +2019-01-30 16:00:00,25.0,1.0,1.0,25.0,40.0,3608.7,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3608.7,10.0 +2019-01-30 17:00:00,25.0,1.0,1.0,25.0,40.0,3733.15,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3733.15,10.0 +2019-01-30 18:00:00,25.0,1.0,1.0,25.0,40.0,3754.65,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3754.65,10.0 +2019-01-30 19:00:00,25.0,1.0,1.0,25.0,40.0,3675.1,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3675.1,10.0 +2019-01-30 20:00:00,25.0,1.0,1.0,25.0,40.0,3490.9,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3490.9,10.0 +2019-01-30 21:00:00,25.0,1.0,1.0,25.0,40.0,3313.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3313.05,10.0 +2019-01-30 22:00:00,25.0,1.0,1.0,25.0,40.0,3191.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3191.25,10.0 +2019-01-30 23:00:00,25.0,1.0,1.0,25.0,40.0,2999.8,2.0,1.0,1.0,1.0,45.05,20.0,2999.8,10.0 +2019-01-31 00:00:00,25.0,1.0,1.0,25.0,40.0,2855.3,2.0,1.0,1.0,1.0,45.05,20.0,2855.3,10.0 +2019-01-31 01:00:00,25.0,1.0,1.0,25.0,40.0,2764.9,2.0,1.0,1.0,1.0,45.05,20.0,2764.9,10.0 +2019-01-31 02:00:00,25.0,1.0,1.0,25.0,40.0,2717.45,2.0,1.0,1.0,1.0,45.05,20.0,2717.45,10.0 +2019-01-31 03:00:00,25.0,1.0,1.0,25.0,40.0,2719.7,2.0,1.0,1.0,1.0,45.05,20.0,2719.7,10.0 +2019-01-31 04:00:00,25.0,1.0,1.0,25.0,40.0,2775.65,2.0,1.0,1.0,1.0,45.05,20.0,2775.65,10.0 +2019-01-31 05:00:00,25.0,1.0,1.0,25.0,40.0,2917.6,2.0,1.0,1.0,1.0,45.05,20.0,2917.6,10.0 +2019-01-31 06:00:00,25.0,1.0,1.0,25.0,40.0,3240.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3240.25,10.0 +2019-01-31 07:00:00,25.0,1.0,1.0,25.0,40.0,3533.85,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3533.85,10.0 +2019-01-31 08:00:00,25.0,1.0,1.0,25.0,40.0,3623.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3623.45,10.0 +2019-01-31 09:00:00,25.0,1.0,1.0,25.0,40.0,3630.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3630.25,10.0 +2019-01-31 10:00:00,25.0,1.0,1.0,25.0,40.0,3674.6,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3674.6,10.0 +2019-01-31 11:00:00,25.0,1.0,1.0,25.0,40.0,3687.2,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3687.2,10.0 +2019-01-31 12:00:00,25.0,1.0,1.0,25.0,40.0,3647.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3647.25,10.0 +2019-01-31 13:00:00,25.0,1.0,1.0,25.0,40.0,3609.55,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3609.55,10.0 +2019-01-31 14:00:00,25.0,1.0,1.0,25.0,40.0,3554.75,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3554.75,10.0 +2019-01-31 15:00:00,25.0,1.0,1.0,25.0,40.0,3513.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3513.25,10.0 +2019-01-31 16:00:00,25.0,1.0,1.0,25.0,40.0,3482.8,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3482.8,10.0 +2019-01-31 17:00:00,25.0,1.0,1.0,25.0,40.0,3643.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3643.25,10.0 +2019-01-31 18:00:00,25.0,1.0,1.0,25.0,40.0,3688.5,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3688.5,10.0 +2019-01-31 19:00:00,25.0,1.0,1.0,25.0,40.0,3630.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3630.45,10.0 +2019-01-31 20:00:00,25.0,1.0,1.0,25.0,40.0,3467.25,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3467.25,10.0 +2019-01-31 21:00:00,25.0,1.0,1.0,25.0,40.0,3308.05,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3308.05,10.0 +2019-01-31 22:00:00,25.0,1.0,1.0,25.0,40.0,3203.45,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3203.45,10.0 +2019-01-31 23:00:00,25.0,1.0,1.0,25.0,40.0,3026.95,2.0,1.0,1.0,1.0,53.50000000000001,20.0,3026.95,10.0 +2019-02-01 00:00:00,25.0,1.0,1.0,25.0,40.0,2892.2,2.0,1.0,1.0,1.0,45.05,20.0,2892.2,10.0 diff --git a/examples/inputs/example_01a/forecasts_df.csv.license b/examples/inputs/example_01a/forecasts_df.csv.license new file mode 100644 index 000000000..a6ae06366 --- /dev/null +++ b/examples/inputs/example_01a/forecasts_df.csv.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: ASSUME Developers + +SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/examples/inputs/example_02a/config.yaml b/examples/inputs/example_02a/config.yaml index fb8f7488b..e5c6a9961 100644 --- a/examples/inputs/example_02a/config.yaml +++ b/examples/inputs/example_02a/config.yaml @@ -1,6 +1,45 @@ -# SPDX-FileCopyrightText: ASSUME Developers -# -# SPDX-License-Identifier: AGPL-3.0-or-later +tiny: + start_date: 2019-01-01 00:00 + end_date: 2019-01-05 00:00 + save_frequency_hours: null + time_step: 1h + learning_mode: true + learning_config: + algorithm: ppo + continue_learning: false + device: cpu + gradient_steps: 10 + learning_rate: 0.0003 + max_bid_price: 100 + trained_policies_save_path: null + training_episodes: 100 + validation_episodes_interval: 10 + matd3: + actor_architecture: mlp + batch_size: 64 + episodes_collecting_initial_experience: 3 + gamma: 0.99 + noise_dt: 1 + noise_scale: 1 + noise_sigma: 0.1 + train_freq: 24h + markets_config: + EOM: + market_mechanism: pay_as_clear + maximum_bid_price: 3000 + maximum_bid_volume: 100000 + minimum_bid_price: -500 + opening_duration: 1h + opening_frequency: 1h + operator: EOM_operator + price_unit: EUR/MWh + product_type: energy + products: + - count: 1 + duration: 1h + first_delivery: 1h + volume_unit: MWh + base: start_date: 2019-03-01 00:00 @@ -10,25 +49,24 @@ base: save_frequency_hours: null learning_config: - continue_learning: false - trained_policies_save_path: null - max_bid_price: 100 algorithm: matd3 - learning_rate: 0.001 - learning_rate_schedule: linear - training_episodes: 50 - episodes_collecting_initial_experience: 5 - train_freq: 24h - gradient_steps: -1 - batch_size: 256 - gamma: 0.99 + continue_learning: false device: cpu - action_noise_schedule: linear - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - + gradient_steps: -1 + learning_rate: 0.001 + max_bid_price: 100 + trained_policies_save_path: null + training_episodes: 100 + validation_episodes_interval: 10 + learning_mode: true + matd3: + actor_architecture: mlp + batch_size: 64 + episodes_collecting_initial_experience: 3 + gamma: 0.99 + noise_dt: 1 + noise_scale: 1 + noise_sigma: 0.1 markets_config: EOM: operator: EOM_operator @@ -49,32 +87,44 @@ base: base_lstm: start_date: 2019-03-01 00:00 - end_date: 2019-03-31 00:00 time_step: 1h - learning_mode: true - save_frequency_hours: null + +base_ppo: + save_frequency_hours: null + start_date: 2019-03-01 00:00 + time_step: 1h + end_date: 2019-03-31 00:00 + learning_mode: true learning_config: + algorithm: ppo continue_learning: false - trained_policies_save_path: null - max_bid_price: 100 - algorithm: matd3 - learning_rate: 0.001 - training_episodes: 50 - episodes_collecting_initial_experience: 5 - train_freq: 24h - gradient_steps: -1 - batch_size: 256 - gamma: 0.99 device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 - actor_architecture: lstm - + gradient_steps: 10 + learning_rate: 0.0003 + max_bid_price: 100 + trained_policies_save_path: null + training_episodes: 100 + validation_episodes_interval: 10 + matd3: + actor_architecture: mlp + batch_size: 64 + episodes_collecting_initial_experience: 3 + gamma: 0.99 + noise_dt: 1 + noise_scale: 1 + noise_sigma: 0.1 + train_freq: 24h + ppo: + actor_architecture: dist + batch_size: 11 + clip_ratio: 0.05 + entropy_coef: 0.005 + gae_lambda: 0.95 + gamma: 0.99 + max_grad_norm: 0.3 + train_freq: 33h + vf_coef: 0.75 markets_config: EOM: operator: EOM_operator @@ -93,32 +143,42 @@ base_lstm: price_unit: EUR/MWh market_mechanism: pay_as_clear -tiny: + +tiny_ppo: + save_frequency_hours: null start_date: 2019-01-01 00:00 - end_date: 2019-01-05 00:00 time_step: 1h + end_date: 2019-01-05 00:00 learning_mode: true - save_frequency_hours: null - learning_config: + algorithm: ppo continue_learning: false + device: cpu + gradient_steps: 10 + learning_rate: 0.0003 trained_policies_save_path: null + training_episodes: 100 + validation_episodes_interval: 10 max_bid_price: 100 - algorithm: matd3 - learning_rate: 0.001 - training_episodes: 10 - episodes_collecting_initial_experience: 3 - train_freq: 24h - gradient_steps: -1 - batch_size: 64 - gamma: 0.99 - device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - actor_architecture: mlp - + matd3: + actor_architecture: mlp + batch_size: 64 + episodes_collecting_initial_experience: 3 + gamma: 0.99 + noise_dt: 1 + noise_scale: 1 + noise_sigma: 0.1 + train_freq: 24h + ppo: + actor_architecture: dist + batch_size: 11 + clip_ratio: 0.05 + entropy_coef: 0.005 + gae_lambda: 0.95 + gamma: 0.99 + max_grad_norm: 0.3 + train_freq: 33h + vf_coef: 0.75 markets_config: EOM: operator: EOM_operator @@ -131,8 +191,4 @@ tiny: opening_frequency: 1h opening_duration: 1h volume_unit: MWh - maximum_bid_volume: 100000 - maximum_bid_price: 3000 - minimum_bid_price: -500 - price_unit: EUR/MWh - market_mechanism: pay_as_clear + diff --git a/examples/inputs/example_02a/config_backup.yaml b/examples/inputs/example_02a/config_backup.yaml new file mode 100644 index 000000000..8da7961f6 --- /dev/null +++ b/examples/inputs/example_02a/config_backup.yaml @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: ASSUME Developers +# +# SPDX-License-Identifier: AGPL-3.0-or-later + +tiny: + start_date: 2019-01-01 00:00 + end_date: 2019-01-05 00:00 + time_step: 1h + save_frequency_hours: null + learning_mode: True + + learning_config: + continue_learning: False + trained_policies_save_path: null + max_bid_price: 100 + algorithm: matd3 + actor_architecture: mlp + learning_rate: 0.001 + training_episodes: 10 + episodes_collecting_initial_experience: 3 + train_freq: 24h + gradient_steps: -1 + batch_size: 64 + gamma: 0.99 + device: cpu + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 + validation_episodes_interval: 5 + + markets_config: + EOM: + operator: EOM_operator + product_type: energy + products: + - duration: 1h + count: 1 + first_delivery: 1h + opening_frequency: 1h + opening_duration: 1h + volume_unit: MWh + maximum_bid_volume: 100000 + maximum_bid_price: 3000 + minimum_bid_price: -500 + price_unit: EUR/MWh + market_mechanism: pay_as_clear + + +base: + start_date: 2019-03-01 00:00 + end_date: 2019-03-31 00:00 + time_step: 1h + save_frequency_hours: null + learning_mode: True + + learning_config: + continue_learning: False + trained_policies_save_path: null + max_bid_price: 100 + algorithm: matd3 + actor_architecture: mlp + learning_rate: 0.001 + training_episodes: 50 + episodes_collecting_initial_experience: 5 + train_freq: 24h + gradient_steps: -1 + batch_size: 256 + gamma: 0.99 + device: cpu + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 + validation_episodes_interval: 5 + early_stopping_steps: 10 + early_stopping_threshold: 0.05 + + markets_config: + EOM: + operator: EOM_operator + product_type: energy + products: + - duration: 1h + count: 1 + first_delivery: 1h + opening_frequency: 1h + opening_duration: 1h + volume_unit: MWh + maximum_bid_volume: 100000 + maximum_bid_price: 3000 + minimum_bid_price: -500 + price_unit: EUR/MWh + market_mechanism: pay_as_clear + +base_lstm: + start_date: 2019-03-01 00:00 + end_date: 2019-03-31 00:00 + time_step: 1h + save_frequency_hours: null + learning_mode: True + + learning_config: + continue_learning: False + trained_policies_save_path: null + max_bid_price: 100 + algorithm: matd3 + actor_architecture: lstm + learning_rate: 0.001 + training_episodes: 50 + episodes_collecting_initial_experience: 5 + train_freq: 24h + gradient_steps: -1 + batch_size: 256 + gamma: 0.99 + device: cpu + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 + validation_episodes_interval: 5 + early_stopping_steps: 10 + early_stopping_threshold: 0.05 + + markets_config: + EOM: + operator: EOM_operator + product_type: energy + products: + - duration: 1h + count: 1 + first_delivery: 1h + opening_frequency: 1h + opening_duration: 1h + volume_unit: MWh + maximum_bid_volume: 100000 + maximum_bid_price: 3000 + minimum_bid_price: -500 + price_unit: EUR/MWh + market_mechanism: pay_as_clear diff --git a/examples/inputs/example_02a/forecasts_df.csv b/examples/inputs/example_02a/forecasts_df.csv new file mode 100644 index 000000000..24ef716d7 --- /dev/null +++ b/examples/inputs/example_02a/forecasts_df.csv @@ -0,0 +1,122 @@ +,fuel_price_natural gas,fuel_price_co2,availability_pp_4,availability_pp_1,availability_pp_5,fuel_price_oil,residual_load_EOM,price_EOM,availability_pp_7,availability_pp_2,availability_pp_6,fuel_price_hard coal,fuel_price_biomass,fuel_price_uranium,demand_EOM,fuel_price_lignite,availability_pp_3 +2019-01-01 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4352.7,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4352.7,1.8,1.0 +2019-01-01 01:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4180.2,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4180.2,1.8,1.0 +2019-01-01 02:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4011.3,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4011.3,1.8,1.0 +2019-01-01 03:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3949.0,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3949.0,1.8,1.0 +2019-01-01 04:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3927.3,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3927.3,1.8,1.0 +2019-01-01 05:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3881.8,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3881.8,1.8,1.0 +2019-01-01 06:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3816.8,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3816.8,1.8,1.0 +2019-01-01 07:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3889.7,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3889.7,1.8,1.0 +2019-01-01 08:00:00,26.0,25.0,1.0,1.0,1.0,22.0,3967.4,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,3967.4,1.8,1.0 +2019-01-01 09:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4221.5,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4221.5,1.8,1.0 +2019-01-01 10:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4491.2,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4491.2,1.8,1.0 +2019-01-01 11:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4787.2,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4787.2,1.8,1.0 +2019-01-01 12:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4916.5,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4916.5,1.8,1.0 +2019-01-01 13:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4888.5,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4888.5,1.8,1.0 +2019-01-01 14:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4865.0,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4865.0,1.8,1.0 +2019-01-01 15:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4896.9,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4896.9,1.8,1.0 +2019-01-01 16:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5081.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5081.1,1.8,1.0 +2019-01-01 17:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5396.799999999999,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5396.799999999999,1.8,1.0 +2019-01-01 18:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5468.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5468.9,1.8,1.0 +2019-01-01 19:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5377.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5377.8,1.8,1.0 +2019-01-01 20:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5154.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5154.7,1.8,1.0 +2019-01-01 21:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4988.9,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4988.9,1.8,1.0 +2019-01-01 22:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4884.1,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4884.1,1.8,1.0 +2019-01-01 23:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4611.0,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4611.0,1.8,1.0 +2019-01-02 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4406.4,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4406.4,1.8,1.0 +2019-01-02 01:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4238.200000000001,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4238.200000000001,1.8,1.0 +2019-01-02 02:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4187.0,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4187.0,1.8,1.0 +2019-01-02 03:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4237.8,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4237.8,1.8,1.0 +2019-01-02 04:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4408.8,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4408.8,1.8,1.0 +2019-01-02 05:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4709.7,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4709.7,1.8,1.0 +2019-01-02 06:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5273.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5273.2,1.8,1.0 +2019-01-02 07:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5812.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5812.7,1.8,1.0 +2019-01-02 08:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6132.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6132.2,1.8,1.0 +2019-01-02 09:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6299.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6299.5,1.8,1.0 +2019-01-02 10:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6415.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6415.1,1.8,1.0 +2019-01-02 11:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6544.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6544.9,1.8,1.0 +2019-01-02 12:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6573.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6573.9,1.8,1.0 +2019-01-02 13:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6504.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6504.3,1.8,1.0 +2019-01-02 14:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6382.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6382.2,1.8,1.0 +2019-01-02 15:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6370.4,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6370.4,1.8,1.0 +2019-01-02 16:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6489.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6489.2,1.8,1.0 +2019-01-02 17:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6808.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6808.3,1.8,1.0 +2019-01-02 18:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6782.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6782.9,1.8,1.0 +2019-01-02 19:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6604.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6604.8,1.8,1.0 +2019-01-02 20:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6233.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6233.6,1.8,1.0 +2019-01-02 21:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5921.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5921.2,1.8,1.0 +2019-01-02 22:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5728.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5728.8,1.8,1.0 +2019-01-02 23:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5362.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5362.3,1.8,1.0 +2019-01-03 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5068.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5068.0,1.8,1.0 +2019-01-03 01:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4972.4,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4972.4,1.8,1.0 +2019-01-03 02:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4943.7,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4943.7,1.8,1.0 +2019-01-03 03:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4932.7,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4932.7,1.8,1.0 +2019-01-03 04:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5011.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5011.2,1.8,1.0 +2019-01-03 05:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5273.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5273.1,1.8,1.0 +2019-01-03 06:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5752.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5752.0,1.8,1.0 +2019-01-03 07:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6190.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6190.0,1.8,1.0 +2019-01-03 08:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6442.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6442.3,1.8,1.0 +2019-01-03 09:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6578.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6578.9,1.8,1.0 +2019-01-03 10:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6644.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6644.6,1.8,1.0 +2019-01-03 11:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6742.900000000001,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6742.900000000001,1.8,1.0 +2019-01-03 12:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6764.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6764.3,1.8,1.0 +2019-01-03 13:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6643.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6643.6,1.8,1.0 +2019-01-03 14:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6557.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6557.3,1.8,1.0 +2019-01-03 15:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6492.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6492.1,1.8,1.0 +2019-01-03 16:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6617.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6617.0,1.8,1.0 +2019-01-03 17:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6923.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6923.8,1.8,1.0 +2019-01-03 18:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6893.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6893.5,1.8,1.0 +2019-01-03 19:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6697.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6697.5,1.8,1.0 +2019-01-03 20:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6326.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6326.8,1.8,1.0 +2019-01-03 21:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6031.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6031.2,1.8,1.0 +2019-01-03 22:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5809.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5809.7,1.8,1.0 +2019-01-03 23:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5481.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5481.8,1.8,1.0 +2019-01-04 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5165.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5165.6,1.8,1.0 +2019-01-04 01:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4975.3,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4975.3,1.8,1.0 +2019-01-04 02:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4919.3,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4919.3,1.8,1.0 +2019-01-04 03:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4965.0,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4965.0,1.8,1.0 +2019-01-04 04:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5069.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5069.7,1.8,1.0 +2019-01-04 05:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5311.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5311.5,1.8,1.0 +2019-01-04 06:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5803.299999999999,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5803.299999999999,1.8,1.0 +2019-01-04 07:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6339.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6339.1,1.8,1.0 +2019-01-04 08:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6653.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6653.5,1.8,1.0 +2019-01-04 09:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6818.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6818.5,1.8,1.0 +2019-01-04 10:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6939.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6939.7,1.8,1.0 +2019-01-04 11:00:00,26.0,25.0,1.0,1.0,1.0,22.0,7007.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,7007.6,1.8,1.0 +2019-01-04 12:00:00,26.0,25.0,1.0,1.0,1.0,22.0,7040.2,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,7040.2,1.8,1.0 +2019-01-04 13:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6935.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6935.3,1.8,1.0 +2019-01-04 14:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6755.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6755.7,1.8,1.0 +2019-01-04 15:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6618.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6618.8,1.8,1.0 +2019-01-04 16:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6752.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6752.8,1.8,1.0 +2019-01-04 17:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6955.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6955.6,1.8,1.0 +2019-01-04 18:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6888.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6888.9,1.8,1.0 +2019-01-04 19:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6687.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6687.1,1.8,1.0 +2019-01-04 20:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6292.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6292.6,1.8,1.0 +2019-01-04 21:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5997.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5997.8,1.8,1.0 +2019-01-04 22:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5814.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5814.3,1.8,1.0 +2019-01-04 23:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5421.9,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5421.9,1.8,1.0 +2019-01-05 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5116.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5116.6,1.8,1.0 +2019-01-05 01:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4899.1,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4899.1,1.8,1.0 +2019-01-05 02:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4777.1,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4777.1,1.8,1.0 +2019-01-05 03:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4728.8,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4728.8,1.8,1.0 +2019-01-05 04:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4742.1,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4742.1,1.8,1.0 +2019-01-05 05:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4749.4,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4749.4,1.8,1.0 +2019-01-05 06:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4845.3,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4845.3,1.8,1.0 +2019-01-05 07:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5087.599999999999,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5087.599999999999,1.8,1.0 +2019-01-05 08:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5428.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5428.1,1.8,1.0 +2019-01-05 09:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5772.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5772.8,1.8,1.0 +2019-01-05 10:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5991.7,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5991.7,1.8,1.0 +2019-01-05 11:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6112.1,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6112.1,1.8,1.0 +2019-01-05 12:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6101.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6101.0,1.8,1.0 +2019-01-05 13:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5979.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5979.6,1.8,1.0 +2019-01-05 14:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5861.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5861.8,1.8,1.0 +2019-01-05 15:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5816.8,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5816.8,1.8,1.0 +2019-01-05 16:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5921.299999999999,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5921.299999999999,1.8,1.0 +2019-01-05 17:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6206.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6206.0,1.8,1.0 +2019-01-05 18:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6212.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6212.5,1.8,1.0 +2019-01-05 19:00:00,26.0,25.0,1.0,1.0,1.0,22.0,6030.3,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,6030.3,1.8,1.0 +2019-01-05 20:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5661.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5661.5,1.8,1.0 +2019-01-05 21:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5425.0,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5425.0,1.8,1.0 +2019-01-05 22:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5319.6,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5319.6,1.8,1.0 +2019-01-05 23:00:00,26.0,25.0,1.0,1.0,1.0,22.0,5019.5,55.708333333333336,1.0,1.0,1.0,8.5,21.0,0.9,5019.5,1.8,1.0 +2019-01-06 00:00:00,26.0,25.0,1.0,1.0,1.0,22.0,4692.1,36.15625,1.0,1.0,1.0,8.5,21.0,0.9,4692.1,1.8,1.0 diff --git a/examples/inputs/example_02a/forecasts_df.csv.license b/examples/inputs/example_02a/forecasts_df.csv.license new file mode 100644 index 000000000..a6ae06366 --- /dev/null +++ b/examples/inputs/example_02a/forecasts_df.csv.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: ASSUME Developers + +SPDX-License-Identifier: AGPL-3.0-or-later diff --git a/examples/inputs/example_02b/config.yaml b/examples/inputs/example_02b/config.yaml index d50d8a170..e40f88867 100644 --- a/examples/inputs/example_02b/config.yaml +++ b/examples/inputs/example_02b/config.yaml @@ -14,21 +14,20 @@ base: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp - learning_rate: 0.001 - training_episodes: 100 - episodes_collecting_initial_experience: 3 - train_freq: 24h - gradient_steps: 1 - batch_size: 256 - gamma: 0.99 device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + learning_rate: 0.0003 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 + gradient_steps: 10 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: @@ -59,21 +58,20 @@ base_lstm: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - actor_architecture: lstm + device: cpu learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated training_episodes: 100 - episodes_collecting_initial_experience: 3 - train_freq: 24h - gradient_steps: 1 - batch_size: 256 - gamma: 0.99 - device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + gradient_steps: -1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/inputs/example_02c/config.yaml b/examples/inputs/example_02c/config.yaml index 0cd3f4091..694e76248 100644 --- a/examples/inputs/example_02c/config.yaml +++ b/examples/inputs/example_02c/config.yaml @@ -14,21 +14,20 @@ base: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp + device: cpu learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated training_episodes: 100 - episodes_collecting_initial_experience: 3 - train_freq: 24h - gradient_steps: 1 - batch_size: 256 - gamma: 0.99 - device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + gradient_steps: -1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/inputs/example_02d/config.yaml b/examples/inputs/example_02d/config.yaml index e52c39c78..b79e8dbe6 100644 --- a/examples/inputs/example_02d/config.yaml +++ b/examples/inputs/example_02d/config.yaml @@ -12,24 +12,22 @@ dam: learning_config: continue_learning: False trained_policies_save_path: null - max_bid_price: 200 + max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp - learning_rate: 0.0001 - training_episodes: 30 - episodes_collecting_initial_experience: 5 - train_freq: 24h + device: cpu + learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 gradient_steps: -1 - batch_size: 128 - gamma: 0.99 - device: cuda:0 - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - order_types: ["SB", "BB", "LB"] - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: @@ -62,23 +60,24 @@ tiny: learning_mode: True learning_config: - continue_learning: True + continue_learning: False trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp + device: cpu learning_rate: 0.001 - training_episodes: 3 - episodes_collecting_initial_experience: 1 - train_freq: 24h + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 gradient_steps: -1 - batch_size: 128 - gamma: 0.99 - device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/inputs/example_02e/config.yaml b/examples/inputs/example_02e/config.yaml index 5011950cf..7a08fab36 100644 --- a/examples/inputs/example_02e/config.yaml +++ b/examples/inputs/example_02e/config.yaml @@ -12,21 +12,22 @@ tiny: learning_config: continue_learning: False trained_policies_save_path: null - max_bid_price: 50 + max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp + device: cpu learning_rate: 0.001 - training_episodes: 5 - validation_episodes_interval: 2 - episodes_collecting_initial_experience: 1 - train_freq: 24h + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 gradient_steps: -1 - batch_size: 64 - gamma: 0.99 - device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: @@ -58,21 +59,20 @@ base: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - actor_architecture: mlp - learning_rate: 0.0001 - training_episodes: 200 - episodes_collecting_initial_experience: 5 - train_freq: 1000h - gradient_steps: -1 - batch_size: 256 - gamma: 0.999 device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 + gradient_steps: -1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/inputs/example_03a/config.yaml b/examples/inputs/example_03a/config.yaml index 9acc8f80a..89b146a18 100644 --- a/examples/inputs/example_03a/config.yaml +++ b/examples/inputs/example_03a/config.yaml @@ -14,20 +14,20 @@ base_case_2019: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - learning_rate: 0.001 - training_episodes: 50 - episodes_collecting_initial_experience: 5 - train_freq: 24h - gradient_steps: 1 - batch_size: 256 - gamma: 0.99 device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 + gradient_steps: -1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/inputs/example_03b/config.yaml b/examples/inputs/example_03b/config.yaml index 176397bd3..15f60ddb3 100644 --- a/examples/inputs/example_03b/config.yaml +++ b/examples/inputs/example_03b/config.yaml @@ -14,20 +14,20 @@ base_case_2021: trained_policies_save_path: null max_bid_price: 100 algorithm: matd3 - learning_rate: 0.001 - training_episodes: 50 - episodes_collecting_initial_experience: 5 - train_freq: 24h - gradient_steps: 1 - batch_size: 256 - gamma: 0.99 device: cpu - noise_sigma: 0.1 - noise_scale: 1 - noise_dt: 1 - validation_episodes_interval: 5 - early_stopping_steps: 10 - early_stopping_threshold: 0.05 + learning_rate: 0.001 + validation_episodes_interval: 10 # after how many episodes the validation starts and the policy is updated + training_episodes: 100 + gradient_steps: -1 + matd3: + actor_architecture: mlp + train_freq: 24h # how often write_to_learning_role gets called + episodes_collecting_initial_experience: 3 + batch_size: 64 + gamma: 0.99 + noise_sigma: 0.1 + noise_scale: 1 + noise_dt: 1 markets_config: EOM: diff --git a/examples/notebooks/04_reinforcement_learning_algorithm_example.ipynb b/examples/notebooks/04_reinforcement_learning_algorithm_example.ipynb index a44fcd7e4..fdfd6edfd 100644 --- a/examples/notebooks/04_reinforcement_learning_algorithm_example.ipynb +++ b/examples/notebooks/04_reinforcement_learning_algorithm_example.ipynb @@ -843,18 +843,20 @@ " \"trained_policies_save_path\": None,\n", " \"max_bid_price\": 100,\n", " \"algorithm\": \"matd3\",\n", + " \"device\": \"cpu\",\n", " \"learning_rate\": 0.001,\n", + " \"validation_episodes_interval\": 10,\n", " \"training_episodes\": 100,\n", - " \"episodes_collecting_initial_experience\": 5,\n", - " \"train_freq\": \"24h\",\n", " \"gradient_steps\": -1,\n", - " \"batch_size\": 256,\n", - " \"gamma\": 0.99,\n", - " \"device\": \"cpu\",\n", - " \"noise_sigma\": 0.1,\n", - " \"noise_scale\": 1,\n", - " \"noise_dt\": 1,\n", - " \"validation_episodes_interval\": 5,\n", + " \"matd3\":{\n", + " \"actor_architecture\": \"mlp\",\n", + " \"episodes_collecting_initial_experience\": 3,\n", + " \"batch_size\": 64,\n", + " \"gamma\": 0.99,\n", + " \"noise_sigma\": 0.1,\n", + " \"noise_scale\": 1,\n", + " \"noise_dt\": 1,\n", + " }\n", "}" ] }, @@ -880,6 +882,16 @@ " yaml.safe_dump(data, file)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcf605c7", + "metadata": {}, + "outputs": [], + "source": [ + "data" + ] + }, { "cell_type": "markdown", "id": "4bea575f",