Skip to content

Commit

Permalink
Merge branch 'fix-version-conflicts' of github.com:adrianoesch/newsre…
Browse files Browse the repository at this point in the history
…clib into fix-version-conflicts
  • Loading branch information
adrianoesch committed Mar 18, 2024
2 parents 389385b + e06a906 commit 36bd49d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 4 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ of the corresponding modules.
num_heads: 15
```
For training the `NRMS` model on the `MINDlarge` dataset, execute the following command:

```python
python newsreclib/train.py experiment=nrms_mindlarge_pretrainedemb_celoss_bertsent
```

To understand how to adjust configuration files when transitioning from smaller to larger datasets, refer to the examples provided in `nrms_mindsmall_pretrainedemb_celoss_bertsent` and `nrms_mindlarge_pretrainedemb_celoss_bertsent`. These files will guide you in scaling your configurations appropriately.

*Note:* The same procedure applies for the advanced configuration shown below.

## Advanced Configuration

The advanced scenario depicts a more complex experimental setting.
Expand Down
12 changes: 11 additions & 1 deletion configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@ defaults:
- _self_
- data: null # choose datamodule with `test_dataloader()` for evaluation
- model: null
- callbacks: default.yaml
- logger: many_loggers.yaml
- trainer: default.yaml
- paths: default.yaml
- extras: default.yaml
- hydra: default.yaml

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null

# task name, determines output directory path
task_name: "eval"

# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["eval"]

# passing checkpoint path is necessary for evaluation
ckpt_path: ???
# example: logs/train/runs/nrms_mindsmall_pretrainedemb_celoss_bertsent_s42/2024-03-08_07-48-40/checkpoints/last.ckpt
ckpt_path: ???
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: mind_rec_bert_sent.yaml
- override /model: nrms.yaml
- override /callbacks: default.yaml
- override /logger: many_loggers.yaml
- override /trainer: gpu.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["nrms", "mindlarge", "pretrainedemb", "celoss", "bertsent"]

seed: 42

data:
dataset_size: "large"

model:
use_plm: False
pretrained_embeddings_path: ${paths.data_dir}MINDlarge_train/transformed_word_embeddings.npy
embed_dim: 300
num_heads: 15
query_dim: 200
dropout_probability: 0.2

callbacks:
early_stopping:
patience: 5

trainer:
max_epochs: 20

logger:
wandb:
name: "nrms_mindlarge_pretrainedemb_celoss_bertsent_s42"
tags: ${tags}
group: "mind"
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies:
- sphinx=5.*
- tokenizers=0.*
- torchaudio=2.*
- torchmetrics=0.*
- torchmetrics=1.*
- tqdm=4.*
- transformers=4.*
- wandb=0.*
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tokenizers==0.13.2
torch==2.0.1
torchaudio==2.0.2
torch-geometric==2.3.0
torchmetrics==0.11.4
torchmetrics==1.3.1
torchvision==0.15.2
tqdm==4.65.0
transformers==4.36.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"seaborn==0.12.2",
"sphinx==5.0.2",
"tokenizers==0.15.2",
"torchmetrics==0.11.4",
"torchmetrics==1.3.1",
"tqdm==4.65.0",
"transformers==4.38.2",
"wandb==0.15.3",
Expand Down

0 comments on commit 36bd49d

Please sign in to comment.