Skip to content

Commit

Permalink
Dont clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Oct 27, 2024
1 parent c0d836d commit 0f6fd8b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ddlitlab2024/ml/train_diffusion_transformer_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ def forward(self, x):
optimizer, max_lr=lr, total_steps=epochs * (num_samples // batch_size)
)

scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2")
scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
scheduler.config.num_train_timesteps = train_timesteps

# Training loop
for epoch in range(epochs): # Number of training epochs
mean_loss = 0
# Shuffle the data for each epoch
real_trajectories = real_trajectories[torch.randperm(real_trajectories.size(0))]

for batch in tqdm(range(num_samples // batch_size)):
targets = real_trajectories[batch * batch_size : (batch + 1) * batch_size].to(device)

Expand Down

0 comments on commit 0f6fd8b

Please sign in to comment.