Skip to content

Commit

Permalink
more sensible pre-training defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
MorganCThomas committed May 27, 2024
1 parent a5263c6 commit 9b8e10b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 7 additions & 7 deletions scripts/pretrain/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ dataset_log_dir: /tmp/pretrain # if recomputing dataset, save it here
# Model configuration
model: gru # gru, lstm, or gpt2
custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model)
model_log_dir: /tmp/pretrain # save model here
model_log_dir: test #/tmp/pretrain # save model here

# Training configuration
lr: 0.0001
lr: 0.001
lr_scheduler: StepLR
lr_scheduler_kwargs:
step_size: 1
gamma: 1.0 # no decay
epochs: 10
batch_size: 8
randomize_smiles: False
step_size: 500
gamma: 0.97 # 1.0 = no decay
epochs: 50
batch_size: 128
randomize_smiles: True # Sample a random variant during training, therefore, for 10-fold augmentation on a dataset for 5 epochs, do 10*5=50 epochs.
num_test_smiles: 100
6 changes: 3 additions & 3 deletions scripts/pretrain/pretrain_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,11 @@ def main(cfg: "DictConfig"):
"mols", wandb.Image(image), step=total_smiles
)
logger.log_scalar(
"lr", lr_scheduler.get_lr()[0], step=total_smiles
"lr", lr_scheduler.get_last_lr()[0], step=total_smiles
)

# Decay learning rate
lr_scheduler.step()
# Decay learning rate
lr_scheduler.step()

save_path = Path(cfg.model_log_dir) / f"pretrained_actor_epoch_{epoch}.pt"
torch.save(actor_training.state_dict(), save_path)
Expand Down

0 comments on commit 9b8e10b

Please sign in to comment.