From 8677c04dad82214aa3a42434bc847eeee94ba78c Mon Sep 17 00:00:00 2001 From: Florian Vahl Date: Sun, 27 Oct 2024 14:15:57 +0100 Subject: [PATCH] Add better norm --- .../run_diffusion_context_transformer_robotpy | 5 ++++- ...train_diffusion_context_transformer_robot.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ddlitlab2024/ml/run_diffusion_context_transformer_robotpy b/ddlitlab2024/ml/run_diffusion_context_transformer_robotpy index 9d6bc2a..3189977 100644 --- a/ddlitlab2024/ml/run_diffusion_context_transformer_robotpy +++ b/ddlitlab2024/ml/run_diffusion_context_transformer_robotpy @@ -47,7 +47,7 @@ model = TrajectoryTransformerModel( ).to(device) ema = EMA(model, beta=0.9999) -scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2") +scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) scheduler.config.num_train_timesteps = train_timesteps # Load the model @@ -97,6 +97,9 @@ def sample_trajectory(length=1000, diffusion_steps=15): context = torch.cat([context, sampled_trajectory], dim=1) + # Undo the normalization + context = context * model.std + model.mean + # Plot the sampled trajectory plt.figure(figsize=(12, 6)) for j in range(trajectory_dim): diff --git a/ddlitlab2024/ml/train_diffusion_context_transformer_robot.py b/ddlitlab2024/ml/train_diffusion_context_transformer_robot.py index 2d30bbf..7a39ce3 100644 --- a/ddlitlab2024/ml/train_diffusion_context_transformer_robot.py +++ b/ddlitlab2024/ml/train_diffusion_context_transformer_robot.py @@ -23,6 +23,10 @@ def __init__(self, num_joints, hidden_dim, num_layers, num_heads, max_action_con num_joints=num_joints, hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, max_seq_len=trajectory_prediction_length ) + # Store normalization parameters + self.register_buffer("mean", torch.zeros(num_joints)) + self.register_buffer("std", torch.ones(num_joints)) + def forward(self, past_actions, noisy_action_predictions, step): # Encode the past actions context = self.action_history_encoder(past_actions) # This can be cached during inference TODO @@ -157,8 +161,10 @@ def forward(self, x): # Drop every second data point to reduce the sequence length (subsample) TODO proper subsampling data = data[::3] - # Normalize the joint data (-pi to pi) to (-1, 1) - data = data / np.pi + # Normalize the joint data + stds = data.std() + means = data.mean() + data = (data - means) / stds # Chunk the data into sequences of 50 timesteps timesteps = action_context_length + trajectory_prediction_length @@ -188,6 +194,11 @@ def forward(self, x): max_action_context_length=action_context_length, trajectory_prediction_length=trajectory_prediction_length, ).to(device) + + # Add normalization parameters to the model + model.mean = torch.tensor(means.values).to(device) + model.std = torch.tensor(stds.values).to(device) + ema = EMA(model, beta=0.9999) @@ -196,7 +207,7 @@ def forward(self, x): optimizer, max_lr=lr, total_steps=epochs * (num_samples // batch_size) ) - scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2") + scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) scheduler.config.num_train_timesteps = train_timesteps # Training loop