Skip to content

Commit

Permalink
Edge case bug fix in dyffusion training with small batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
salvaRC committed Nov 28, 2023
1 parent 266fdad commit 38a51cb
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/diffusion/dyffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,20 +504,24 @@ def p_losses(self, xt_last: Tensor, condition: Tensor, t: Tensor, static_conditi
lam1 = self.hparams.lambda_reconstruction
lam2 = self.hparams.lambda_reconstruction2

# Get the interpolated
# since we do not need to interpolate xt_0, we can skip all batches where t=0
t_nonzero = t > 0
x_interpolated = self.q_sample(
x_end=condition[t_nonzero],
x0=xt_last[t_nonzero],
t=t[t_nonzero],
static_condition=None if static_condition is None else static_condition[t_nonzero],
num_predictions=1, # sample one interpolation prediction
)
# Now, simply concatenate the inital_conditions for t=0 with the interpolated data for t>0
# Create the inputs for the forecasting model
# 1. For t=0, simply use the initial conditions
x_t = condition.clone()
x_t[t_nonzero] = x_interpolated.to(x_t.dtype)
# assert torch.all(x_t[t == 0] == condition[t == 0]), f'x_t[t == 0] != condition[t == 0]'

# 2. For t>0, we need to interpolate the data using the interpolator
t_nonzero = t > 0
if t_nonzero.any():
x_interpolated = self.q_sample(
x_end=condition[t_nonzero],
x0=xt_last[t_nonzero],
t=t[t_nonzero],
static_condition=None if static_condition is None else static_condition[t_nonzero],
num_predictions=1, # sample one interpolation prediction
)
# Now, simply concatenate the inital_conditions for t=0 with the interpolated data for t>0
x_t[t_nonzero] = x_interpolated.to(x_t.dtype)
# assert torch.all(x_t[t == 0] == condition[t == 0])

# Train the forward predictions (i.e. predict xt_last from xt_t)
xt_last_target = xt_last
xt_last_pred = self.predict_x_last(condition=condition, x_t=x_t, t=t, static_condition=static_condition)
Expand Down

0 comments on commit 38a51cb

Please sign in to comment.