diff --git a/examples/baselines/ppo/ppo_rgb.py b/examples/baselines/ppo/ppo_rgb.py new file mode 100644 index 000000000..ef8a4f072 --- /dev/null +++ b/examples/baselines/ppo/ppo_rgb.py @@ -0,0 +1,506 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy +import os +import random +import time +from dataclasses import dataclass + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.normal import Normal +from torch.utils.tensorboard import SummaryWriter + +# ManiSkill specific imports +import mani_skill2.envs +from mani_skill2.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper +from mani_skill2.utils.wrappers.record import RecordEpisode +from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + save_model: bool = True + """whether to save model into the `runs/{run_name}` folder""" + upload_model: bool = False + """whether to upload the saved model to huggingface""" + hf_entity: str = "" + """the user or org name of the model repository from the Hugging Face Hub""" + + # Algorithm specific arguments + env_id: str = "PickCube-v1" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 3e-4 + """the learning rate of the optimizer""" + num_envs: int = 512 + """the number of parallel environments""" + num_eval_envs: int = 8 + """the number of parallel evaluation environments""" + num_steps: int = 50 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = False + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.8 + """the discount factor gamma""" + gae_lambda: float = 0.9 + """the lambda for the general advantage estimation""" + num_minibatches: int = 32 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.2 + """the surrogate clipping coefficient""" + clip_vloss: bool = False + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.0 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = 0.1 + """the target KL divergence threshold""" + eval_freq: int = 25 + """evaluation frequency in terms of iterations""" + finite_horizon_gae: bool = True + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + +class DictArray(object): + def __init__(self, buffer_shape, element_space, data_dict=None, device=None): + self.buffer_shape = buffer_shape + if data_dict: + self.data = data_dict + else: + assert isinstance(element_space, gym.spaces.dict.Dict) + self.data = {} + for k, v in element_space.items(): + if isinstance(v, gym.spaces.dict.Dict): + self.data[k] = DictArray(buffer_shape, v) + else: + self.data[k] = torch.zeros(buffer_shape + v.shape).to(device) + + def keys(self): + return self.data.keys() + + def __getitem__(self, index): + if isinstance(index, str): + return self.data[index] + return { + k: v[index] for k, v in self.data.items() + } + + def __setitem__(self, index, value): + if isinstance(index, str): + self.data[index] = value + for k, v in value.items(): + self.data[k][index] = v + + @property + def shape(self): + return self.buffer_shape + + def reshape(self, shape): + t = len(self.buffer_shape) + new_dict = {} + for k,v in self.data.items(): + if isinstance(v, DictArray): + new_dict[k] = v.reshape(shape) + else: + new_dict[k] = v.reshape(shape + v.shape[t:]) + new_buffer_shape = next(iter(new_dict.values())).shape[:len(shape)] + return DictArray(new_buffer_shape, None, data_dict=new_dict) + +class NatureCNN(nn.Module): + def __init__(self, sample_obs): + super().__init__() + + extractors = {} + + self.out_features = 0 + feature_size = 256 + in_channels=sample_obs["rgbd"].shape[-1] + image_size=(sample_obs["rgbd"].shape[1], sample_obs["rgbd"].shape[2]) + state_size=sample_obs["state"].shape[-1] + + # here we use a NatureCNN architecture to process images, but any architecture is permissble here + cnn = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=32, + kernel_size=8, + stride=4, + padding=0, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0 + ), + nn.ReLU(), + nn.Flatten(), + ) + + # to easily figure out the dimensions after flattening, we pass a test tensor + with torch.no_grad(): + n_flatten = cnn(sample_obs["rgbd"].float().permute(0,3,1,2).cpu()).shape[1] + fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU()) + extractors["rgbd"] = nn.Sequential(cnn, fc) + self.out_features += feature_size + + # for state data we simply pass it through a single linear layer + extractors["state"] = nn.Linear(state_size, 64) + self.out_features += 64 + + self.extractors = nn.ModuleDict(extractors) + + def forward(self, observations) -> torch.Tensor: + encoded_tensor_list = [] + # self.extractors contain nn.Modules that do all the processing. + for key, extractor in self.extractors.items(): + obs = observations[key] + if key == "rgbd": + obs = obs.float().permute(0,3,1,2) + obs = obs / 255 + encoded_tensor_list.append(extractor(obs)) + return torch.cat(encoded_tensor_list, dim=1) + +class Agent(nn.Module): + def __init__(self, envs, sample_obs): + super().__init__() + self.feature_net = NatureCNN(sample_obs=sample_obs) + # latent_size = np.array(envs.unwrapped.single_observation_space.shape).prod() + latent_size = self.feature_net.out_features + self.critic = nn.Sequential( + layer_init(nn.Linear(latent_size, 512)), + nn.Tanh(), + layer_init(nn.Linear(512, 1)), + ) + self.actor_mean = nn.Sequential( + layer_init(nn.Linear(latent_size, 512)), + nn.Tanh(), + layer_init(nn.Linear(512, np.prod(envs.unwrapped.single_action_space.shape)), std=0.01*np.sqrt(2)), + ) + self.actor_logstd = nn.Parameter(torch.ones(1, np.prod(envs.unwrapped.single_action_space.shape)) * -0.5) + def get_features(self, x): + return self.feature_net(x) + def get_value(self, x): + x = self.feature_net(x) + return self.critic(x) + def get_action(self, x, deterministic=False): + x = self.feature_net(x) + action_mean = self.actor_mean(x) + if deterministic: + return action_mean + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + return probs.sample() + def get_action_and_value(self, x, action=None): + x = self.feature_net(x) + action_mean = self.actor_mean(x) + action_logstd = self.actor_logstd.expand_as(action_mean) + action_std = torch.exp(action_logstd) + probs = Normal(action_mean, action_std) + if action is None: + action = probs.sample() + return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + env_kwargs = dict(obs_mode="rgbd", control_mode="pd_joint_delta_pos", render_mode="rgb_array") + envs = gym.make(args.env_id, num_envs=args.num_envs, **env_kwargs) + eval_envs = gym.make(args.env_id, num_envs=args.num_eval_envs, **env_kwargs) + + # rgbd obs mode returns a dict of data, we flatten it so there is just a rgbd key and state key + envs = FlattenRGBDObservationWrapper(envs, rgb_only=True) + eval_envs = FlattenRGBDObservationWrapper(eval_envs, rgb_only=True) + if isinstance(envs.action_space, gym.spaces.Dict): + envs = FlattenActionSpaceWrapper(envs) + eval_envs = FlattenActionSpaceWrapper(eval_envs) + if args.capture_video: + eval_envs = RecordEpisode(eval_envs, output_dir=f"runs/{run_name}/videos", save_trajectory=False, video_fps=30) + envs = ManiSkillVectorEnv(envs, args.num_envs, ignore_terminations=False, **env_kwargs) + eval_envs = ManiSkillVectorEnv(eval_envs, args.num_eval_envs, ignore_terminations=True, **env_kwargs) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + + + # ALGO Logic: Storage setup + obs = DictArray((args.num_steps, args.num_envs), envs.single_observation_space, device=device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset(seed=args.seed) + eval_obs, _ = eval_envs.reset(seed=args.seed) + next_done = torch.zeros(args.num_envs, device=device) + eps_returns = torch.zeros(args.num_envs, dtype=torch.float, device=device) + eps_lens = np.zeros(args.num_envs) + place_rew = torch.zeros(args.num_envs, device=device) + print(f"####") + print(f"args.num_iterations={args.num_iterations} args.num_envs={args.num_envs} args.num_eval_envs={args.num_eval_envs}") + print(f"args.minibatch_size={args.minibatch_size} args.batch_size={args.batch_size} args.update_epochs={args.update_epochs}") + print(f"####") + agent = Agent(envs, sample_obs=next_obs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + for iteration in range(1, args.num_iterations + 1): + print(f"Epoch: {iteration}, global_step={global_step}") + final_values = torch.zeros((args.num_steps, args.num_envs), device=device) + agent.eval() + if iteration % args.eval_freq == 1: + # evaluate + print("Evaluating") + eval_done = False + while not eval_done: + with torch.no_grad(): + eval_obs, _, eval_terminations, eval_truncations, eval_infos = eval_envs.step(agent.get_action(eval_obs, deterministic=True)) + if eval_truncations.any(): + eval_done = True + info = eval_infos["final_info"] + episodic_return = info['episode']['r'].mean().cpu().numpy() + print(f"eval_episodic_return={episodic_return}") + writer.add_scalar("charts/eval_success_rate", info["success"].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/eval_episodic_return", episodic_return, global_step) + writer.add_scalar("charts/eval_episodic_length", info["elapsed_steps"].float().mean().cpu().numpy(), global_step) + + if args.save_model and iteration % args.eval_freq == 1: + model_path = f"runs/{run_name}/{args.exp_name}_{iteration}.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, terminations, truncations, infos = envs.step(action) + next_done = torch.logical_or(terminations, truncations).to(torch.float32) + rewards[step] = reward.view(-1) + + if "final_info" in infos: + info = infos["final_info"] + done_mask = info["_final_info"] + episodic_return = info['episode']['r'][done_mask].mean().cpu().numpy() + writer.add_scalar("charts/success_rate", info["success"][done_mask].float().mean().cpu().numpy(), global_step) + writer.add_scalar("charts/episodic_return", episodic_return, global_step) + writer.add_scalar("charts/episodic_length", info["elapsed_steps"][done_mask].float().mean().cpu().numpy(), global_step) + for k in info["final_observation"]: + info["final_observation"][k] = info["final_observation"][k][done_mask] + final_values[step, torch.arange(args.num_envs, device=device)[done_mask]] = agent.get_value(info["final_observation"]).view(-1) + + # bootstrap value according to termination and truncation + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + next_not_done = 1.0 - next_done + nextvalues = next_value + else: + next_not_done = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + real_next_values = next_not_done * nextvalues + final_values[t] # t instead of t+1 + # next_not_done means nextvalues is computed from the correct next_obs + # if next_not_done is 1, final_values is always 0 + # if next_not_done is 0, then use final_values, which is computed according to bootstrap_at_done + if args.finite_horizon_gae: + """ + See GAE paper equation(16) line 1, we will compute the GAE based on this line only + 1 *( -V(s_t) + r_t + gamma * V(s_{t+1}) ) + lambda *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * V(s_{t+2}) ) + lambda^2 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + ... ) + lambda^3 *( -V(s_t) + r_t + gamma * r_{t+1} + gamma^2 * r_{t+2} + gamma^3 * r_{t+3} + We then normalize it by the sum of the lambda^i (instead of 1-lambda) + """ + if t == args.num_steps - 1: # initialize + lam_coef_sum = 0. + reward_term_sum = 0. # the sum of the second term + value_term_sum = 0. # the sum of the third term + lam_coef_sum = lam_coef_sum * next_not_done + reward_term_sum = reward_term_sum * next_not_done + value_term_sum = value_term_sum * next_not_done + + lam_coef_sum = 1 + args.gae_lambda * lam_coef_sum + reward_term_sum = args.gae_lambda * args.gamma * reward_term_sum + lam_coef_sum * rewards[t] + value_term_sum = args.gae_lambda * args.gamma * value_term_sum + args.gamma * real_next_values + + advantages[t] = (reward_term_sum + value_term_sum) / lam_coef_sum - values[t] + else: + delta = rewards[t] + args.gamma * real_next_values - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * next_not_done * lastgaelam # Here actually we should use next_not_terminated, but we don't have lastgamlam if terminated + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,)) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + agent.train() + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}_final.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + + envs.close() + writer.close() diff --git a/mani_skill2/utils/wrappers/flatten.py b/mani_skill2/utils/wrappers/flatten.py index 24c72bf4e..46d91092a 100644 --- a/mani_skill2/utils/wrappers/flatten.py +++ b/mani_skill2/utils/wrappers/flatten.py @@ -17,9 +17,10 @@ class FlattenRGBDObservationWrapper(gym.ObservationWrapper): Flattens the rgbd mode observations into a dictionary with two keys, "rgbd" and "state" """ - def __init__(self, env) -> None: + def __init__(self, env, rgb_only=False) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) + self.rgb_only = rgb_only new_obs = self.observation(self.base_env._init_raw_obs) self.base_env._update_obs_space(new_obs) @@ -29,7 +30,8 @@ def observation(self, observation: Dict): images = [] for cam_data in sensor_data.values(): images.append(cam_data["rgb"]) - images.append(cam_data["depth"]) + if not self.rgb_only: + images.append(cam_data["depth"]) images = torch.concat(images, axis=-1) # flatten the rest of the data which should just be state data observation = flatten_state_dict(observation, use_torch=True)