diff --git a/examples/baselines/act/README.md b/examples/baselines/act/README.md new file mode 100644 index 000000000..24785f1b8 --- /dev/null +++ b/examples/baselines/act/README.md @@ -0,0 +1,96 @@ +# Action Chunking with Transformers (ACT) + +Code for running the ACT algorithm based on ["Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware"](https://arxiv.org/pdf/2304.13705). It is adapted from the [original code](https://github.com/tonyzhaozh/act). + +## Installation + +To get started, we recommend using conda/mamba to create a new environment and install the dependencies + +```bash +conda create -n act-ms python=3.9 +conda activate act-ms +pip install -e . +``` + +## Demonstration Download and Preprocessing + +By default for fast downloads and smaller file sizes, ManiSkill demonstrations are stored in a highly reduced/compressed format which includes not keeping any observation data. Run the command to download the demonstration and convert it to a format that includes observation data and the desired action space. + +```bash +python -m mani_skill.utils.download_demo "PickCube-v1" +``` + +```bash +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 +``` + +Set -o to rgbd for RGBD observations. Note that the control mode can heavily influence how well Behavior Cloning performs. In the paper, they reported a degraded performance when using delta joint positions as actions instead of target joint positions. By default, we recommend using `pd_joint_delta_pos` for control mode as all tasks can be solved with that control mode, although it is harder to learn with BC than `pd_ee_delta_pos` or `pd_ee_delta_pose` for robots that have those control modes. Finally, the type of demonstration data used can also impact performance, with typically neural network generated demonstrations being easier to learn from than human/motion planning generated demonstrations. + +## Training + +We provide scripts to train ACT on demonstrations. Make sure to use the same sim backend as the backend the demonstrations were collected with. + + +Note that some demonstrations are slow (e.g. motion planning or human teleoperated) and can exceed the default max episode steps which can be an issue as imitation learning algorithms learn to solve the task at the same speed the demonstrations solve it. In this case, you can use the `--max-episode-steps` flag to set a higher value so that the policy can solve the task in time. General recommendation is to set `--max-episode-steps` to about 2x the length of the mean demonstrations length you are using for training. We provide recommended numbers for demonstrations in the examples.sh script. + +Example training, learning from 100 demonstrations generated via motionplanning in the PickCube-v1 task +```bash +python train.py --env-id PickCube-v1 \ + --demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 100 \ + --total_iters 30000 +``` + + +## Train and Evaluate with GPU Simulation + +You can also choose to train on trajectories generated in the GPU simulation and evaluate much faster with the GPU simulation. However as most demonstrations are usually generated in the CPU simulation (via motionplanning or teleoperation), you may observe worse performance when evaluating on the GPU simulation vs the CPU simulation. This can be partially alleviated by using the replay trajectory tool to try and replay trajectories back in the GPU simulation. + +It is also recommended to not save videos if you are using a lot of parallel environments as the video size can get very large. + +To replay trajectories in the GPU simulation, you can use the following command. Note that this can be a bit slow as the replay trajectory tool is currently not optimized for GPU parallelized environments. + +```bash +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 1 -b gpu --count 100 # process only 100 trajectories +``` + +Once our GPU backend demonstration dataset is ready, you can use the following command to train and evaluate on the GPU simulation. + +```bash +python train.py --env-id PickCube-v1 \ + --demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "gpu" --num-demos 100 --max_episode_steps 100 \ + --total_iters 30000 \ + --num-eval-envs 100 --no-capture-video +``` + +## Citation + +If you use this baseline please cite the following +``` +@inproceedings{DBLP:conf/rss/ZhaoKLF23, + author = {Tony Z. Zhao and + Vikash Kumar and + Sergey Levine and + Chelsea Finn}, + editor = {Kostas E. Bekris and + Kris Hauser and + Sylvia L. Herbert and + Jingjin Yu}, + title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, + booktitle = {Robotics: Science and Systems XIX, Daegu, Republic of Korea, July + 10-14, 2023}, + year = {2023}, + url = {https://doi.org/10.15607/RSS.2023.XIX.016}, + doi = {10.15607/RSS.2023.XIX.016}, + timestamp = {Thu, 20 Jul 2023 15:37:49 +0200}, + biburl = {https://dblp.org/rec/conf/rss/ZhaoKLF23.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/examples/baselines/act/act/detr/backbone.py b/examples/baselines/act/act/detr/backbone.py new file mode 100644 index 000000000..d1328e435 --- /dev/null +++ b/examples/baselines/act/act/detr/backbone.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from act.utils import NestedTensor, is_main_process +from act.detr.position_encoding import build_position_encoding + +import IPython +e = IPython.embed + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool, + include_depth: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + + # for rgbd data + if include_depth: + w = backbone.conv1.weight + w = torch.cat([w, torch.full((64, 1, 7, 7), 0)], dim=1) + backbone.conv1.weight = nn.Parameter(w) + + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation, args.include_depth) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/examples/baselines/act/act/detr/detr_vae.py b/examples/baselines/act/act/detr/detr_vae.py new file mode 100644 index 000000000..65c2b2811 --- /dev/null +++ b/examples/baselines/act/act/detr/detr_vae.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +from act.detr.transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer + +import numpy as np + +import IPython +e = IPython.embed + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries): + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, action_dim) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + else: + self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # size of latent z + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_state_proj = nn.Linear(state_dim, hidden_dim) # project state to embedding + self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding + self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var + self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], state, actions + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for state and proprio + + def forward(self, obs, actions=None): + is_training = actions is not None + state = obs['state'] if self.backbones is not None else obs + bs = state.shape[0] + + if is_training: + # project CLS token, state sequence, and action sequence to embedding dim + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + state_embed = self.encoder_state_proj(state) # (bs, hidden_dim) + state_embed = torch.unsqueeze(state_embed, axis=1) # (bs, 1, hidden_dim) + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + # concat them together to form an input to the CVAE encoder + encoder_input = torch.cat([cls_embed, state_embed, action_embed], axis=1) # (bs, seq+2, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+2, bs, hidden_dim) + # no masking is applied to all parts of the CVAE encoder input + is_pad = torch.full((bs, encoder_input.shape[0]), False).to(state.device) # False: not a padding + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+2, 1, hidden_dim) + # query CVAE encoder + encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, :self.latent_dim] + logvar = latent_info[:, self.latent_dim:] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(state.device) + latent_input = self.latent_out_proj(latent_sample) + + # CVAE decoder + if self.backbones is not None: + vis_data = obs['rgb'] + if "depth" in obs: + vis_data = torch.cat([vis_data, obs['depth']], dim=2) + num_cams = vis_data.shape[1] + + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id in range(num_cams): + features, pos = self.backbones[0](vis_data[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature # (batch, hidden_dim, H, W) + pos = pos[0] # (1, hidden_dim, H, W) + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + + # proprioception features (state) + proprio_input = self.input_proj_robot_state(state) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) # (batch, hidden_dim, 4, 8) + pos = torch.cat(all_cam_pos, axis=3) # (batch, hidden_dim, 4, 8) + hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] # (batch, num_queries, hidden_dim) + else: + state = self.input_proj_robot_state(state) + hs = self.transformer(None, None, self.query_embed.weight, None, latent_input, state, self.additional_pos_embed.weight)[0] + + a_hat = self.action_head(hs) + return a_hat, [mu, logvar] + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder diff --git a/examples/baselines/act/act/detr/position_encoding.py b/examples/baselines/act/act/detr/position_encoding.py new file mode 100644 index 000000000..f7585ccd3 --- /dev/null +++ b/examples/baselines/act/act/detr/position_encoding.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from act.utils import NestedTensor + +import IPython +e = IPython.embed + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/examples/baselines/act/act/detr/transformer.py b/examples/baselines/act/act/detr/transformer.py new file mode 100644 index 000000000..335a41406 --- /dev/null +++ b/examples/baselines/act/act/detr/transformer.py @@ -0,0 +1,313 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import IPython +e = IPython.embed + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): + if src is None: + bs = proprio_input.shape[0] + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + src = torch.stack([latent_input, proprio_input], axis=0) + # TODO flatten only when input has H and W + elif len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + hs = hs.transpose(1, 2) + return hs + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/examples/baselines/act/act/evaluate.py b/examples/baselines/act/act/evaluate.py new file mode 100644 index 000000000..fcd684500 --- /dev/null +++ b/examples/baselines/act/act/evaluate.py @@ -0,0 +1,94 @@ +from collections import defaultdict +import gymnasium +import numpy as np +import torch + +from mani_skill.utils import common + +def evaluate(n: int, agent, eval_envs, eval_kwargs): + stats, num_queries, temporal_agg, max_timesteps, device, sim_backend = eval_kwargs.values() + + use_visual_obs = isinstance(eval_envs.single_observation_space.sample(), dict) + delta_control = not stats + if not delta_control: + pre_process = lambda s_obs: (s_obs - stats['state_mean'].cpu().numpy()) / stats['state_std'].cpu().numpy() + post_process = lambda a: a * stats['action_std'].cpu().numpy() + stats['action_mean'].cpu().numpy() + + # create action table for temporal ensembling + action_dim = eval_envs.action_space.shape[-1] + num_envs = eval_envs.num_envs + if temporal_agg: + query_frequency = 1 + all_time_actions = np.zeros([num_envs, max_timesteps, max_timesteps+num_queries, action_dim], dtype=np.float32) + else: + query_frequency = num_queries + actions_to_take = np.zeros([num_envs, num_queries, action_dim]) + + agent.eval() + with torch.no_grad(): + eval_metrics = defaultdict(list) + obs, info = eval_envs.reset() + ts, eps_count = 0, 0 + while eps_count < n: + # pre-process obs + if use_visual_obs: + obs['state'] = pre_process(obs['state']) if not delta_control else obs['state'] # (num_envs, obs_dim) + obs = {k: common.to_tensor(v, device) for k, v in obs.items()} + else: + obs = pre_process(obs) if not delta_control else obs # (num_envs, obs_dim) + obs = common.to_tensor(obs, device) + + # query policy + if ts % query_frequency == 0: + action_seq = agent.get_action(obs) # (num_envs, num_queries, action_dim) + if sim_backend == "cpu": + action_seq = action_seq.cpu().numpy() + + # we assume ignore_terminations=True. Otherwise, some envs could be done + # earlier, so we would need to temporally ensemble at corresponding timestep + # for each env. + if temporal_agg: + assert query_frequency == 1, "query_frequency != 1 has not been implemented for temporal_agg==1." + all_time_actions[:, ts, ts:ts+num_queries] = action_seq # (num_envs, num_queries, act_dim) + actions_for_curr_step = all_time_actions[:, :, ts] # (num_envs, max_timesteps, act_dim) + # since we pad the action with 0 in 'delta_pos' control mode, this causes error. + #actions_populated = np.all(actions_for_curr_step[0] != 0, axis=1) # (max_timesteps,) + actions_populated = np.zeros(max_timesteps, dtype=bool) # (max_timesteps,) + actions_populated[max(0, ts + 1 - num_queries):ts+1] = True + actions_for_curr_step = actions_for_curr_step[:, actions_populated] # (num_envs, num_populated, act_dim) + k = 0.01 + if ts < num_queries: + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step[0])), dtype=np.float32) # (num_populated,) + exp_weights = exp_weights / exp_weights.sum() # (num_populated,) + exp_weights = np.tile(exp_weights, (num_envs, 1)) # (num_envs, num_populated) + exp_weights = np.expand_dims(exp_weights, axis=-1) # (num_envs, num_populated, 1) + raw_action = (actions_for_curr_step * exp_weights).sum(axis=1) # (num_envs, act_dim) + else: + if ts % query_frequency == 0: + actions_to_take = action_seq + raw_action = actions_to_take[:, ts % query_frequency] + action = post_process(raw_action) if not delta_control else raw_action # (num_envs, act_dim) + + # step the environment + obs, rew, terminated, truncated, info = eval_envs.step(action) + ts += 1 + + # collect episode info + if truncated.any(): + assert truncated.all() == truncated.any(), "all episodes should truncate at the same time for fair evaluation with other algorithms" + if isinstance(info["final_info"], dict): + for k, v in info["final_info"]["episode"].items(): + eval_metrics[k].append(v.float().cpu().numpy()) + else: + for final_info in info["final_info"]: + for k, v in final_info["episode"].items(): + eval_metrics[k].append(v) + # new episodes begin + eps_count += num_envs + ts = 0 + all_time_actions = np.zeros([num_envs, max_timesteps, max_timesteps+num_queries, action_dim], dtype=np.float32) + + agent.train() + for k in eval_metrics.keys(): + eval_metrics[k] = np.stack(eval_metrics[k]) + return eval_metrics diff --git a/examples/baselines/act/act/make_env.py b/examples/baselines/act/act/make_env.py new file mode 100644 index 000000000..2d9162a26 --- /dev/null +++ b/examples/baselines/act/act/make_env.py @@ -0,0 +1,46 @@ +from typing import Optional +import gymnasium as gym +import mani_skill.envs +from mani_skill.utils import gym_utils +from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv +from mani_skill.utils.wrappers import RecordEpisode, CPUGymWrapper + + +def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, other_kwargs: dict, video_dir: Optional[str] = None, wrappers: list[gym.Wrapper] = []): + """Create vectorized environment for evaluation and/or recording videos. + For CPU vectorized environments only the first parallel environment is used to record videos. + For GPU vectorized environments all parallel environments are used to record videos. + + Args: + env_id: the environment id + num_envs: the number of parallel environments + sim_backend: the simulation backend to use. can be "cpu" or "gpu + env_kwargs: the environment kwargs. You can also pass in max_episode_steps in env_kwargs to override the default max episode steps for the environment. + video_dir: the directory to save the videos. If None no videos are recorded. + wrappers: the list of wrappers to apply to the environment. + """ + if sim_backend == "cpu": + def cpu_make_env(env_id, seed, video_dir=None, env_kwargs = dict(), other_kwargs = dict()): + def thunk(): + env = gym.make(env_id, reconfiguration_freq=1, **env_kwargs) + for wrapper in wrappers: + env = wrapper(env) + env = CPUGymWrapper(env, ignore_terminations=True, record_metrics=True) + if video_dir: + env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, info_on_video=True, source_type="act", source_desc="act evaluation rollout") + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + vector_cls = gym.vector.SyncVectorEnv if num_envs == 1 else lambda x : gym.vector.AsyncVectorEnv(x, context="forkserver") + env = vector_cls([cpu_make_env(env_id, seed, video_dir if seed == 0 else None, env_kwargs, other_kwargs) for seed in range(num_envs)]) + else: + env = gym.make(env_id, num_envs=num_envs, sim_backend=sim_backend, reconfiguration_freq=1, **env_kwargs) + max_episode_steps = gym_utils.find_max_episode_steps_value(env) + for wrapper in wrappers: + env = wrapper(env) + if video_dir: + env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, save_video=True, source_type="act", source_desc="act evaluation rollout", max_steps_per_video=max_episode_steps) + env = ManiSkillVectorEnv(env, ignore_terminations=True, record_metrics=True) + return env diff --git a/examples/baselines/act/act/utils.py b/examples/baselines/act/act/utils.py new file mode 100644 index 000000000..c6da02856 --- /dev/null +++ b/examples/baselines/act/act/utils.py @@ -0,0 +1,161 @@ +from torch.utils.data.sampler import Sampler +import numpy as np +import torch +import torch.distributed as dist +from torch import Tensor +from h5py import File, Group, Dataset +from typing import Optional + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + +def is_main_process(): + return get_rank() == 0 + + +class IterationBasedBatchSampler(Sampler): + """Wraps a BatchSampler. + Resampling from it until a specified number of iterations have been sampled + References: + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py + """ + + def __init__(self, batch_sampler, num_iterations, start_iter=0): + self.batch_sampler = batch_sampler + self.num_iterations = num_iterations + self.start_iter = start_iter + + def __iter__(self): + iteration = self.start_iter + while iteration < self.num_iterations: + # if the underlying sampler has a set_epoch method, like + # DistributedSampler, used for making each process see + # a different split of the dataset, then set it + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) + for batch in self.batch_sampler: + yield batch + iteration += 1 + if iteration >= self.num_iterations: + break + + def __len__(self): + return self.num_iterations - self.start_iter + + +def worker_init_fn(worker_id, base_seed=None): + """The function is designed for pytorch multi-process dataloader. + Note that we use the pytorch random generator to generate a base_seed. + Please try to be consistent. + References: + https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed + """ + if base_seed is None: + base_seed = torch.IntTensor(1).random_().item() + # print(worker_id, base_seed) + np.random.seed(base_seed + worker_id) + +TARGET_KEY_TO_SOURCE_KEY = { + 'states': 'env_states', + 'observations': 'obs', + 'success': 'success', + 'next_observations': 'obs', + # 'dones': 'dones', + # 'rewards': 'rewards', + 'actions': 'actions', +} +def load_content_from_h5_file(file): + if isinstance(file, (File, Group)): + return {key: load_content_from_h5_file(file[key]) for key in list(file.keys())} + elif isinstance(file, Dataset): + return file[()] + else: + raise NotImplementedError(f"Unspported h5 file type: {type(file)}") + +def load_hdf5(path, ): + print('Loading HDF5 file', path) + file = File(path, 'r') + ret = load_content_from_h5_file(file) + file.close() + print('Loaded') + return ret + +def load_traj_hdf5(path, num_traj=None): + print('Loading HDF5 file', path) + file = File(path, 'r') + keys = list(file.keys()) + if num_traj is not None: + assert num_traj <= len(keys), f"num_traj: {num_traj} > len(keys): {len(keys)}" + keys = sorted(keys, key=lambda x: int(x.split('_')[-1])) + keys = keys[:num_traj] + ret = { + key: load_content_from_h5_file(file[key]) for key in keys + } + file.close() + print('Loaded') + return ret +def load_demo_dataset(path, keys=['observations', 'actions'], num_traj=None, concat=True): + # assert num_traj is None + raw_data = load_traj_hdf5(path, num_traj) + # raw_data has keys like: ['traj_0', 'traj_1', ...] + # raw_data['traj_0'] has keys like: ['actions', 'dones', 'env_states', 'infos', ...] + _traj = raw_data['traj_0'] + for key in keys: + source_key = TARGET_KEY_TO_SOURCE_KEY[key] + assert source_key in _traj, f"key: {source_key} not in traj_0: {_traj.keys()}" + dataset = {} + for target_key in keys: + # if 'next' in target_key: + # raise NotImplementedError('Please carefully deal with the length of trajectory') + source_key = TARGET_KEY_TO_SOURCE_KEY[target_key] + dataset[target_key] = [ raw_data[idx][source_key] for idx in raw_data ] + if isinstance(dataset[target_key][0], np.ndarray) and concat: + if target_key in ['observations', 'states'] and \ + len(dataset[target_key][0]) > len(raw_data['traj_0']['actions']): + dataset[target_key] = np.concatenate([ + t[:-1] for t in dataset[target_key] + ], axis=0) + elif target_key in ['next_observations', 'next_states'] and \ + len(dataset[target_key][0]) > len(raw_data['traj_0']['actions']): + dataset[target_key] = np.concatenate([ + t[1:] for t in dataset[target_key] + ], axis=0) + else: + dataset[target_key] = np.concatenate(dataset[target_key], axis=0) + + print('Load', target_key, dataset[target_key].shape) + else: + print('Load', target_key, len(dataset[target_key]), type(dataset[target_key][0])) + return dataset diff --git a/examples/baselines/act/examples.sh b/examples/baselines/act/examples.sh new file mode 100644 index 000000000..b315a9b10 --- /dev/null +++ b/examples/baselines/act/examples.sh @@ -0,0 +1,91 @@ +### Example scripts for training ACT that have some results ### + +# Learning from motion planning generated demonstrations + +# PickCube-v1 + +# state +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PickCube-v1 \ + --demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 100 --total_iters 5000 --save_freq 5000 + +# rgbd +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o rgbd \ + --save-traj --num-procs 10 -b cpu + +python train_rgbd.py --env-id PickCube-v1 \ + --demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.rgbd.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 100 --total_iters 20000 --save_freq 20000 + +# PushCube-v1 + +# state +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PushCube-v1 \ + --demo-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 100 --total_iters 20000 --save_freq 20000 + +# rgbd +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o rgbd \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PushCube-v1 \ + --demo-path ~/.maniskill/demos/PushCube-v1/motionplanning/trajectory.rgbd.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 100 --total_iters 100000 --save_freq 100000 + +# StackCube-v1 + +# state +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id StackCube-v1 \ + --demo-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 200 --total_iters 30000 --save_freq 30000 + +# rgbd +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pos -o rgbd \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id StackCube-v1 \ + --demo-path ~/.maniskill/demos/StackCube-v1/motionplanning/trajectory.rgbd.pd_ee_delta_pos.cpu.h5 \ + --control-mode "pd_ee_delta_pos" --sim-backend "cpu" --max_episode_steps 200 --total_iters 100000 --save_freq 100000 + +# PegInsertionSide-v1 + +# state +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pose -o state \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PegInsertionSide-v1 \ + --demo-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.state.pd_ee_delta_pose.cpu.h5 \ + --control-mode "pd_ee_delta_pose" --sim-backend "cpu" --max_episode_steps 300 --total_iters 300000 --save_freq 300000 + +# rgbd +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pose -o rgbd \ + --save-traj --num-procs 10 -b cpu + +python train.py --env-id PegInsertionSide-v1 \ + --demo-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.rgbd.pd_ee_delta_pose.cpu.h5 \ + --control-mode "pd_ee_delta_pose" --sim-backend "cpu" --max_episode_steps 300 --total_iters 1000000 --save_freq 1000000 diff --git a/examples/baselines/act/setup.py b/examples/baselines/act/setup.py new file mode 100644 index 000000000..960b2ec1b --- /dev/null +++ b/examples/baselines/act/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name="act", + version="0.1.0", + packages=find_packages(), + install_requires=[ + "torchvision", + "diffusers", + "tensorboard", + "wandb", + "mani_skill" + ], + description="A minimal setup for ACT for ManiSkill", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", +) diff --git a/examples/baselines/act/train.py b/examples/baselines/act/train.py new file mode 100644 index 000000000..3f9fdd52c --- /dev/null +++ b/examples/baselines/act/train.py @@ -0,0 +1,452 @@ +ALGO_NAME = 'BC_ACT_state' + +import argparse +import os +import random +from distutils.util import strtobool +import time +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.transforms as T +from torch.utils.tensorboard import SummaryWriter +from act.evaluate import evaluate +from mani_skill.utils import common, gym_utils +from mani_skill.utils.registration import REGISTERED_ENVS + +from collections import defaultdict + +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import RandomSampler, BatchSampler +from torch.utils.data.dataloader import DataLoader +from act.utils import IterationBasedBatchSampler, worker_init_fn +from act.make_env import make_eval_envs +from diffusers.training_utils import EMAModel +from act.detr.transformer import build_transformer +from act.detr.detr_vae import build_encoder, DETRVAE +from dataclasses import dataclass, field +from typing import Optional, List +import tyro + +@dataclass +class Args: + exp_name: Optional[str] = None + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "ManiSkill" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + + env_id: str = "PickCube-v1" + """the id of the environment""" + demo_path: str = 'pickcube.trajectory.state.pd_joint_delta_pos.cpu.h5' + """the path of demo dataset (pkl or h5)""" + num_demos: Optional[int] = None + """number of trajectories to load from the demo dataset""" + total_iters: int = 1_000_000 + """total timesteps of the experiment""" + batch_size: int = 1024 + """the batch size of sample from the replay memory""" + + # ACT specific arguments + lr: float = 1e-4 + """the learning rate of the Action Chunking with Transformers""" + kl_weight: float = 10 + """weight for the kl loss term""" + temporal_agg: bool = True + """if toggled, temporal ensembling will be performed""" + + # Backbone + position_embedding: str = 'sine' + backbone: str = 'resnet18' + lr_backbone: float = 1e-5 + masks: bool = False + dilation: bool = False + + # Transformer + enc_layers: int = 4 + dec_layers: int = 7 + dim_feedforward: int = 1600 + hidden_dim: int = 512 + dropout: float = 0.1 + nheads: int = 8 + num_queries: int = 30 + pre_norm: bool = False + + # Environment/experiment specific arguments + max_episode_steps: Optional[int] = None + """Change the environments' max_episode_steps to this value. Sometimes necessary if the demonstrations being imitated are too short. Typically the default + max episode steps of environments in ManiSkill are tuned lower so reinforcement learning agents can learn faster.""" + log_freq: int = 1000 + """the frequency of logging the training metrics""" + eval_freq: int = 5000 + """the frequency of evaluating the agent on the evaluation environments""" + save_freq: Optional[int] = None + """the frequency of saving the model checkpoints. By default this is None and will only save checkpoints based on the best evaluation metrics.""" + num_eval_episodes: int = 100 + """the number of episodes to evaluate the agent on""" + num_eval_envs: int = 10 + """the number of parallel environments to evaluate the agent on""" + sim_backend: str = "cpu" + """the simulation backend to use for evaluation environments. can be "cpu" or "gpu""" + num_dataload_workers: int = 0 + """the number of workers to use for loading the training data in the torch dataloader""" + control_mode: str = 'pd_joint_delta_pos' + """the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset.""" + + # additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms + demo_type: Optional[str] = None + + +class SmallDemoDataset_ACTPolicy(Dataset): # Load everything into GPU memory + def __init__(self, data_path, num_queries, device, num_traj): + if data_path[-4:] == '.pkl': + raise NotImplementedError() + else: + from act.utils import load_demo_dataset + trajectories = load_demo_dataset(data_path, num_traj=num_traj, concat=False) + # trajectories['observations'] is a list of np.ndarray (L+1, obs_dim) + # trajectories['actions'] is a list of np.ndarray (L, act_dim) + + for k, v in trajectories.items(): + for i in range(len(v)): + trajectories[k][i] = torch.Tensor(v[i]).to(device) + + # When the robot reaches the goal state, its joints and gripper fingers need to remain stationary + if 'delta_pos' in args.control_mode or args.control_mode == 'base_pd_joint_vel_arm_pd_joint_vel': + self.pad_action_arm = torch.zeros((trajectories['actions'][0].shape[1]-1,), device=device) + # to make the arm stay still, we pad the action with 0 in 'delta_pos' control mode + # gripper action needs to be copied from the last action + # else: + # raise NotImplementedError(f'Control Mode {args.control_mode} not supported') + + self.slices = [] + self.num_traj = len(trajectories['actions']) + for traj_idx in range(self.num_traj): + episode_len = trajectories['actions'][traj_idx].shape[0] + self.slices += [ + (traj_idx, ts) for ts in range(episode_len) + ] + + print(f"Length of Dataset: {len(self.slices)}") + + self.num_queries = num_queries + self.trajectories = trajectories + self.delta_control = 'delta' in args.control_mode + self.norm_stats = self.get_norm_stats() if not self.delta_control else None + + def __getitem__(self, index): + traj_idx, ts = self.slices[index] + + # get observation at ts only + obs = self.trajectories['observations'][traj_idx][ts] + # get num_queries actions + act_seq = self.trajectories['actions'][traj_idx][ts:ts+self.num_queries] + action_len = act_seq.shape[0] + + # Pad after the trajectory, so all the observations are utilized in training + if action_len < self.num_queries: + if 'delta_pos' in args.control_mode or args.control_mode == 'base_pd_joint_vel_arm_pd_joint_vel': + gripper_action = act_seq[-1, -1] + pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0) + act_seq = torch.cat([act_seq, pad_action.repeat(self.num_queries-action_len, 1)], dim=0) + # making the robot (arm and gripper) stay still + elif not self.delta_control: + target = act_seq[-1] + act_seq = torch.cat([act_seq, target.repeat(self.num_queries-action_len, 1)], dim=0) + + # normalize obs and act_seq + if not self.delta_control: + obs = (obs - self.norm_stats["state_mean"][0]) / self.norm_stats["state_std"][0] + act_seq = (act_seq - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] + + return { + 'observations': obs, + 'actions': act_seq, + } + + def __len__(self): + return len(self.slices) + + def get_norm_stats(self): + traj_idx, ts = self.slices[index] + + # get observation at start_ts only + obs = self.trajectories['observations'][traj_idx][ts] + # get num_queries actions + act_seq = self.trajectories['actions'][traj_idx][ts:ts+self.num_queries] + action_len = act_seq.shape[0] + + # Pad after the trajectory, so all the observations are utilized in training + if action_len < self.num_queries: + if 'delta_pos' in args.control_mode or args.control_mode == 'base_pd_joint_vel_arm_pd_joint_vel': + gripper_action = act_seq[-1, -1] + pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0) + act_seq = torch.cat([act_seq, pad_action.repeat(self.num_queries-action_len, 1)], dim=0) + # making the robot (arm and gripper) stay still + elif not self.delta_control: + target = act_seq[-1] + act_seq = torch.cat([act_seq, target.repeat(self.num_queries-action_len, 1)], dim=0) + + # normalize obs and act_seq + if not self.delta_control: + obs = (obs - self.norm_stats["state_mean"][0]) / self.norm_stats["state_std"][0] + act_seq = (act_seq - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] + + return { + 'observations': obs, + 'actions': act_seq, + } + + +class Agent(nn.Module): + def __init__(self, env, args): + super().__init__() + assert len(env.single_observation_space.shape) == 1 # (obs_dim,) + assert len(env.single_action_space.shape) == 1 # (act_dim,) + #assert (env.single_action_space.high == 1).all() and (env.single_action_space.low == -1).all() + + self.kl_weight = args.kl_weight + self.state_dim = env.single_observation_space.shape[0] + self.act_dim = env.single_action_space.shape[0] + + # CNN backbone + backbones = None + + # CVAE decoder + transformer = build_transformer(args) + + # CVAE encoder + encoder = build_encoder(args) + + # ACT ( CVAE encoder + (CNN backbones + CVAE decoder) ) + self.model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=self.state_dim, + action_dim=self.act_dim, + num_queries=args.num_queries, + ) + + def compute_loss(self, obs, action_seq): + # forward pass + a_hat, (mu, logvar) = self.model(obs, action_seq) + + # compute l1 loss and kl loss + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + all_l1 = F.l1_loss(action_seq, a_hat, reduction='none') + l1 = all_l1.mean() + + # store all loss + loss_dict = dict() + loss_dict['l1'] = l1 + loss_dict['kl'] = total_kld[0] + loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight + return loss_dict + + def get_action(self, obs): + # forward pass + a_hat, (_, _) = self.model(obs) # no action, sample from prior + return a_hat + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld + +def save_ckpt(run_name, tag): + os.makedirs(f'runs/{run_name}/checkpoints', exist_ok=True) + ema.copy_to(ema_agent.parameters()) + torch.save({ + 'norm_stats': dataset.norm_stats, + 'agent': agent.state_dict(), + 'ema_agent': ema_agent.state_dict(), + }, f'runs/{run_name}/checkpoints/{tag}.pt') + +if __name__ == "__main__": + args = tyro.cli(Args) + if args.exp_name is None: + args.exp_name = os.path.basename(__file__)[: -len(".py")] + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + else: + run_name = args.exp_name + + if args.demo_path.endswith('.h5'): + import json + json_file = args.demo_path[:-2] + 'json' + with open(json_file, 'r') as f: + demo_info = json.load(f) + if 'control_mode' in demo_info['env_info']['env_kwargs']: + control_mode = demo_info['env_info']['env_kwargs']['control_mode'] + elif 'control_mode' in demo_info['episodes'][0]: + control_mode = demo_info['episodes'][0]['control_mode'] + else: + raise Exception('Control mode not found in json') + assert control_mode == args.control_mode, f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + env_kwargs = dict(control_mode=args.control_mode, reward_mode="sparse", obs_mode="state", render_mode="rgb_array") + if args.max_episode_steps is not None: + env_kwargs["max_episode_steps"] = args.max_episode_steps + other_kwargs = None + envs = make_eval_envs(args.env_id, args.num_eval_envs, args.sim_backend, env_kwargs, other_kwargs, video_dir=f'runs/{run_name}/videos' if args.capture_video else None) + + if args.track: + import wandb + config = vars(args) + config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, env_horizon=gym_utils.find_max_episode_steps_value(envs)) + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=config, + name=run_name, + save_code=True, + group="ACT", + tags=["act"] + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # dataloader setup + dataset = SmallDemoDataset_ACTPolicy(args.demo_path, args.num_queries, device, num_traj=args.num_demos) + sampler = RandomSampler(dataset, replacement=False) + batch_sampler = BatchSampler(sampler, batch_size=args.batch_size, drop_last=True) + batch_sampler = IterationBasedBatchSampler(batch_sampler, args.total_iters) + train_dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_dataload_workers, + worker_init_fn=lambda worker_id: worker_init_fn(worker_id, base_seed=args.seed), + ) + if args.num_demos is None: + args.num_demos = len(dataset) + + # agent setup + agent = Agent(envs, args).to(device) + + # optimizer setup + param_dicts = [ + {"params": [p for n, p in agent.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in agent.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = optim.AdamW(param_dicts, lr=args.lr, weight_decay=1e-4) + + # LR drop by a factor of 10 after lr_drop iters + lr_drop = int((2/3)*args.total_iters) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, lr_drop) + + # Exponential Moving Average + # accelerates training and improves stability + # holds a copy of the model weights + ema = EMAModel(parameters=agent.parameters(), power=0.75) + ema_agent = Agent(envs, args).to(device) + + # Evaluation + eval_kwargs = dict( + stats=dataset.norm_stats, num_queries=args.num_queries, temporal_agg=args.temporal_agg, + max_timesteps=gym_utils.find_max_episode_steps_value(envs), device=device, sim_backend=args.sim_backend + ) + + # ---------------------------------------------------------------------------- # + # Training begins. + # ---------------------------------------------------------------------------- # + agent.train() + + best_eval_metrics = defaultdict(float) + timings = defaultdict(float) + + for iteration, data_batch in enumerate(train_dataloader): + cur_iter = iteration + 1 + + # forward and compute loss + loss_dict = agent.compute_loss( + obs=data_batch['observations'], # (B, obs_dim) + action_seq=data_batch['actions'], # (B, num_queries, act_dim) + ) + total_loss = loss_dict['loss'] # total_loss = l1 + kl * self.kl_weight + + # backward + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + lr_scheduler.step() # step lr scheduler every batch, this is different from standard pytorch behavior + last_tick = time.time() + + # update Exponential Moving Average of the model weights + ema.step(agent.parameters()) + # TRY NOT TO MODIFY: record rewards for plotting purposes + if cur_iter % args.log_freq == 0: + print(f"Iteration {cur_iter}, loss: {total_loss.item()}") + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], cur_iter) + writer.add_scalar("losses/total_loss", total_loss.item(), cur_iter) + for k, v in timings.items(): + writer.add_scalar(f"time/{k}", v, cur_iter) + + # Evaluation + if cur_iter % args.eval_freq == 0: + last_tick = time.time() + + ema.copy_to(ema_agent.parameters()) + + eval_metrics = evaluate(args.num_eval_episodes, ema_agent, envs, eval_kwargs) + timings["eval"] += time.time() - last_tick + + print(f"Evaluated {len(eval_metrics['success_at_end'])} episodes") + for k in eval_metrics.keys(): + eval_metrics[k] = np.mean(eval_metrics[k]) + writer.add_scalar(f"eval/{k}", eval_metrics[k], cur_iter) + print(f"{k}: {eval_metrics[k]:.4f}") + + save_on_best_metrics = ["success_once", "success_at_end"] + for k in save_on_best_metrics: + if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]: + best_eval_metrics[k] = eval_metrics[k] + save_ckpt(run_name, f"best_eval_{k}") + print(f'New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint.') + + # Checkpoint + if args.save_freq is not None and cur_iter % args.save_freq == 0: + save_ckpt(run_name, str(cur_iter)) + + envs.close() + writer.close() diff --git a/examples/baselines/act/train_rgbd.py b/examples/baselines/act/train_rgbd.py new file mode 100644 index 000000000..335c41aa0 --- /dev/null +++ b/examples/baselines/act/train_rgbd.py @@ -0,0 +1,607 @@ +ALGO_NAME = 'BC_ACT_rgbd' + +import argparse +import os +import random +from distutils.util import strtobool +from functools import partial +import time +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision.transforms as T +from torch.utils.tensorboard import SummaryWriter +from act.evaluate import evaluate +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.utils import common, gym_utils +from mani_skill.utils.registration import REGISTERED_ENVS + +from collections import defaultdict + +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import RandomSampler, BatchSampler +from torch.utils.data.dataloader import DataLoader +from act.utils import IterationBasedBatchSampler, worker_init_fn +from act.make_env import make_eval_envs +from diffusers.training_utils import EMAModel +from act.detr.backbone import build_backbone +from act.detr.transformer import build_transformer +from act.detr.detr_vae import build_encoder, DETRVAE +from dataclasses import dataclass, field +from typing import Optional, List, Dict +import tyro + +@dataclass +class Args: + exp_name: Optional[str] = None + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "ManiSkill" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + + env_id: str = "PickCube-v1" + """the id of the environment""" + demo_path: str = 'pickcube.trajectory.rgbd.pd_joint_delta_pos.cpu.h5' + """the path of demo dataset (pkl or h5)""" + num_demos: Optional[int] = None + """number of trajectories to load from the demo dataset""" + total_iters: int = 1_000_000 + """total timesteps of the experiment""" + batch_size: int = 256 + """the batch size of sample from the replay memory""" + + # ACT specific arguments + lr: float = 1e-4 + """the learning rate of the Action Chunking with Transformers""" + kl_weight: float = 10 + """weight for the kl loss term""" + temporal_agg: bool = True + """if toggled, temporal ensembling will be performed""" + + # Backbone + position_embedding: str = 'sine' + backbone: str = 'resnet18' + lr_backbone: float = 1e-5 + masks: bool = False + dilation: bool = False + include_depth: bool = True + + # Transformer + enc_layers: int = 4 + dec_layers: int = 7 + dim_feedforward: int = 1600 + hidden_dim: int = 512 + dropout: float = 0.1 + nheads: int = 8 + num_queries: int = 30 + pre_norm: bool = False + + # Environment/experiment specific arguments + max_episode_steps: Optional[int] = None + """Change the environments' max_episode_steps to this value. Sometimes necessary if the demonstrations being imitated are too short. Typically the default + max episode steps of environments in ManiSkill are tuned lower so reinforcement learning agents can learn faster.""" + log_freq: int = 1000 + """the frequency of logging the training metrics""" + eval_freq: int = 5000 + """the frequency of evaluating the agent on the evaluation environments""" + save_freq: Optional[int] = None + """the frequency of saving the model checkpoints. By default this is None and will only save checkpoints based on the best evaluation metrics.""" + num_eval_episodes: int = 100 + """the number of episodes to evaluate the agent on""" + num_eval_envs: int = 10 + """the number of parallel environments to evaluate the agent on""" + sim_backend: str = "cpu" + """the simulation backend to use for evaluation environments. can be "cpu" or "gpu""" + num_dataload_workers: int = 0 + """the number of workers to use for loading the training data in the torch dataloader""" + control_mode: str = 'pd_joint_delta_pos' + """the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset.""" + + # additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms + demo_type: Optional[str] = None + + +class FlattenRGBDObservationWrapper(gym.ObservationWrapper): + """ + Flattens the rgbd mode observations into a dictionary with two keys, "rgbd" and "state" + + Args: + rgb (bool): Whether to include rgb images in the observation + depth (bool): Whether to include depth images in the observation + state (bool): Whether to include state data in the observation + + Note that the returned observations will have a "rgbd" or "rgb" or "depth" key depending on the rgb/depth bool flags. + """ + + def __init__(self, env, rgb=True, depth=True, state=True) -> None: + self.base_env: BaseEnv = env.unwrapped + super().__init__(env) + self.include_rgb = rgb + self.include_depth = depth + self.include_state = state + self.transforms = T.Compose( + [ + T.Resize((224, 224), antialias=True), + ] + ) # resize the input image to be at least 224x224 + new_obs = self.observation(self.base_env._init_raw_obs) + self.base_env.update_obs_space(new_obs) + + def observation(self, observation: Dict): + sensor_data = observation.pop("sensor_data") + del observation["sensor_param"] + images_rgb = [] + images_depth = [] + for cam_data in sensor_data.values(): + if self.include_rgb: + resized_rgb = self.transforms( + cam_data["rgb"].permute(0, 3, 1, 2) + ) # (1, 3, 224, 224) + images_rgb.append(resized_rgb) + if self.include_depth: + depth = (cam_data["depth"].to(torch.float32) / 1024).to(torch.float16) + resized_depth = self.transforms( + depth.permute(0, 3, 1, 2) + ) # (1, 1, 224, 224) + images_depth.append(resized_depth) + + rgb = torch.stack(images_rgb, dim=1) # (1, num_cams, C, 224, 224), uint8 + if self.include_depth: + depth = torch.stack(images_depth, dim=1) # (1, num_cams, C, 224, 224), float16 + + # flatten the rest of the data which should just be state data + observation = common.flatten_state_dict(observation, use_torch=True) + ret = dict() + if self.include_state: + ret["state"] = observation + if self.include_rgb and not self.include_depth: + ret["rgb"] = rgb + elif self.include_rgb and self.include_depth: + ret["rgb"] = rgb + ret["depth"] = depth + elif self.include_depth and not self.include_rgb: + ret["depth"] = depth + return ret + + +class SmallDemoDataset_ACTPolicy(Dataset): # Load everything into memory + def __init__(self, data_path, num_queries, num_traj, include_depth=True): + if data_path[-4:] == '.pkl': + raise NotImplementedError() + else: + from act.utils import load_demo_dataset + trajectories = load_demo_dataset(data_path, num_traj=num_traj, concat=False) + # trajectories['observations'] is a list of np.ndarray (L+1, obs_dim) + # trajectories['actions'] is a list of np.ndarray (L, act_dim) + print('Raw trajectory loaded, start to pre-process the observations...') + + self.include_depth = include_depth + self.transforms = T.Compose( + [ + T.Resize((224, 224), antialias=True), + ] + ) # pre-trained models from torchvision.models expect input image to be at least 224x224 + + # Pre-process the observations, make them align with the obs returned by the FlattenRGBDObservationWrapper + obs_traj_dict_list = [] + for obs_traj_dict in trajectories['observations']: + obs_traj_dict = self.process_obs(obs_traj_dict) + obs_traj_dict_list.append(obs_traj_dict) + trajectories['observations'] = obs_traj_dict_list + self.obs_keys = list(obs_traj_dict.keys()) + + # Pre-process the actions + for i in range(len(trajectories['actions'])): + trajectories['actions'][i] = torch.Tensor(trajectories['actions'][i]) + print('Obs/action pre-processing is done.') + + # When the robot reaches the goal state, its joints and gripper fingers need to remain stationary + if 'delta_pos' in args.control_mode or args.control_mode == 'base_pd_joint_vel_arm_pd_joint_vel': + self.pad_action_arm = torch.zeros((trajectories['actions'][0].shape[1]-1,)) + # to make the arm stay still, we pad the action with 0 in 'delta_pos' control mode + # gripper action needs to be copied from the last action + # else: + # raise NotImplementedError(f'Control Mode {args.control_mode} not supported') + + self.slices = [] + self.num_traj = len(trajectories['actions']) + for traj_idx in range(self.num_traj): + episode_len = trajectories['actions'][traj_idx].shape[0] + self.slices += [ + (traj_idx, ts) for ts in range(episode_len) + ] + + print(f"Length of Dataset: {len(self.slices)}") + + self.num_queries = num_queries + self.trajectories = trajectories + self.delta_control = 'delta' in args.control_mode + self.norm_stats = self.get_norm_stats() if not self.delta_control else None + + def __getitem__(self, index): + traj_idx, ts = self.slices[index] + + # get state at start_ts only + state = self.trajectories['observations'][traj_idx]['state'][ts] + # get num_queries actions + act_seq = self.trajectories['actions'][traj_idx][ts:ts+self.num_queries] + action_len = act_seq.shape[0] + + # Pad after the trajectory, so all the observations are utilized in training + if action_len < self.num_queries: + if 'delta_pos' in args.control_mode or args.control_mode == 'base_pd_joint_vel_arm_pd_joint_vel': + gripper_action = act_seq[-1, -1] + pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0) + act_seq = torch.cat([act_seq, pad_action.repeat(self.num_queries-action_len, 1)], dim=0) + # making the robot (arm and gripper) stay still + elif not self.delta_control: + target = act_seq[-1] + act_seq = torch.cat([act_seq, target.repeat(self.num_queries-action_len, 1)], dim=0) + + # normalize state and act_seq + if not self.delta_control: + state = (state - self.norm_stats["state_mean"][0]) / self.norm_stats["state_std"][0] + act_seq = (act_seq - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] + + # get rgb or rgbd data at start_ts and combine with state to form obs + if self.include_depth: + rgb = self.trajectories['observations'][traj_idx]['rgb'][ts] + depth = self.trajectories['observations'][traj_idx]['depth'][ts] + obs = dict(state=state, rgb=rgb, depth=depth) + else: + rgb = self.trajectories['observations'][traj_idx]['rgb'][ts] + obs = dict(state=state, rgb=rgb) + + return { + 'observations': obs, + 'actions': act_seq, + } + + def __len__(self): + return len(self.slices) + + def process_obs(self, obs_dict): + # get rgbd data + sensor_data = obs_dict.pop("sensor_data") + del obs_dict["sensor_param"] + images_rgb = [] + images_depth = [] + for cam_data in sensor_data.values(): + rgb = torch.from_numpy(cam_data["rgb"]) # (ep_len, H, W, 3) + resized_rgb = self.transforms( + rgb.permute(0, 3, 1, 2) + ) # (ep_len, 3, 224, 224); pre-trained models from torchvision.models expect input image to be at least 224x224 + images_rgb.append(resized_rgb) + if self.include_depth: + depth = torch.Tensor(cam_data["depth"].astype(np.float32) / 1024).to(torch.float16) # (ep_len, H, W, 1) + resized_depth = self.transforms( + depth.permute(0, 3, 1, 2) + ) # (ep_len, 1, 224, 224); pre-trained models from torchvision.models expect input image to be at least 224x224 + images_depth.append(resized_depth) + rgb = torch.stack(images_rgb, dim=1) # (ep_len, num_cams, 3, 224, 224) # still uint8 + if self.include_depth: + depth = torch.stack(images_depth, dim=1) # (ep_len, num_cams, 1, 224, 224) # float16 + + # flatten the rest of the data which should just be state data + obs_dict['extra'] = {k: v[:, None] if len(v.shape) == 1 else v for k, v in obs_dict['extra'].items()} # dirty fix for data that has one dimension (e.g. is_grasped) + obs_dict = common.flatten_state_dict(obs_dict, use_torch=True) + + processed_obs = dict(state=obs_dict, rgb=rgb, depth=depth) if self.include_depth else dict(state=obs_dict, rgb=rgb) + + return processed_obs + + def get_norm_stats(self): + all_state_data = [] + all_action_data = [] + for traj_idx, ts in self.slices: + state = self.trajectories['observations'][traj_idx]['state'][ts] + act_seq = self.trajectories['actions'][traj_idx][ts:ts+self.num_queries] + action_len = act_seq.shape[0] + if action_len < self.num_queries: + target_pos = act_seq[-1] + act_seq = torch.cat([act_seq, target_pos.repeat(self.num_queries-action_len, 1)], dim=0) + all_state_data.append(state) + all_action_data.append(act_seq) + + all_state_data = torch.stack(all_state_data) + all_action_data = torch.concatenate(all_action_data) + + # normalize obs (state) data + state_mean = all_state_data.mean(dim=0, keepdim=True) + state_std = all_state_data.std(dim=0, keepdim=True) + state_std = torch.clip(state_std, 1e-2, np.inf) # clipping + + # normalize action data + action_mean = all_action_data.mean(dim=0, keepdim=True) + action_std = all_action_data.std(dim=0, keepdim=True) + action_std = torch.clip(action_std, 1e-2, np.inf) # clipping + + stats = {"action_mean": action_mean, "action_std": action_std, + "state_mean": state_mean, "state_std": state_std, + "example_state": state} + + return stats + + +class Agent(nn.Module): + def __init__(self, env, args): + super().__init__() + assert len(env.single_observation_space['state'].shape) == 1 # (obs_dim,) + assert len(env.single_observation_space['rgb'].shape) == 4 # (num_cams, C, H, W) + assert len(env.single_action_space.shape) == 1 # (act_dim,) + #assert (env.single_action_space.high == 1).all() and (env.single_action_space.low == -1).all() + + self.state_dim = env.single_observation_space['state'].shape[0] + self.act_dim = env.single_action_space.shape[0] + self.kl_weight = args.kl_weight + self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # CNN backbone + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + # CVAE decoder + transformer = build_transformer(args) + + # CVAE encoder + encoder = build_encoder(args) + + # ACT ( CVAE encoder + (CNN backbones + CVAE decoder) ) + self.model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=self.state_dim, + action_dim=self.act_dim, + num_queries=args.num_queries, + ) + + def compute_loss(self, obs, action_seq): + # normalize rgb data + obs['rgb'] = obs['rgb'].float() / 255.0 + obs['rgb'] = self.normalize(obs['rgb']) + + # depth data + if args.include_depth: + obs['depth'] = obs['depth'].float() + + # forward pass + a_hat, (mu, logvar) = self.model(obs, action_seq) + + # compute l1 loss and kl loss + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + all_l1 = F.l1_loss(action_seq, a_hat, reduction='none') + l1 = all_l1.mean() + + # store all loss + loss_dict = dict() + loss_dict['l1'] = l1 + loss_dict['kl'] = total_kld[0] + loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight + return loss_dict + + def get_action(self, obs): + # normalize rgb data + obs['rgb'] = obs['rgb'].float() / 255.0 + obs['rgb'] = self.normalize(obs['rgb']) + + # depth data + if args.include_depth: + obs['depth'] = obs['depth'].float() + + # forward pass + a_hat, (_, _) = self.model(obs) # no action, sample from prior + + return a_hat + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld + +def save_ckpt(run_name, tag): + os.makedirs(f'runs/{run_name}/checkpoints', exist_ok=True) + ema.copy_to(ema_agent.parameters()) + torch.save({ + 'norm_stats': dataset.norm_stats, + 'agent': agent.state_dict(), + 'ema_agent': ema_agent.state_dict(), + }, f'runs/{run_name}/checkpoints/{tag}.pt') + +if __name__ == "__main__": + args = tyro.cli(Args) + + if args.exp_name is None: + args.exp_name = os.path.basename(__file__)[: -len(".py")] + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + else: + run_name = args.exp_name + + if args.demo_path.endswith('.h5'): + import json + json_file = args.demo_path[:-2] + 'json' + with open(json_file, 'r') as f: + demo_info = json.load(f) + if 'control_mode' in demo_info['env_info']['env_kwargs']: + control_mode = demo_info['env_info']['env_kwargs']['control_mode'] + elif 'control_mode' in demo_info['episodes'][0]: + control_mode = demo_info['episodes'][0]['control_mode'] + else: + raise Exception('Control mode not found in json') + assert control_mode == args.control_mode, f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + env_kwargs = dict(control_mode=args.control_mode, reward_mode="sparse", obs_mode="rgbd" if args.include_depth else "rgb", render_mode="rgb_array") + if args.max_episode_steps is not None: + env_kwargs["max_episode_steps"] = args.max_episode_steps + other_kwargs = None + wrappers = [partial(FlattenRGBDObservationWrapper, depth=args.include_depth)] + envs = make_eval_envs(args.env_id, args.num_eval_envs, args.sim_backend, env_kwargs, other_kwargs, video_dir=f'runs/{run_name}/videos' if args.capture_video else None, wrappers=wrappers) + + if args.track: + import wandb + config = vars(args) + config["eval_env_cfg"] = dict(**env_kwargs, num_envs=args.num_eval_envs, env_id=args.env_id, env_horizon=gym_utils.find_max_episode_steps_value(envs)) + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=config, + name=run_name, + save_code=True, + group="ACT", + tags=["act"] + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # dataloader setup + dataset = SmallDemoDataset_ACTPolicy(args.demo_path, args.num_queries, num_traj=args.num_demos, include_depth=args.include_depth) + sampler = RandomSampler(dataset, replacement=False) + batch_sampler = BatchSampler(sampler, batch_size=args.batch_size, drop_last=True) + batch_sampler = IterationBasedBatchSampler(batch_sampler, args.total_iters) + train_dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_dataload_workers, + worker_init_fn=lambda worker_id: worker_init_fn(worker_id, base_seed=args.seed), + ) + if args.num_demos is None: + args.num_demos = len(dataset) + + # agent setup + agent = Agent(envs, args).to(device) + + # optimizer setup + param_dicts = [ + {"params": [p for n, p in agent.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in agent.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = optim.AdamW(param_dicts, lr=args.lr, weight_decay=1e-4) + + # LR drop by a factor of 10 after lr_drop iters + lr_drop = int((2/3)*args.total_iters) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, lr_drop) + + # Exponential Moving Average + # accelerates training and improves stability + # holds a copy of the model weights + ema = EMAModel(parameters=agent.parameters(), power=0.75) + ema_agent = Agent(envs, args).to(device) + + # Evaluation + eval_kwargs = dict( + stats=dataset.norm_stats, num_queries=args.num_queries, temporal_agg=args.temporal_agg, + max_timesteps=gym_utils.find_max_episode_steps_value(envs), device=device, sim_backend=args.sim_backend + ) + + # ---------------------------------------------------------------------------- # + # Training begins. + # ---------------------------------------------------------------------------- # + agent.train() + + best_eval_metrics = defaultdict(float) + timings = defaultdict(float) + + for iteration, data_batch in enumerate(train_dataloader): + cur_iter = iteration + 1 + + # copy data from cpu to gpu + obs_batch_dict = data_batch['observations'] + obs_batch_dict = {k: v.cuda(non_blocking=True) for k, v in obs_batch_dict.items()} + act_batch = data_batch['actions'].cuda(non_blocking=True) + + # forward and compute loss + loss_dict = agent.compute_loss( + obs=obs_batch_dict, # obs_batch_dict['state'] is (B, obs_dim) + action_seq=act_batch, # (B, num_queries, act_dim) + ) + total_loss = loss_dict['loss'] # total_loss = l1 + kl * self.kl_weight + + # backward + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + lr_scheduler.step() # step lr scheduler every batch, this is different from standard pytorch behavior + last_tick = time.time() + + # update Exponential Moving Average of the model weights + ema.step(agent.parameters()) + # TRY NOT TO MODIFY: record rewards for plotting purposes + if cur_iter % args.log_freq == 0: + print(f"Iteration {cur_iter}, loss: {total_loss.item()}") + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], cur_iter) + writer.add_scalar("charts/backbone_learning_rate", optimizer.param_groups[1]["lr"], cur_iter) + writer.add_scalar("losses/total_loss", total_loss.item(), cur_iter) + for k, v in timings.items(): + writer.add_scalar(f"time/{k}", v, cur_iter) + + # Evaluation + if cur_iter % args.eval_freq == 0: + last_tick = time.time() + + ema.copy_to(ema_agent.parameters()) + + eval_metrics = evaluate(args.num_eval_episodes, ema_agent, envs, eval_kwargs) + timings["eval"] += time.time() - last_tick + + print(f"Evaluated {len(eval_metrics['success_at_end'])} episodes") + for k in eval_metrics.keys(): + eval_metrics[k] = np.mean(eval_metrics[k]) + writer.add_scalar(f"eval/{k}", eval_metrics[k], cur_iter) + print(f"{k}: {eval_metrics[k]:.4f}") + + save_on_best_metrics = ["success_once", "success_at_end"] + for k in save_on_best_metrics: + if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]: + best_eval_metrics[k] = eval_metrics[k] + save_ckpt(run_name, f"best_eval_{k}") + print(f'New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint.') + + # Checkpoint + if args.save_freq is not None and cur_iter % args.save_freq == 0: + save_ckpt(run_name, str(cur_iter)) + + envs.close() + writer.close()