Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from SGD optimizer to AdamW for patch classification tasks #690

Merged
merged 9 commits into from
Oct 18, 2024
6 changes: 2 additions & 4 deletions configs/core/tests/offline/embeddings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ model:
out_features: &NUM_CLASSES 2
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.1}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
6 changes: 2 additions & 4 deletions configs/vision/pathology/offline/classification/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ model:
out_features: &NUM_CLASSES 4
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
Expand Down
8 changes: 3 additions & 5 deletions configs/vision/pathology/offline/classification/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ model:
out_features: &NUM_CLASSES 9
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down Expand Up @@ -104,7 +102,7 @@ data:
split: val
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
Expand Down
10 changes: 4 additions & 6 deletions configs/vision/pathology/offline/classification/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 51
patience: 70
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
- class_path: eva.callbacks.ClassificationEmbeddingsWriter
Expand Down Expand Up @@ -53,11 +53,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
Expand Down Expand Up @@ -54,11 +54,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down Expand Up @@ -118,7 +116,7 @@ data:
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
Expand Down
6 changes: 2 additions & 4 deletions configs/vision/pathology/online/classification/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ model:
out_features: &NUM_CLASSES 4
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
8 changes: 3 additions & 5 deletions configs/vision/pathology/online/classification/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ model:
out_features: &NUM_CLASSES 9
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down Expand Up @@ -86,7 +84,7 @@ data:
split: val
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
Expand Down
10 changes: 4 additions & 6 deletions configs/vision/pathology/online/classification/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
min_delta: 0
patience: 51
patience: 70
monitor: *MONITOR_METRIC
mode: *MONITOR_METRIC_MODE
logger:
Expand All @@ -45,11 +45,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.000625}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:
filename: best
save_last: true
save_top_k: 1
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryAccuracy}
monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy}
mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max}
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
Expand Down Expand Up @@ -45,11 +45,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.01}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down Expand Up @@ -91,7 +89,7 @@ data:
split: test
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 4096}
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256}
num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4}
shuffle: true
val:
Expand Down
6 changes: 2 additions & 4 deletions configs/vision/tests/offline/panda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ model:
output_size: &NUM_CLASSES 6
criterion: torch.nn.CrossEntropyLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: &LR_VALUE ${oc.env:LR_VALUE, 0.00004}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
6 changes: 2 additions & 4 deletions configs/vision/tests/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.1}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
6 changes: 2 additions & 4 deletions configs/vision/tests/online/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ model:
out_features: 1
criterion: torch.nn.BCEWithLogitsLoss
optimizer:
class_path: torch.optim.SGD
class_path: torch.optim.AdamW
init_args:
lr: ${oc.env:LR_VALUE, 0.1}
momentum: 0.9
weight_decay: 0.0
lr: ${oc.env:LR_VALUE, 0.0003}
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
Expand Down
18 changes: 8 additions & 10 deletions docs/leaderboards.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,24 @@ We selected this approach to prioritize reliable, robust and fair FM-evaluation
| **Dropout** | 0.0 | 0.0 | 0.0 |
| **Hidden activation function** | n/a | ReLU | n/a |
| **Output activation function** | none | none | none |
| **Number of steps** | 12,500 | 12,500 (2) | 2,000 |
nkaenzig marked this conversation as resolved.
Show resolved Hide resolved
| **Base batch size** | 4,096 (1) | 32 | 64 |
| **Base learning rate** | 0.01 (1) | 0.001 | 0.0001 |
| **Early stopping** | 5% * [Max epochs] | 10% * [Max epochs] (3) | 10% * [Max epochs] (3) |
| **Number of steps** | 12,500 | 12,500 (1) | 2,000 |
| **Base batch size** | 256 | 32 | 64 |
| **Base learning rate** | 0.0003 | 0.001 | 0.0001 |
| **Early stopping** | 5% * [Max epochs] | 10% * [Max epochs] (2) | 10% * [Max epochs] (2) |
| **Optimizer** | SGD | AdamW | AdamW |
| **Momentum** | 0.9 | n/a | n/a |
| **Weight Decay** | 0.0 | n/a | n/a |
| **betas** | n/a | [0.9, 0.999] | [0.9, 0.999] |
| **LR Schedule** | Cosine without warmup | Cosine without warmup | PolynomialLR |
| **Loss** | Cross entropy | Cross entropy | Dice |
| **number of patches per slide**| 1 | dataset specific (4) | dataset specific (4) |
| **number of patches per slide**| 1 | dataset specific (3) | dataset specific (3) |


(1) For smaller datasets (e.g. BACH with 400 samples) we reduce the batch size to 256 and scale the learning rate accordingly.
(1) Upper cap at a maximum of 100 epochs.

(2) Upper cap at a maximum of 100 epochs.
(2) Lower cap at a minimum of 8 epochs.

(3) Lower cap at a minimum of 8 epochs.

(4) Number of patches per slide depends on task and slide size. E.g. for PANDA and Camelyon16 we use a max of 1,000 and 10,000 random patches per slide respectively.
(3) Number of patches per slide depends on task and slide size. E.g. for `PANDASmall` and `Camelyon16Small` we use a max of 200 and 1000 random patches per slide respectively.


- [1]: [Virchow: A Million-Slide Digital Pathology Foundation Model, 2024](https://arxiv.org/pdf/2309.07778.pdf)