Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Action Chunking with Transformers (ACT) to baselines #640

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
43d027b
Create train.py
ywchoi02 Oct 20, 2024
cc0cb60
Create train_rgb.py
ywchoi02 Oct 20, 2024
e154fe3
Create evaluate.py
ywchoi02 Oct 20, 2024
2e126e7
Create make_env.py
ywchoi02 Oct 20, 2024
7abc4cb
Create utils.py
ywchoi02 Oct 20, 2024
f48fdc0
Create backbone.py
ywchoi02 Oct 20, 2024
aaae4b7
Create detr_vae.py
ywchoi02 Oct 20, 2024
07b5741
Create position_encoding.py
ywchoi02 Oct 20, 2024
1f89c29
Create transformer.py
ywchoi02 Oct 20, 2024
479841f
Update evaluate.py
ywchoi02 Oct 21, 2024
bdd1660
Create README.md
ywchoi02 Oct 21, 2024
09d5411
Changed to absolute import
ywchoi02 Oct 21, 2024
622a36e
Update absolute import
ywchoi02 Oct 21, 2024
4932d9b
change to absolute import
ywchoi02 Oct 21, 2024
37eb2af
change to absolute import
ywchoi02 Oct 21, 2024
e5ebaa4
fix import
ywchoi02 Oct 22, 2024
e5e0aeb
fix import
ywchoi02 Oct 22, 2024
35979e9
fix import
ywchoi02 Oct 22, 2024
04972f3
Create setup.py
ywchoi02 Oct 22, 2024
673d453
Merge branch 'haosulab:main' into main
ywchoi02 Nov 2, 2024
4add469
Create examples.sh
ywchoi02 Nov 2, 2024
ddcfb1b
Merge branch 'haosulab:main' into main
ywchoi02 Nov 17, 2024
edeeea5
Update train.py
ywchoi02 Nov 17, 2024
492bfbd
Update (vectorized) evaluate.py
ywchoi02 Nov 17, 2024
7f5f636
Update examples.sh
ywchoi02 Nov 17, 2024
3a954a1
Merge branch 'haosulab:main' into main
ywchoi02 Nov 21, 2024
ea902fb
Merge branch 'haosulab:main' into main
ywchoi02 Nov 26, 2024
ff53d8c
Update evaluate.py (incorporate visual_data)
ywchoi02 Nov 26, 2024
a7989e9
Update backbone.py (for rgbd data)
ywchoi02 Nov 26, 2024
d2b3eb6
Update detr_vae.py (for rgbd data)
ywchoi02 Nov 26, 2024
881401f
Create train_rgbd.py
ywchoi02 Nov 26, 2024
d7b33f0
Delete examples/baselines/act/train_rgb.py
ywchoi02 Nov 26, 2024
663fd82
Update examples.sh (include rgbd examples)
ywchoi02 Nov 26, 2024
8735baf
Update train_rgbd.py (bug fix)
ywchoi02 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/baselines/act/act/detr/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from ..utils import NestedTensor, is_main_process
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use absolute imports when possible, it is just the style choice this repo uses.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed them to absolute imports, but I'm not sure if they are correct. please let me know if they need to be fixed. I also added a README file.


from .position_encoding import build_position_encoding

import IPython
e = IPython.embed

class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.

Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias


class BackboneBase(nn.Module):

def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out


class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)


class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)

def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))

return out, pos


def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model
139 changes: 139 additions & 0 deletions examples/baselines/act/act/detr/detr_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer

import numpy as np

import IPython
e = IPython.embed


def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps


def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

return torch.FloatTensor(sinusoid_table).unsqueeze(0)


class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.backbones = None

# encoder extra parameters
self.latent_dim = 32 # size of latent z
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_state_proj = nn.Linear(state_dim, hidden_dim) # project state to embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], state, actions

# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for state and proprio

def forward(self, obs, actions=None):
is_training = actions is not None
state = obs['state'] if self.backbones is not None else obs
bs = state.shape[0]

if is_training:
# project CLS token, state sequence, and action sequence to embedding dim
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
state_embed = self.encoder_state_proj(state) # (bs, hidden_dim)
state_embed = torch.unsqueeze(state_embed, axis=1) # (bs, 1, hidden_dim)
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
# concat them together to form an input to the CVAE encoder
encoder_input = torch.cat([cls_embed, state_embed, action_embed], axis=1) # (bs, seq+2, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+2, bs, hidden_dim)
# no masking is applied to all parts of the CVAE encoder input
is_pad = torch.full((bs, encoder_input.shape[0]), False).to(state.device) # False: not a padding
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+2, 1, hidden_dim)
# query CVAE encoder
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(state.device)
latent_input = self.latent_out_proj(latent_sample)

# CVAE decoder
if self.backbones is not None:
vis_data = obs['rgb'] if "rgb" in obs else obs['rgbd']
num_cams = vis_data.shape[1]

# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id in range(num_cams):
features, pos = self.backbones[0](vis_data[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature # (batch, hidden_dim, H, W)
pos = pos[0] # (1, hidden_dim, H, W)
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)

# proprioception features (state)
proprio_input = self.input_proj_robot_state(state)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3) # (batch, hidden_dim, 4, 8)
pos = torch.cat(all_cam_pos, axis=3) # (batch, hidden_dim, 4, 8)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] # (batch, num_queries, hidden_dim)
else:
state = self.input_proj_robot_state(state)
hs = self.transformer(None, None, self.query_embed.weight, None, latent_input, state, self.additional_pos_embed.weight)[0]

a_hat = self.action_head(hs)
return a_hat, [mu, logvar]


def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

return encoder
93 changes: 93 additions & 0 deletions examples/baselines/act/act/detr/position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn

from ..utils import NestedTensor

import IPython
e = IPython.embed

class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale

def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask

not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos


class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()

def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)

def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos


def build_position_encoding(args):
N_steps = args.hidden_dim // 2
if args.position_embedding in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif args.position_embedding in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")

return position_embedding
Loading