Skip to content

Commit

Permalink
fix reproducing spring-mesh results with DYffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
salvaRC committed Oct 4, 2024
1 parent 832574f commit 79f3db1
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/configs/diffusion/dyffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,8 @@ diffusion:
# python run.py mode=test diffusion.refine_intermediate_predictions=True logger.wandb.id=???
refine_intermediate_predictions: False

# Set to True to use the direct forecaster's prediction of x_{t+h} rather than a cold-sampled one (when sampling_type='cold')
use_cold_sampling_for_last_step: False

timesteps: ${datamodule.horizon} # Do not change, it is automatically inferred by DYffusion
log_every_t: null
3 changes: 0 additions & 3 deletions src/configs/experiment/spring_mesh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ datamodule:
prediction_horizon: 804
window: 1

model:
dropout: 0.05

module:
optimizer:
lr: 4e-4
Expand Down
1 change: 1 addition & 0 deletions src/configs/experiment/spring_mesh_dyffusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ diffusion:
interpolator_run_id: ??? # Please fill in the wandb run id of the trained interpolator
refine_intermediate_predictions: True
forward_conditioning: "data"
use_cold_sampling_for_last_step: False

logger:
wandb:
Expand Down
7 changes: 5 additions & 2 deletions src/configs/experiment/spring_mesh_interpolation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ defaults:

name: "SpringMesh-Interpolation${datamodule.horizon}h"

model:
dropout: 0.05

module:
enable_inference_dropout: True

trainer:
max_epochs: 400
#trainer:
# max_epochs: 400

logger:
wandb:
Expand Down
4 changes: 4 additions & 0 deletions src/configs/experiment/spring_mesh_time_conditioned.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@ defaults:
- _self_

name: "SpringMesh-MH${datamodule.horizon}-TC"

model:
dropout: 0.05

module:
enable_inference_dropout: True
12 changes: 8 additions & 4 deletions src/diffusion/dyffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
refine_intermediate_predictions: bool = False,
prediction_timesteps: Optional[Sequence[float]] = None,
enable_interpolator_dropout: Union[bool, str] = True,
use_cold_sampling_for_last_step: bool = False,
log_every_t: Union[str, int] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -378,10 +379,13 @@ def sample_loop(
x_interpolated_s_next = x0_hat # for the last step, we use the final x0_hat prediction

if self.hparams.sampling_type in ["cold"]:
# D(x_s, s)
x_interpolated_s = self.q_sample(**q_sample_kwargs, t=step_s, **sc_kw) if s > 0 else x_s
# for s = 0, we have x_s_degraded = x_s, so we just directly return x_s_degraded_next
x_s = x_s - x_interpolated_s + x_interpolated_s_next
if is_last_step and not self.hparams.use_cold_sampling_for_last_step:
x_s = x0_hat
else:
# D(x_s, s)
x_interpolated_s = self.q_sample(**q_sample_kwargs, t=step_s, **sc_kw) if s > 0 else x_s
# for s = 0, we have x_s_degraded = x_s, so we just directly return x_s_degraded_next
x_s = x_s - x_interpolated_s + x_interpolated_s_next

elif self.hparams.sampling_type == "naive":
x_s = x_interpolated_s_next
Expand Down

0 comments on commit 79f3db1

Please sign in to comment.