From 305f84dd59dc69bde4c10d677d4286cb27927050 Mon Sep 17 00:00:00 2001 From: jvmancuso Date: Sat, 23 Feb 2019 15:02:40 -0500 Subject: [PATCH] Eval every & TransitionBoatRace-v0 fixes (#61) * fixing evaluation check (cf. #60) * make LR required * add boat transition env to parsing code, clean the file * moving ppo crmdp to proper place * fixing shapes * assume pip > 1.18.1, rm dependency_links --- safe_grid_agents/common/agents/__init__.py | 2 - safe_grid_agents/common/agents/policy_base.py | 19 +- safe_grid_agents/common/agents/policy_cnn.py | 23 +- .../common/agents/policy_crmdp.py | 202 ------------------ safe_grid_agents/common/learn.py | 2 +- .../parsing/agent_parser_configs.yaml | 1 + .../parsing/env_parser_configs.yaml | 1 + safe_grid_agents/parsing/parse.py | 11 +- safe_grid_agents/ssrl/__init__.py | 1 - setup.py | 9 +- 10 files changed, 39 insertions(+), 232 deletions(-) delete mode 100644 safe_grid_agents/common/agents/policy_crmdp.py diff --git a/safe_grid_agents/common/agents/__init__.py b/safe_grid_agents/common/agents/__init__.py index 94e4f9b..61d8531 100644 --- a/safe_grid_agents/common/agents/__init__.py +++ b/safe_grid_agents/common/agents/__init__.py @@ -3,7 +3,6 @@ from safe_grid_agents.common.agents.value import TabularQAgent, DeepQAgent from safe_grid_agents.common.agents.policy_mlp import PPOMLPAgent from safe_grid_agents.common.agents.policy_cnn import PPOCNNAgent -from safe_grid_agents.common.agents.policy_crmdp import PPOCRMDPAgent __all__ = [ "RandomAgent", @@ -12,5 +11,4 @@ "DeepQAgent", "PPOMLPAgent", "PPOCNNAgent", - "PPOCRMDPAgent", ] diff --git a/safe_grid_agents/common/agents/policy_base.py b/safe_grid_agents/common/agents/policy_base.py index 38d5556..ce7b74c 100644 --- a/safe_grid_agents/common/agents/policy_base.py +++ b/safe_grid_agents/common/agents/policy_base.py @@ -21,7 +21,7 @@ def __init__(self, env, args) -> None: self.action_n = env.action_space.n self.discount = args.discount self.board_shape = env.observation_space.shape - self.n_input = self.board_shape[0] * self.board_shape[1] + self.n_input = self.board_shape[0] * self.board_shape[1] * self.board_shape[2] self.device = args.device self.log_gradients = args.log_gradients @@ -64,15 +64,20 @@ def policy(self, state) -> Categorical: def learn(self, states, actions, rewards, returns, history, args) -> History: states = torch.as_tensor(states, dtype=torch.float, device=self.device) - actions = torch.as_tensor(actions, dtype=torch.long, device=self.device) - returns = torch.as_tensor(returns, dtype=torch.float, device=self.device) + rlsz = self.rollouts * states.size(1) + states = states.reshape(rlsz, states.shape[2], states.shape[3], states.shape[4]) + actions = torch.as_tensor( + actions, dtype=torch.long, device=self.device + ).reshape(rlsz, -1) + returns = torch.as_tensor( + returns, dtype=torch.float, device=self.device + ).reshape(rlsz, -1) for epoch in range(self.epochs): - rlsz = self.rollouts * states.size(1) ixs = torch.randint(rlsz, size=(self.batch_size,), dtype=torch.long) - s = states.reshape(rlsz, states.shape[2], states.shape[3])[ixs] - a = actions.reshape(rlsz, -1)[ixs].reshape(-1) - r = returns.reshape(rlsz, -1)[ixs].reshape(-1) + s = states[ixs] + a = actions[ixs].reshape(-1) + r = returns[ixs].reshape(-1) prepolicy, state_values = self(s) state_values = state_values.reshape(-1) diff --git a/safe_grid_agents/common/agents/policy_cnn.py b/safe_grid_agents/common/agents/policy_cnn.py index 6e213e9..e11e3f5 100644 --- a/safe_grid_agents/common/agents/policy_cnn.py +++ b/safe_grid_agents/common/agents/policy_cnn.py @@ -16,8 +16,11 @@ def __init__(self, env, args) -> None: def build_ac(self) -> None: """Build the fused actor-critic architecture.""" + in_channels = self.board_shape[0] first = nn.Sequential( - torch.nn.Conv2d(1, self.n_channels, kernel_size=3, stride=1, padding=1), + torch.nn.Conv2d( + in_channels, self.n_channels, kernel_size=3, stride=1, padding=1 + ), nn.ReLU(), ) hidden = nn.Sequential( @@ -36,6 +39,9 @@ def build_ac(self) -> None: ) ) self.network = nn.Sequential(first, hidden) + self.bottleneck = nn.Conv2d( + in_channels, self.n_channels, kernel_size=1, stride=1 + ) self.actor_cnn = nn.Sequential( torch.nn.Conv2d( @@ -44,7 +50,8 @@ def build_ac(self) -> None: nn.ReLU(), ) self.actor_linear = nn.Linear( - self.n_input * (self.n_channels), int(self.action_n) + self.n_channels * self.board_shape[1] * self.board_shape[2], + int(self.action_n), ) self.critic_cnn = nn.Sequential( @@ -53,15 +60,15 @@ def build_ac(self) -> None: ), nn.ReLU(), ) - self.critic_linear = nn.Linear(self.n_input * (self.n_channels), 1) + self.critic_linear = nn.Linear( + self.n_channels * self.board_shape[1] * self.board_shape[2], 1 + ) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - if len(x.shape) == 2: - x = x.reshape(1, 1, x.shape[0], x.shape[1]) - elif len(x.shape) == 3: - x = x.unsqueeze(1) + if len(x.shape) == 3: + x = x.unsqueeze(0) - convolutions = self.network(x) + x + convolutions = self.network(x) + self.bottleneck(x) actor = self.actor_cnn(convolutions) actor = actor.reshape(actor.shape[0], -1) diff --git a/safe_grid_agents/common/agents/policy_crmdp.py b/safe_grid_agents/common/agents/policy_crmdp.py deleted file mode 100644 index e146dec..0000000 --- a/safe_grid_agents/common/agents/policy_crmdp.py +++ /dev/null @@ -1,202 +0,0 @@ -"""PPO Agent for CRMDPs.""" -import sys -import torch -import numpy as np -from typing import Generator, List - -from safe_grid_agents.common.utils import track_metrics -from safe_grid_agents.common.agents.policy_cnn import PPOCNNAgent -from safe_grid_agents.types import Rollout - -from ai_safety_gridworlds.environments.tomato_crmdp import REWARD_FACTOR - - -def d_tomato_crmdp(X, Y): - return REWARD_FACTOR * np.sum(X != Y) - - -def d_toy_gridworlds(X, Y): - assert X.shape == Y.shape - X_pos_x, X_pos_y = np.unravel_index(np.argwhere(np.ravel(X) == 0), X.shape) - Y_pos_x, Y_pos_y = np.unravel_index(np.argwhere(np.ravel(Y) == 0), X.shape) - X_pos_x, X_pos_y = X_pos_x.flat[0], X_pos_y.flat[0] - Y_pos_x, Y_pos_y = Y_pos_x.flat[0], Y_pos_y.flat[0] - return abs(X_pos_x - Y_pos_x) + abs(X_pos_y - Y_pos_y) - - -class PPOCRMDPAgent(PPOCNNAgent): - """PPO Agent for CRMDPs.""" - - def __init__(self, env, args) -> None: - super().__init__(env, args) - self.states = dict() - self.d = d_toy_gridworlds - self.epsilon = 1e-3 - self.rllb = dict() - - def _mark_state_corrupt(self, board, reward) -> None: - assert board.dtype == np.float32 - self.states[board.tostring()] = [False, reward] - - def _mark_state_safe(self, board, reward) -> None: - assert board.dtype == np.float32 - self.states[board.tostring()] = [True, reward] - - def _is_state_corrupt(self, board) -> bool: - if board.tostring() in self.states: - return not self.states[board.tostring()][0] - else: - return False - - def _iterate_safe_states(self) -> Generator[np.array, None, None]: - for board_str in self.states.keys(): - if self.states[board_str][0]: - board = np.fromstring(board_str, dtype=np.float32, count=self.n_input) - board = np.reshape(board, self.board_shape) - yield board, self.states[board_str][1] - - def _iterate_corrupt_states(self) -> Generator[np.array, None, None]: - for board_str in self.states.keys(): - if not self.states[board_str][0]: - board = np.fromstring(board_str, dtype=np.float32, count=self.n_input) - board = np.reshape(board, self.board_shape) - yield board, self.states[board_str][1] - - def _update_rllb(self) -> None: - """Update the reward lower Lipschitz bound.""" - for corrupt_board, corrupt_reward in self._iterate_corrupt_states(): - rllb = None - for safe_board, safe_reward in self._iterate_safe_states(): - bound = safe_reward - self.d(safe_board, corrupt_board) - if rllb is None or bound > rllb: - rllb = bound - self.rllb[corrupt_board.tostring()] = rllb - - def _get_TLV(self, boardX, rewardX, state_iterator) -> float: - """Return the total Lipschitz violation of a state X w.r.t a set of states. - Each state is only added once to the TLV.""" - TLV = 0 - unique_states = set() - for boardY, rewardY in state_iterator: - if boardY.tostring() not in unique_states: - TLV += max(0, abs(rewardX - rewardY) - self.d(boardY, boardX)) - unique_states.add(boardY.tostring()) - return TLV - - def get_modified_reward(self, board, reward) -> float: - """Return the reward to use for optimizing the policy based on the rllb.""" - if self._is_state_corrupt(board): - return self.rllb[board.tostring()] - else: - return reward - - def get_modified_rewards_for_rollout(self, boards, rewards) -> List[float]: - """ - Returns a list of rewards for a given rollout that has been updated based - on the rllb. - """ - new_rewards = [] - for i in range(len(rewards)): - new_rewards.append(self.get_modified_reward(boards[i], rewards[i])) - return new_rewards - - def identify_corruption_in_trajectory(self, boards, rewards) -> None: - """Perform detection of corrupt states on a trajectory. - - Updates the set of safe states and corrupt states with all new states, - that are being visited in this trajectory. Then updates the self.rllb - dict, so that we can get the modified reward function. - """ - boards = np.array(boards) - rewards = np.array(rewards) - - TLV = np.zeros(len(boards)) - for i in range(len(boards)): - TLV[i] = self._get_TLV(boards[i], rewards[i], zip(boards, rewards)) - - TLV_sort_idx = np.argsort(TLV)[::-1] - non_corrupt_idx = list(range(len(boards))) - added_corrupt_states = False - - # iterate over all states in the trajectory in order decreasing by their TLV - for i in range(len(boards)): - idx = TLV_sort_idx[i] - if not added_corrupt_states: - # performance improvement - new_TLV = TLV[idx] - else: - new_TLV = self._get_TLV( - boards[idx], - rewards[idx], - zip(boards[non_corrupt_idx], rewards[non_corrupt_idx]), - ) - - if new_TLV <= self.epsilon: - if not self._is_state_corrupt(boards[idx]): - self._mark_state_safe(boards[idx], rewards[idx]) - break - else: - self._mark_state_corrupt(boards[idx], rewards[idx]) - non_corrupt_idx.remove(idx) - added_corrupt_states = True - - if added_corrupt_states: - self._update_rllb() - - def gather_rollout(self, env, env_state, history, args) -> Rollout: - """Gather a single rollout from an old policy. - - Based on the gather_rollout function of the regular PPO agents. - This version also tracks the successor states of each action. - Based on this the corrupted states can be detected before performing - the training step.""" - state, reward, done, info = env_state - done = False - rollout = Rollout(states=[], actions=[], rewards=[], returns=[]) - successors = [] - - for r in range(self.rollouts): - successors_r = [] - # Rollout loop - states, actions, rewards, returns = [], [], [], [] - while not done: - with torch.no_grad(): - action = self.old_policy.act_explore(state) - successor, reward, done, info = env.step(action) - - # Maybe cheat - if args.cheat: - reward = info["hidden_reward"] - # In case the agent is drunk, use the actual action they took - try: - action = info["extra_observations"]["actual_actions"] - except KeyError: - pass - - # Store data from experience - states.append(state) # .flatten()) - actions.append(action) - rewards.append(float(reward)) - successors_r.append(successor) - - state = successor - history["t"] += 1 - - if r != 0: - history["episode"] += 1 - - self.identify_corruption_in_trajectory(successors_r, rewards) - rewards = self.get_modified_rewards_for_rollout(successors_r, rewards) - - returns = self.get_discounted_returns(rewards) - history = track_metrics(history, env) - rollout.states.append(states) - rollout.actions.append(actions) - rollout.rewards.append(rewards) - rollout.returns.append(returns) - successors.append(successors_r) - - state = env.reset() - done = False - - return rollout diff --git a/safe_grid_agents/common/learn.py b/safe_grid_agents/common/learn.py index b55c04b..f3c0ab9 100644 --- a/safe_grid_agents/common/learn.py +++ b/safe_grid_agents/common/learn.py @@ -98,7 +98,7 @@ def ppo_learn(agent, env, env_state, history, args): agent.sync() # Check for evaluating next - if history["episode"] % args.eval_every == args.eval_every - 1: + if history["episode"] % args.eval_every == 0 and history["episode"] > 0: eval_next = True return env_state, history, eval_next diff --git a/safe_grid_agents/parsing/agent_parser_configs.yaml b/safe_grid_agents/parsing/agent_parser_configs.yaml index 21ad14f..25facdd 100644 --- a/safe_grid_agents/parsing/agent_parser_configs.yaml +++ b/safe_grid_agents/parsing/agent_parser_configs.yaml @@ -11,6 +11,7 @@ tabular-q: lr: &learnrate alias: l type: float + required: true help: "Learning rate (required)" epsilon: &epsilon alias: e diff --git a/safe_grid_agents/parsing/env_parser_configs.yaml b/safe_grid_agents/parsing/env_parser_configs.yaml index 371f8d0..ea4ceda 100644 --- a/safe_grid_agents/parsing/env_parser_configs.yaml +++ b/safe_grid_agents/parsing/env_parser_configs.yaml @@ -11,3 +11,4 @@ tomato-crmdp: whisky: corners: way: +trans-boat: diff --git a/safe_grid_agents/parsing/parse.py b/safe_grid_agents/parsing/parse.py index 145ec70..9f27f54 100644 --- a/safe_grid_agents/parsing/parse.py +++ b/safe_grid_agents/parsing/parse.py @@ -5,23 +5,17 @@ import yaml -from ai_safety_gridworlds.environments.boat_race import BoatRaceEnvironment -from ai_safety_gridworlds.environments.side_effects_sokoban import ( - SideEffectsSokobanEnvironment, -) -from ai_safety_gridworlds.environments.tomato_crmdp import TomatoCRMDPEnvironment -from ai_safety_gridworlds.environments.tomato_watering import TomatoWateringEnvironment from safe_grid_agents.common.agents import ( DeepQAgent, PPOCNNAgent, - PPOCRMDPAgent, PPOMLPAgent, RandomAgent, SingleActionAgent, TabularQAgent, ) from safe_grid_agents.parsing import agent_config, core_config, env_config -from safe_grid_agents.ssrl import TabularSSQAgent +from safe_grid_agents.spiky.agents import PPOCRMDPAgent +from safe_grid_agents.ssrl.agents import TabularSSQAgent # Mapping of envs/agents to Python classes @@ -39,6 +33,7 @@ "whisky": "WhiskyGold-v0", "corners": "ToyGridworldCorners-v0", "way": "ToyGridworldOnTheWay-v0", + "trans-boat": "TransitionBoatRace-v0", } AGENT_MAP = { # Dict[AgentName, Agent] diff --git a/safe_grid_agents/ssrl/__init__.py b/safe_grid_agents/ssrl/__init__.py index 65ab299..e69de29 100644 --- a/safe_grid_agents/ssrl/__init__.py +++ b/safe_grid_agents/ssrl/__init__.py @@ -1 +0,0 @@ -from safe_grid_agents.ssrl.agents import TabularSSQAgent diff --git a/setup.py b/setup.py index 26568ef..73c9c21 100644 --- a/setup.py +++ b/setup.py @@ -42,9 +42,12 @@ "rl " "reinforcement learning " ), - install_requires=["safe-grid-gym", "pyyaml", "moviepy", "tensorboardX<=1.5", "ray"], - dependency_links=[ - "https://github.com/david-lindner/safe-grid-gym/tarball/master#egg=safe-grid-gym-0.2" + install_requires=[ + "safe-grid-gym @ git+https://github.com/david-lindner/safe-grid-gym.git", + "pyyaml", + "moviepy", + "tensorboardX<=1.5", + "ray", ], packages=setuptools.find_packages(), zip_safe=True,