Skip to content

Commit

Permalink
Fix bugs in training loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
wiktorlazarski committed May 10, 2022
1 parent b6361b4 commit 6ce0d0c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions scripts/training/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(
size_augmentation_keys: t.Optional[t.List[str]] = None,
content_augmentation_keys: t.Optional[t.List[str]] = None,
):
super().__init__()

self.dataset_root = dataset_root
self.nn_image_input_resolution = nn_image_input_resolution
self.batch_size = batch_size
Expand Down
5 changes: 4 additions & 1 deletion scripts/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,17 @@ def main(configs: omegaconf.DictConfig) -> None:
log_every_n_steps=1,
callbacks=[early_stop_callback, model_ckpt_callback],
max_epochs=configs.training.max_epochs,
weights_save_path="models",
gpus=1 if configs.training.with_gpu else 0,
)

# Train loop
logger.info("🏋️ Starting training loop.")
nn_trainer.fit(nn_module, dataset_module)

# Display best model based on monitored metric
logger.info(f"🥇 Best model: {model_ckpt_callback.best_model_path}")
nn_module.load_from_checkpoint(model_ckpt_callback.best_model_path)

# Test loop
logger.info("🧪 Starting testing loop.")
nn_trainer.test(nn_module, dataset_module)
Expand Down

0 comments on commit 6ce0d0c

Please sign in to comment.