-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain_params.py
45 lines (37 loc) · 1.48 KB
/
pretrain_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from dataclasses import dataclass
import torchvision.models as models
from omegaconf import SI
model_names = sorted(
name
for name in models.__dict__
if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
)
@dataclass
class PretrainParams:
data: str = r"~/Imagenet-100" # path to dataset containing train and val folders
arch: str = "resnet50"
workers: int = 8
batch_size: int = 128
epochs: int = SI("${trainer_params.max_epochs}")
lr: float = 0.03
schedule: tuple = (120, 160)
momentum: float = 0.9
weight_decay: float = 1e-4
# moco specific configs:
moco_dim: int = 128
moco_k: int = 16_384 # queue size; number of negative keys (default: 65536)
moco_m: float = 0.999 # moco momentum of updating key encoder (default: 0.999)
moco_t: float = 0.2 # softmax temperature (default: 0.07)"
# options for moco v2
mlp: bool = True # use mlp head
aug_plus: bool = True # use moco v2 data augmentation
cos: bool = True # use cosine lr schedule
# region MOCHI PARAMETERS
mochi: bool = True # use mochi
mochi_N: int = 1024 # hard negative pool size
mochi_s: int = 1024 # number of harder negative btw 2 pairs of hard negatives
mochi_s_prime: int = 512 # number of harder negative btw a negative and positive
mochi_tau: float = SI("${pretraining_params.moco_t}")
mochi_warmup: int = (
10 # learning rate warmup epochs + do not synthesize hard negatives
)