-
Notifications
You must be signed in to change notification settings - Fork 6
/
config_patch.py
105 lines (87 loc) · 2.47 KB
/
config_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import ml_collections as mlc
import os
def get_config(arg=None):
config = mlc.ConfigDict()
config.benchmark = "patch"
###
# Inner Loop
###
config.inner = dict()
config.inner.layer_type = 'TTT' # 'TTT' | 'self_attention' | 'linear_attention'
# Linear Attention
config.inner.linear_attention = dict()
config.inner.linear_attention.elu = True
config.inner.linear_attention.normalizer = 'adaptive' # 'adaptive' | 'constant'
# TTT Layer
config.inner.TTT = dict()
config.inner.TTT.inner_encoder = "mlp_2"
config.inner.TTT.inner_itr = 1
config.inner.TTT.inner_lr = (1.,)
config.inner.TTT.train_init = True
config.inner.TTT.inner_encoder_init = "xavier_uniform"
config.inner.TTT.inner_encoder_bias = True
config.inner.TTT.decoder_LN = True
config.inner.TTT.SGD = False # For patch benchmark, use batch GD for inner optimization
###
###
# Routine & Logging
###
config.total_epochs = 90
config.resume = False
###
###
# Optimizer & Scheduler
###
config.grad_clip_norm = 1.0
config.optax_name = "scale_by_adam"
config.optax = dict(mu_dtype="bfloat16")
config.lr = 0.001
config.wd = 0.0001
config.wd_mults = [(".*/kernel$", 1.0)]
###
###
# Common
###
config.seed = 0
config.tf_seed = 0
config.num_classes = 1000
config.loss = "softmax_xent"
config.model = "small" # "tiny" | "small"
config.tfds_path = "" # TODO: Change to your custom data path
config.input = {}
config.input.data = dict(
name="imagenet2012",
split="train",
data_dir=config.tfds_path,
)
config.input.batch_size = 1024
config.input.accum_time = 1
config.input.cache_raw = True # Needs up to 120GB of RAM!
config.input.shuffle_buffer_size = 250_000
config.mixup = dict(p=0.2, fold_in=None)
config.pp_common = (
"|value_range(-1, 1)"
"|onehot(1000, key='label', key_result='labels')"
"|keep('image', 'labels')"
)
config.input.pp = (
"decode_jpeg_and_inception_crop(224)"
"|flip_lr"
"|randaug(2,10)"
) + config.pp_common
pp_eval = (
"decode"
"|resize_small(256)"
"|central_crop(224)"
) + config.pp_common
config.evals = {
"type": "classification",
"data": dict(name="imagenet2012",
split="validation",
data_dir=config.tfds_path,
),
"pp_fn": pp_eval,
"loss_name": config.loss,
}
###
return config