From 0f6fd8b080d1e67bff729430429851e534dcabc9 Mon Sep 17 00:00:00 2001 From: Florian Vahl Date: Sun, 27 Oct 2024 12:56:18 +0100 Subject: [PATCH] Dont clip --- ddlitlab2024/ml/train_diffusion_transformer_robot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ddlitlab2024/ml/train_diffusion_transformer_robot.py b/ddlitlab2024/ml/train_diffusion_transformer_robot.py index c090ff2..c6454e0 100644 --- a/ddlitlab2024/ml/train_diffusion_transformer_robot.py +++ b/ddlitlab2024/ml/train_diffusion_transformer_robot.py @@ -165,12 +165,15 @@ def forward(self, x): optimizer, max_lr=lr, total_steps=epochs * (num_samples // batch_size) ) -scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2") +scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False) scheduler.config.num_train_timesteps = train_timesteps # Training loop for epoch in range(epochs): # Number of training epochs mean_loss = 0 + # Shuffle the data for each epoch + real_trajectories = real_trajectories[torch.randperm(real_trajectories.size(0))] + for batch in tqdm(range(num_samples // batch_size)): targets = real_trajectories[batch * batch_size : (batch + 1) * batch_size].to(device)