From c8c5c75f8a417e98777a28f02aa0d3720d238c4d Mon Sep 17 00:00:00 2001 From: Federico Pizarro Bejarano Date: Mon, 17 Jun 2024 00:58:36 -0400 Subject: [PATCH] Fixing some minor issues (#156) --- examples/cbf/config_overrides/ppo_config.yaml | 2 +- examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml | 2 +- .../mpsc/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml | 2 +- .../rl/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml | 2 +- .../rl/config_overrides/quadrotor_3D/ppo_quadrotor_3D.yaml | 2 +- safe_control_gym/controllers/ppo/ppo.yaml | 2 +- safe_control_gym/controllers/sac/sac.yaml | 2 +- safe_control_gym/experiments/base_experiment.py | 5 +++-- 8 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/cbf/config_overrides/ppo_config.yaml b/examples/cbf/config_overrides/ppo_config.yaml index 5d79442df..77f98bef2 100644 --- a/examples/cbf/config_overrides/ppo_config.yaml +++ b/examples/cbf/config_overrides/ppo_config.yaml @@ -2,7 +2,7 @@ algo: ppo algo_config: # model args hidden_dim: 64 - activation: relu + activation: tanh norm_obs: False norm_reward: False clip_obs: 10.0 diff --git a/examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml b/examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml index 5d79442df..77f98bef2 100644 --- a/examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml +++ b/examples/mpsc/config_overrides/cartpole/ppo_cartpole.yaml @@ -2,7 +2,7 @@ algo: ppo algo_config: # model args hidden_dim: 64 - activation: relu + activation: tanh norm_obs: False norm_reward: False clip_obs: 10.0 diff --git a/examples/mpsc/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml b/examples/mpsc/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml index 6ee9f0b72..1f9fde511 100644 --- a/examples/mpsc/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml +++ b/examples/mpsc/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml @@ -2,7 +2,7 @@ algo: ppo algo_config: # model args hidden_dim: 256 - activation: relu + activation: tanh # loss args use_gae: True diff --git a/examples/rl/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml b/examples/rl/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml index e0128c2b4..74bd15488 100644 --- a/examples/rl/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml +++ b/examples/rl/config_overrides/quadrotor_2D/ppo_quadrotor_2D.yaml @@ -2,7 +2,7 @@ algo: ppo algo_config: # model args hidden_dim: 128 - activation: relu + activation: tanh # loss args use_gae: True diff --git a/examples/rl/config_overrides/quadrotor_3D/ppo_quadrotor_3D.yaml b/examples/rl/config_overrides/quadrotor_3D/ppo_quadrotor_3D.yaml index 166626c6a..3b2aee7a3 100644 --- a/examples/rl/config_overrides/quadrotor_3D/ppo_quadrotor_3D.yaml +++ b/examples/rl/config_overrides/quadrotor_3D/ppo_quadrotor_3D.yaml @@ -2,7 +2,7 @@ algo: ppo algo_config: # model args hidden_dim: 128 - activation: relu + activation: tanh # loss args use_gae: True diff --git a/safe_control_gym/controllers/ppo/ppo.yaml b/safe_control_gym/controllers/ppo/ppo.yaml index 04126e91f..39ad17867 100644 --- a/safe_control_gym/controllers/ppo/ppo.yaml +++ b/safe_control_gym/controllers/ppo/ppo.yaml @@ -1,6 +1,6 @@ # Model args hidden_dim: 64 -activation: 'tanh' +activation: tanh norm_obs: False norm_reward: False clip_obs: 10 diff --git a/safe_control_gym/controllers/sac/sac.yaml b/safe_control_gym/controllers/sac/sac.yaml index 6cecd6667..6f68389b5 100644 --- a/safe_control_gym/controllers/sac/sac.yaml +++ b/safe_control_gym/controllers/sac/sac.yaml @@ -1,6 +1,6 @@ # model args hidden_dim: 256 -activation: 'relu' +activation: relu norm_obs: False norm_reward: False clip_obs: 10. diff --git a/safe_control_gym/experiments/base_experiment.py b/safe_control_gym/experiments/base_experiment.py index be54622a2..d8f86c536 100644 --- a/safe_control_gym/experiments/base_experiment.py +++ b/safe_control_gym/experiments/base_experiment.py @@ -170,8 +170,9 @@ def _select_action(self, obs, info): if self.safety_filter is not None: physical_action = self.env.denormalize_action(action) unextended_obs = obs[:self.env.symbolic.nx] - certified_action, _ = self.safety_filter.certify_action(unextended_obs, physical_action, info) - action = self.env.normalize_action(certified_action) + certified_action, success = self.safety_filter.certify_action(unextended_obs, physical_action, info) + if success: + action = self.env.normalize_action(certified_action) return action