diff --git a/ddlitlab2024/ml/train_diffusion_transformer_robot.py b/ddlitlab2024/ml/train_diffusion_transformer_robot.py index c090ff2..c6454e0 100644 --- a/ddlitlab2024/ml/train_diffusion_transformer_robot.py +++ b/ddlitlab2024/ml/train_diffusion_transformer_robot.py @@ -165,12 +165,15 @@ def forward(self, x): optimizer, max_lr=lr, total_steps=epochs * (num_samples // batch_size) ) -scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2") +scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) scheduler.config.num_train_timesteps = train_timesteps # Training loop for epoch in range(epochs): # Number of training epochs mean_loss = 0 + # Shuffle the data for each epoch + real_trajectories = real_trajectories[torch.randperm(real_trajectories.size(0))] + for batch in tqdm(range(num_samples // batch_size)): targets = real_trajectories[batch * batch_size : (batch + 1) * batch_size].to(device)