From 79f3db1a759de818d84650ce74faa303468c1f42 Mon Sep 17 00:00:00 2001 From: salvaRC Date: Fri, 4 Oct 2024 10:45:51 -0700 Subject: [PATCH] fix reproducing spring-mesh results with DYffusion --- src/configs/diffusion/dyffusion.yaml | 3 +++ src/configs/experiment/spring_mesh.yaml | 3 --- src/configs/experiment/spring_mesh_dyffusion.yaml | 1 + .../experiment/spring_mesh_interpolation.yaml | 7 +++++-- .../experiment/spring_mesh_time_conditioned.yaml | 4 ++++ src/diffusion/dyffusion.py | 12 ++++++++---- 6 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/configs/diffusion/dyffusion.yaml b/src/configs/diffusion/dyffusion.yaml index 6fdb085..7b92bde 100644 --- a/src/configs/diffusion/dyffusion.yaml +++ b/src/configs/diffusion/dyffusion.yaml @@ -50,5 +50,8 @@ diffusion: # python run.py mode=test diffusion.refine_intermediate_predictions=True logger.wandb.id=??? refine_intermediate_predictions: False + # Set to True to use the direct forecaster's prediction of x_{t+h} rather than a cold-sampled one (when sampling_type='cold') + use_cold_sampling_for_last_step: False + timesteps: ${datamodule.horizon} # Do not change, it is automatically inferred by DYffusion log_every_t: null diff --git a/src/configs/experiment/spring_mesh.yaml b/src/configs/experiment/spring_mesh.yaml index 294c869..4134940 100644 --- a/src/configs/experiment/spring_mesh.yaml +++ b/src/configs/experiment/spring_mesh.yaml @@ -16,9 +16,6 @@ datamodule: prediction_horizon: 804 window: 1 -model: - dropout: 0.05 - module: optimizer: lr: 4e-4 diff --git a/src/configs/experiment/spring_mesh_dyffusion.yaml b/src/configs/experiment/spring_mesh_dyffusion.yaml index 548b792..6363613 100644 --- a/src/configs/experiment/spring_mesh_dyffusion.yaml +++ b/src/configs/experiment/spring_mesh_dyffusion.yaml @@ -14,6 +14,7 @@ diffusion: interpolator_run_id: ??? # Please fill in the wandb run id of the trained interpolator refine_intermediate_predictions: True forward_conditioning: "data" + use_cold_sampling_for_last_step: False logger: wandb: diff --git a/src/configs/experiment/spring_mesh_interpolation.yaml b/src/configs/experiment/spring_mesh_interpolation.yaml index ae96783..3d9e1c9 100644 --- a/src/configs/experiment/spring_mesh_interpolation.yaml +++ b/src/configs/experiment/spring_mesh_interpolation.yaml @@ -10,11 +10,14 @@ defaults: name: "SpringMesh-Interpolation${datamodule.horizon}h" +model: + dropout: 0.05 + module: enable_inference_dropout: True -trainer: - max_epochs: 400 +#trainer: +# max_epochs: 400 logger: wandb: diff --git a/src/configs/experiment/spring_mesh_time_conditioned.yaml b/src/configs/experiment/spring_mesh_time_conditioned.yaml index 41a20ec..be47204 100644 --- a/src/configs/experiment/spring_mesh_time_conditioned.yaml +++ b/src/configs/experiment/spring_mesh_time_conditioned.yaml @@ -9,5 +9,9 @@ defaults: - _self_ name: "SpringMesh-MH${datamodule.horizon}-TC" + +model: + dropout: 0.05 + module: enable_inference_dropout: True \ No newline at end of file diff --git a/src/diffusion/dyffusion.py b/src/diffusion/dyffusion.py index 0ad13f4..bd104df 100644 --- a/src/diffusion/dyffusion.py +++ b/src/diffusion/dyffusion.py @@ -28,6 +28,7 @@ def __init__( refine_intermediate_predictions: bool = False, prediction_timesteps: Optional[Sequence[float]] = None, enable_interpolator_dropout: Union[bool, str] = True, + use_cold_sampling_for_last_step: bool = False, log_every_t: Union[str, int] = None, *args, **kwargs, @@ -378,10 +379,13 @@ def sample_loop( x_interpolated_s_next = x0_hat # for the last step, we use the final x0_hat prediction if self.hparams.sampling_type in ["cold"]: - # D(x_s, s) - x_interpolated_s = self.q_sample(**q_sample_kwargs, t=step_s, **sc_kw) if s > 0 else x_s - # for s = 0, we have x_s_degraded = x_s, so we just directly return x_s_degraded_next - x_s = x_s - x_interpolated_s + x_interpolated_s_next + if is_last_step and not self.hparams.use_cold_sampling_for_last_step: + x_s = x0_hat + else: + # D(x_s, s) + x_interpolated_s = self.q_sample(**q_sample_kwargs, t=step_s, **sc_kw) if s > 0 else x_s + # for s = 0, we have x_s_degraded = x_s, so we just directly return x_s_degraded_next + x_s = x_s - x_interpolated_s + x_interpolated_s_next elif self.hparams.sampling_type == "naive": x_s = x_interpolated_s_next