Skip to content

Commit

Permalink
Add context transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Oct 27, 2024
1 parent 0f6fd8b commit b6ab2ce
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 462 deletions.
116 changes: 116 additions & 0 deletions ddlitlab2024/ml/run_diffusion_context_transformer_robotpy
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from ddlitlab2024.ml.train_diffusion_context_transformer_robot import TrajectoryTransformerModel

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define hyperparameters
hidden_dim = 256
num_layers = 4
num_heads = 4
sequence_length = 100
train_timesteps = 1000
action_context_length = 100
trajectory_prediction_length = 10

# Extract the joint command data all joints, and drop the time column
joints = [
"LHipYaw",
"RHipYaw",
"LHipRoll",
"RHipRoll",
"LHipPitch",
"RHipPitch",
"LKnee",
"RKnee",
"LAnklePitch",
"RAnklePitch",
"LAnkleRoll",
"RAnkleRoll",
]
trajectory_dim = len(joints)

# Initialize the Transformer model and optimizer, and move model to device
model = TrajectoryTransformerModel(
num_joints=trajectory_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
num_heads=num_heads,
max_action_context_length=action_context_length,
trajectory_prediction_length=trajectory_prediction_length,
).to(device)
ema = EMA(model, beta=0.9999)

scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2")
scheduler.config.num_train_timesteps = train_timesteps

# Load the model
ema.load_state_dict(torch.load("trajectory_transformer_model.pth"))


# Sampling a new trajectory after training
def sample_trajectory(length=1000, diffusion_steps=15):
scheduler.set_timesteps(diffusion_steps)

context = torch.zeros(1, 0, trajectory_dim).to(device)

for _ in range(length // trajectory_prediction_length):
sampled_trajectory = torch.randn(1, trajectory_prediction_length, trajectory_dim).to(device)

plt.figure(figsize=(12, 6))

for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = ema(context[:, -min(action_context_length, context.size(1)):, :], sampled_trajectory, torch.tensor([t], device=device))

# Normally we'd rely on the scheduler to handle the update step:
sampled_trajectory = scheduler.step(noise_pred, t, sampled_trajectory).prev_sample

# Plot the sampled trajectory
for j in range(trajectory_dim):
plt.subplot(3, 4, j + 1)

color = cm.viridis(t / scheduler.timesteps[0])
plt.plot(torch.arange(trajectory_prediction_length) + context.size(1),
sampled_trajectory[0, :, j].cpu(), label=f"Step {t}", color=color)
# Scale the y-axis to the range of the training data
plt.ylim(-1, 1)
plt.title(f"Joint {joints[j]}")

# Plot the context and the sampled trajectory
for j in range(trajectory_dim):
plt.subplot(3, 4, j + 1)
plt.plot(context[0, :, j].cpu(), label="Context")
plt.title(f"Joint {joints[j]}")

plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.legend()
plt.show()

context = torch.cat([context, sampled_trajectory], dim=1)

# Plot the sampled trajectory
plt.figure(figsize=(12, 6))
for j in range(trajectory_dim):
plt.subplot(3, 4, j + 1)
plt.plot(context[0, :, j].cpu(), label="Sampled Trajectory")
# Scale the y-axis to the range of the training data
plt.ylim(-1, 1)
plt.title(f"Joint {joints[j]}")
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.legend()
plt.show()


for _ in range(20):
# Plot the sampled trajectory
sample_trajectory()
72 changes: 1 addition & 71 deletions ddlitlab2024/ml/run_diffusion_transformer.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,14 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from torch import nn
from ddlitlab2024.ml.train_diffusion_transformer import TrajectoryTransformerModel

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TrajectoryTransformerModel(nn.Module):
def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_seq_len):
super().__init__()
self.embedding = nn.Linear(num_joints, hidden_dim)
self.positional_encoding = PositionalEncoding(hidden_dim, max_seq_len + 1)
self.step_encoding = StepToken(hidden_dim, device=device)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim,
batch_first=True,
norm_first=True,
activation="gelu",
),
num_layers=num_layers,
)
self.fc_out = nn.Linear(hidden_dim, num_joints)

def forward(self, x, step):
# x shape: (batch_size, seq_len, joint, num_bins)
# Flatten the joint and bin dimensions into a single token dimension
x = x.view(x.size(0), x.size(1), -1)
# Embed the input
x = self.embedding(x)
# Positional encoding
x += self.positional_encoding(x)
# Add token for the step
x = torch.cat([self.step_encoding(step), x], dim=1)
# Memory tensor (not used)
memory = torch.zeros(x.size(0), 1, x.size(2)).to(x.device)
# Pass through the transformer decoder
out = self.transformer_decoder(x, memory) # Causal mask applied
# Remove the step token
out = out[:, 1:]
# Final classification layer (logits for each bin)
return self.fc_out(out)


# Positional Encoding class for the Transformer
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0)

def forward(self, x):
return self.pe[:, : x.size(1)].to(x.device)


# Sinosoidal step encoding
class StepToken(nn.Module):
def __init__(self, dim, device=device):
super().__init__()
self.dim = dim
self.token = nn.Parameter(torch.randn(1, dim // 2, device=device))

def forward(self, x):
half_dim = self.dim // 4
emb = torch.exp(torch.arange(half_dim, device=x.device) * -np.log(10000) / (half_dim - 1))
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos(), self.token.expand((x.size(0), self.dim // 2))), dim=-1).unsqueeze(1)
return emb


# Define dimensions for the Transformer model
trajectory_dim = 1 # 1D input for the sine wave
hidden_dim = 256
Expand Down
79 changes: 6 additions & 73 deletions ddlitlab2024/ml/run_diffusion_transformer_robot.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,14 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F # noqa
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from torch import nn
from ddlitlab2024.ml.train_diffusion_transformer_robot import TrajectoryTransformerModel

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TrajectoryTransformerModel(nn.Module):
def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_seq_len):
super().__init__()
self.embedding = nn.Linear(num_joints, hidden_dim)
self.positional_encoding = PositionalEncoding(hidden_dim, max_seq_len + 1)
self.step_encoding = StepToken(hidden_dim, device=device)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim,
batch_first=True,
norm_first=True,
activation="gelu",
),
num_layers=num_layers,
)
self.fc_out = nn.Linear(hidden_dim, num_joints)

def forward(self, x, step):
# x shape: (batch_size, seq_len, joint, num_bins)
# Flatten the joint and bin dimensions into a single token dimension
x = x.view(x.size(0), x.size(1), -1)
# Embed the input
x = self.embedding(x)
# Positional encoding
x += self.positional_encoding(x)
# Add token for the step
x = torch.cat([self.step_encoding(step), x], dim=1)
# Memory tensor (not used)
memory = torch.zeros(x.size(0), 1, x.size(2)).to(x.device)
# Pass through the transformer decoder
out = self.transformer_decoder(x, memory) # Causal mask applied
# Remove the step token
out = out[:, 1:]
# Final classification layer (logits for each bin)
return self.fc_out(out)


# Positional Encoding class for the Transformer
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0)

def forward(self, x):
return self.pe[:, : x.size(1)].to(x.device)


# Sinosoidal step encoding
class StepToken(nn.Module):
def __init__(self, dim, device=device):
super().__init__()
self.dim = dim
self.token = nn.Parameter(torch.randn(1, dim // 2, device=device))

def forward(self, x):
half_dim = self.dim // 4
emb = torch.exp(torch.arange(half_dim, device=x.device) * -np.log(10000) / (half_dim - 1))
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos(), self.token.expand((x.size(0), self.dim // 2))), dim=-1).unsqueeze(1)
return emb


# Define hyperparameters
hidden_dim = 256
num_layers = 4
Expand Down Expand Up @@ -113,7 +43,7 @@ def forward(self, x):
).to(device)
ema = EMA(model, beta=0.9999)

scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2")
scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
scheduler.config.num_train_timesteps = train_timesteps

# Load the model
Expand Down Expand Up @@ -142,13 +72,16 @@ def sample_trajectory(length=sequence_length, step_size=100, diffusion_steps=15)
# Plot the context and the sampled trajectory
context = torch.cat([context, sampled_trajectory], dim=1)

# Undo the normalization
context = context * model.std + model.mean

# Plot the sampled trajectory
plt.figure(figsize=(12, 6))
for j in range(trajectory_dim):
plt.subplot(3, 4, j + 1)
plt.plot(context[0, :, j].cpu(), label="Sampled Trajectory")
# Scale the y-axis to the range of the training data
plt.ylim(-1, 1)
plt.ylim(-3.5, 3.5)
plt.title(f"Joint {joints[j]}")
plt.xlabel("Time")
plt.ylabel("Amplitude")
Expand Down
Loading

0 comments on commit b6ab2ce

Please sign in to comment.