From 5062268e4daab69bce4025a92c3cad460b2dfdf4 Mon Sep 17 00:00:00 2001 From: Sampreet Date: Sat, 22 Aug 2020 20:02:50 +0530 Subject: [PATCH 1/4] minor error fixes --- genrl/deep/agents/a2c/a2c.py | 12 +++++------- genrl/deep/agents/base.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/genrl/deep/agents/a2c/a2c.py b/genrl/deep/agents/a2c/a2c.py index 2b8f4793..1fb0025a 100644 --- a/genrl/deep/agents/a2c/a2c.py +++ b/genrl/deep/agents/a2c/a2c.py @@ -101,10 +101,10 @@ def _create_model(self) -> None: """ Creates actor critic model and initialises optimizers """ + input_dim, action_dim, discrete, action_lim = get_env_properties( + self.env, self.network + ) if isinstance(self.network, str): - input_dim, action_dim, discrete, action_lim = get_env_properties( - self.env, self.network - ) self.ac = get_model("ac", self.network)( input_dim, action_dim, @@ -114,10 +114,8 @@ def _create_model(self) -> None: discrete, action_lim=action_lim, ).to(self.device) - else: self.ac = self.network.to(self.device) - action_dim = self.network.action_dim if self.noise is not None: self.noise = self.noise( @@ -205,8 +203,8 @@ def get_hyperparams(self) -> Dict[str, Any]: "network": self.network, "batch_size": self.batch_size, "gamma": self.gamma, - "lr_actor": self.lr_actor, - "lr_critic": self.lr_critic, + "lr_policy": self.lr_policy, + "lr_value": self.lr_value, "rollout_size": self.rollout_size, "policy_weights": self.ac.actor.state_dict(), "value_weights": self.ac.critic.state_dict(), diff --git a/genrl/deep/agents/base.py b/genrl/deep/agents/base.py index 39252353..0a28c05e 100644 --- a/genrl/deep/agents/base.py +++ b/genrl/deep/agents/base.py @@ -102,7 +102,7 @@ def update_params(self, update_interval: int) -> None: """ raise NotImplementedError - def get_hyperparameters(self) -> Dict[str, Any]: + def get_hyperparams(self) -> Dict[str, Any]: """Get relevant hyperparameters to save Returns: From 33a81df41bb0507a1fc3e7b6d3f8b45c77bff856 Mon Sep 17 00:00:00 2001 From: sampreet-arthi Date: Wed, 14 Oct 2020 07:34:45 +0000 Subject: [PATCH 2/4] Changed episodes comparison from == to >= --- genrl/agents/__init__.py | 3 +-- genrl/trainers/offpolicy.py | 2 +- tests/test_agents/test_bandit/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/genrl/agents/__init__.py b/genrl/agents/__init__.py index 3257caff..f1c7977b 100644 --- a/genrl/agents/__init__.py +++ b/genrl/agents/__init__.py @@ -15,6 +15,7 @@ NeuralNoiseSamplingAgent, ) from genrl.agents.bandits.contextual.variational import VariationalAgent # noqa +from genrl.agents.bandits.multiarmed.base import MABAgent # noqa from genrl.agents.bandits.multiarmed.bayesian import BayesianUCBMABAgent # noqa from genrl.agents.bandits.multiarmed.bernoulli_mab import BernoulliMAB # noqa from genrl.agents.bandits.multiarmed.epsgreedy import EpsGreedyMABAgent # noqa @@ -41,5 +42,3 @@ from genrl.agents.deep.sac.sac import SAC # noqa from genrl.agents.deep.td3.td3 import TD3 # noqa from genrl.agents.deep.vpg.vpg import VPG # noqa - -from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa diff --git a/genrl/trainers/offpolicy.py b/genrl/trainers/offpolicy.py index 7e0571c2..d6484485 100644 --- a/genrl/trainers/offpolicy.py +++ b/genrl/trainers/offpolicy.py @@ -166,7 +166,7 @@ def train(self) -> None: if self.episodes % self.log_interval == 0: self.log(timestep) - if self.episodes == self.epochs: + if self.episodes >= self.epochs: break if timestep >= self.start_update and timestep % self.update_interval == 0: diff --git a/tests/test_agents/test_bandit/__init__.py b/tests/test_agents/test_bandit/__init__.py index 4411dff3..8faedc3d 100644 --- a/tests/test_agents/test_bandit/__init__.py +++ b/tests/test_agents/test_bandit/__init__.py @@ -1,6 +1,6 @@ from tests.test_agents.test_bandit.test_cb_agents import TestCBAgent # noqa from tests.test_agents.test_bandit.test_data_bandits import TestDataBandit # noqa from tests.test_agents.test_bandit.test_mab_agents import TestMABAgent # noqa -from tests.test_agents.test_bandit.test_multi_armed_bandits import ( - TestMultiArmedBandit, # noqa +from tests.test_agents.test_bandit.test_multi_armed_bandits import ( # noqa + TestMultiArmedBandit, ) From f04d573c47db9c27ccee8b41f6bba5054e438ce3 Mon Sep 17 00:00:00 2001 From: Sampreet Date: Sat, 17 Oct 2020 00:28:42 +0530 Subject: [PATCH 3/4] Move breaking outside of the check_game_over_status condition --- genrl/trainers/offpolicy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/genrl/trainers/offpolicy.py b/genrl/trainers/offpolicy.py index d6484485..8aa2607b 100644 --- a/genrl/trainers/offpolicy.py +++ b/genrl/trainers/offpolicy.py @@ -166,8 +166,8 @@ def train(self) -> None: if self.episodes % self.log_interval == 0: self.log(timestep) - if self.episodes >= self.epochs: - break + if self.episodes >= self.epochs: + break if timestep >= self.start_update and timestep % self.update_interval == 0: self.agent.update_params(self.update_interval) From 950226ad217ed7ccf7c3a000386acf236662b8ba Mon Sep 17 00:00:00 2001 From: Sampreet Date: Wed, 28 Oct 2020 20:00:07 +0530 Subject: [PATCH 4/4] Replaced for loop in off-policy trainer with while loop over timesteps and episodes --- genrl/trainers/offpolicy.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/genrl/trainers/offpolicy.py b/genrl/trainers/offpolicy.py index 8aa2607b..37cbed1a 100644 --- a/genrl/trainers/offpolicy.py +++ b/genrl/trainers/offpolicy.py @@ -142,11 +142,12 @@ def train(self) -> None: self.training_rewards = [] self.episodes = 0 + self.timesteps = 0 - for timestep in range(0, self.max_timesteps, self.env.n_envs): - self.agent.update_params_before_select_action(timestep) + while self.timesteps <= self.max_timesteps and self.episodes <= self.epochs: + self.agent.update_params_before_select_action(self.timesteps) - action = self.get_action(state, timestep) + action = self.get_action(state, self.timesteps) next_state, reward, done, info = self.env.step(action) if self.render: @@ -164,20 +165,22 @@ def train(self) -> None: self.noise_reset() if self.episodes % self.log_interval == 0: - self.log(timestep) + self.log(self.timesteps) - if self.episodes >= self.epochs: - break - - if timestep >= self.start_update and timestep % self.update_interval == 0: + if ( + self.timesteps >= self.start_update + and self.timesteps % self.update_interval == 0 + ): self.agent.update_params(self.update_interval) if ( - timestep >= self.start_update + self.timesteps >= self.start_update and self.save_interval != 0 - and timestep % self.save_interval == 0 + and self.timesteps % self.save_interval == 0 ): - self.save(timestep) + self.save(self.timesteps) + + self.timesteps += self.env.n_envs self.env.close() self.logger.close()