Skip to content

Commit

Permalink
let user decide evaluation batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
salvaRC committed Nov 29, 2023
1 parent 38a51cb commit e4523b3
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/configs/experiment/navier_stokes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ name: "NavierStokes"

datamodule:
batch_size: 32
eval_batch_size: ${datamodule.batch_size}
eval_batch_size: 4 # effectively eval_batch_size *= number of predictions ($module.num_predictions; default=20)
horizon: 16
prediction_horizon: 64
window: 1
Expand Down
2 changes: 1 addition & 1 deletion src/configs/experiment/oisst_pacific.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name: "OISSTv2PacificSubset"

datamodule:
batch_size: 64
eval_batch_size: ${datamodule.batch_size}
eval_batch_size: 6 # effectively eval_batch_size *= number of predictions ($module.num_predictions; default=20)
horizon: 7
prediction_horizon: 7
boxes: [84, 85, 86, 87, 88, 89, 108, 109, 110, 111, 112]
Expand Down
2 changes: 1 addition & 1 deletion src/configs/experiment/spring_mesh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ name: "SpringMesh"
datamodule:
physical_system: 'spring-mesh'
batch_size: 64
eval_batch_size: 512
eval_batch_size: 16 # effectively eval_batch_size *= number of predictions ($module.num_predictions; default=20)
horizon: 134
prediction_horizon: 804
window: 1
Expand Down
8 changes: 3 additions & 5 deletions src/utilities/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import os
import time
import warnings
Expand Down Expand Up @@ -464,13 +463,12 @@ def check_config_values(config: DictConfig):
config.save_config_to_wandb = False

if config.module.get("num_predictions", 1) > 1:
# adapt the evaluation batch size to the number of predictions
bs, ebs = config.datamodule.batch_size, config.datamodule.eval_batch_size
if ebs >= bs:
# reduce the eval batch size to account for the number of predictions
config.datamodule.eval_batch_size = max(1, int(bs // math.sqrt(config.module.num_predictions)))
effective_ebs = ebs * config.module.num_predictions
log.info(
f"Reducing eval batch size from {ebs} to {config.datamodule.eval_batch_size} to match the number of predictions."
f"Note that the effective evaluation batch size will be multiplied by the number of "
f"predictions={config.module.num_predictions} for a total of {effective_ebs}!"
)


Expand Down

0 comments on commit e4523b3

Please sign in to comment.