From 9b8e10bf1050f1cd2ce93a3da192daa44a84dcb0 Mon Sep 17 00:00:00 2001 From: MorganCThomas Date: Mon, 27 May 2024 17:42:21 +0200 Subject: [PATCH] more sensible pre-training defaults --- scripts/pretrain/config.yaml | 14 +++++++------- scripts/pretrain/pretrain_single_node.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/pretrain/config.yaml b/scripts/pretrain/config.yaml index f881e923..1cd537ac 100644 --- a/scripts/pretrain/config.yaml +++ b/scripts/pretrain/config.yaml @@ -14,15 +14,15 @@ dataset_log_dir: /tmp/pretrain # if recomputing dataset, save it here # Model configuration model: gru # gru, lstm, or gpt2 custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) -model_log_dir: /tmp/pretrain # save model here +model_log_dir: test #/tmp/pretrain # save model here # Training configuration -lr: 0.0001 +lr: 0.001 lr_scheduler: StepLR lr_scheduler_kwargs: - step_size: 1 - gamma: 1.0 # no decay -epochs: 10 -batch_size: 8 -randomize_smiles: False + step_size: 500 + gamma: 0.97 # 1.0 = no decay +epochs: 50 +batch_size: 128 +randomize_smiles: True # Sample a random variant during training, therefore, for 10-fold augmentation on a dataset for 5 epochs, do 10*5=50 epochs. num_test_smiles: 100 diff --git a/scripts/pretrain/pretrain_single_node.py b/scripts/pretrain/pretrain_single_node.py index 419191d2..dcae2a46 100644 --- a/scripts/pretrain/pretrain_single_node.py +++ b/scripts/pretrain/pretrain_single_node.py @@ -217,11 +217,11 @@ def main(cfg: "DictConfig"): "mols", wandb.Image(image), step=total_smiles ) logger.log_scalar( - "lr", lr_scheduler.get_lr()[0], step=total_smiles + "lr", lr_scheduler.get_last_lr()[0], step=total_smiles ) - # Decay learning rate - lr_scheduler.step() + # Decay learning rate + lr_scheduler.step() save_path = Path(cfg.model_log_dir) / f"pretrained_actor_epoch_{epoch}.pt" torch.save(actor_training.state_dict(), save_path)