Skip to content

Commit

Permalink
Add better norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Oct 27, 2024
1 parent b6ab2ce commit 8677c04
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
5 changes: 4 additions & 1 deletion ddlitlab2024/ml/run_diffusion_context_transformer_robotpy
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ model = TrajectoryTransformerModel(
).to(device)
ema = EMA(model, beta=0.9999)

scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2")
scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
scheduler.config.num_train_timesteps = train_timesteps

# Load the model
Expand Down Expand Up @@ -97,6 +97,9 @@ def sample_trajectory(length=1000, diffusion_steps=15):

context = torch.cat([context, sampled_trajectory], dim=1)

# Undo the normalization
context = context * model.std + model.mean

# Plot the sampled trajectory
plt.figure(figsize=(12, 6))
for j in range(trajectory_dim):
Expand Down
17 changes: 14 additions & 3 deletions ddlitlab2024/ml/train_diffusion_context_transformer_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_action_con
num_joints=num_joints, hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, max_seq_len=trajectory_prediction_length
)

# Store normalization parameters
self.register_buffer("mean", torch.zeros(num_joints))
self.register_buffer("std", torch.ones(num_joints))

def forward(self, past_actions, noisy_action_predictions, step):
# Encode the past actions
context = self.action_history_encoder(past_actions) # This can be cached during inference TODO
Expand Down Expand Up @@ -157,8 +161,10 @@ def forward(self, x):
# Drop every second data point to reduce the sequence length (subsample) TODO proper subsampling
data = data[::3]

# Normalize the joint data (-pi to pi) to (-1, 1)
data = data / np.pi
# Normalize the joint data
stds = data.std()
means = data.mean()
data = (data - means) / stds

# Chunk the data into sequences of 50 timesteps
timesteps = action_context_length + trajectory_prediction_length
Expand Down Expand Up @@ -188,6 +194,11 @@ def forward(self, x):
max_action_context_length=action_context_length,
trajectory_prediction_length=trajectory_prediction_length,
).to(device)

# Add normalization parameters to the model
model.mean = torch.tensor(means.values).to(device)
model.std = torch.tensor(stds.values).to(device)

ema = EMA(model, beta=0.9999)


Expand All @@ -196,7 +207,7 @@ def forward(self, x):
optimizer, max_lr=lr, total_steps=epochs * (num_samples // batch_size)
)

scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2")
scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
scheduler.config.num_train_timesteps = train_timesteps

# Training loop
Expand Down

0 comments on commit 8677c04

Please sign in to comment.