From 38a51cb59f0a08ea592c11f6333f73068573a937 Mon Sep 17 00:00:00 2001 From: salvaRC Date: Mon, 27 Nov 2023 21:30:42 -0800 Subject: [PATCH] Edge case bug fix in dyffusion training with small batch size --- src/diffusion/dyffusion.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/diffusion/dyffusion.py b/src/diffusion/dyffusion.py index 04d0be4..0ad13f4 100644 --- a/src/diffusion/dyffusion.py +++ b/src/diffusion/dyffusion.py @@ -504,20 +504,24 @@ def p_losses(self, xt_last: Tensor, condition: Tensor, t: Tensor, static_conditi lam1 = self.hparams.lambda_reconstruction lam2 = self.hparams.lambda_reconstruction2 - # Get the interpolated - # since we do not need to interpolate xt_0, we can skip all batches where t=0 - t_nonzero = t > 0 - x_interpolated = self.q_sample( - x_end=condition[t_nonzero], - x0=xt_last[t_nonzero], - t=t[t_nonzero], - static_condition=None if static_condition is None else static_condition[t_nonzero], - num_predictions=1, # sample one interpolation prediction - ) - # Now, simply concatenate the inital_conditions for t=0 with the interpolated data for t>0 + # Create the inputs for the forecasting model + # 1. For t=0, simply use the initial conditions x_t = condition.clone() - x_t[t_nonzero] = x_interpolated.to(x_t.dtype) - # assert torch.all(x_t[t == 0] == condition[t == 0]), f'x_t[t == 0] != condition[t == 0]' + + # 2. For t>0, we need to interpolate the data using the interpolator + t_nonzero = t > 0 + if t_nonzero.any(): + x_interpolated = self.q_sample( + x_end=condition[t_nonzero], + x0=xt_last[t_nonzero], + t=t[t_nonzero], + static_condition=None if static_condition is None else static_condition[t_nonzero], + num_predictions=1, # sample one interpolation prediction + ) + # Now, simply concatenate the inital_conditions for t=0 with the interpolated data for t>0 + x_t[t_nonzero] = x_interpolated.to(x_t.dtype) + # assert torch.all(x_t[t == 0] == condition[t == 0]) + # Train the forward predictions (i.e. predict xt_last from xt_t) xt_last_target = xt_last xt_last_pred = self.predict_x_last(condition=condition, x_t=x_t, t=t, static_condition=static_condition)