-
Notifications
You must be signed in to change notification settings - Fork 19
/
train_common.py
69 lines (55 loc) · 2.21 KB
/
train_common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
import latplan.main
from latplan.util import *
gs_annealing_epoch = 1000
main_epoch = 1000
# parameters : a dictionary of { name : [ *values ] or value }.
# If the value is a list, it is interpreted as a hyperparameter choice.
# If it is a non-list, single value, it is interpreted as a fixed hyperparameter.
parameters = {
'test_noise' : False, # if true, noise is added during testing
'test_hard' : True, # if true, latent output is discrete
'train_noise' : True, # if true, noise is added during training
'train_hard' : False, # if true, latent output is discrete
'dropout_z' : False,
'noise' :0.2,
'dropout' :0.2,
'optimizer' :"radam",
'min_temperature' :0.5,
'epoch' :gs_annealing_epoch+main_epoch,
'gs_annealing_start':0,
'gs_annealing_end' :gs_annealing_epoch,
'clipnorm' :0.1,
'batch_size' :[400],
'lr' :[0.001],
'N' :[100], # latent space size
'zerosuppress' :0.1,
'densify' :False,
'max_temperature' : [5.0],
# hyperparameters for encoder/decoder.
# Each specific class depends only on a subset of hyperparameters.
# For example, CubeSpaceAE_AMA3Conv uses the hyperparameters for convolutional encoder/decoder only,
# ignoring the hyperparameters for fully-connected encoder/decoder.
# Unused hyperparameters are still recorded, but it does not affect the network.
# convolutional
'conv_channel' :[32],
'conv_channel_increment' :[1],
'conv_kernel' :[5],
'conv_pooling' :[1], # no pooling
'conv_per_pooling' :[1],
'conv_depth' :[3], # has_conv_layer = True; so just one convolution
# fully connected
'fc_width' :[100],
'fc_depth' :[2],
# aae
'A' :[6000],
'aae_activation' :['relu'],
'aae_width' :[1000],
'aae_depth' :[2],
'eff_regularizer':[None],
'beta_d' :[ 1 ],
'beta_z' :[ 1 ],
"output" :"GaussianOutput(sigma=0.1)",
}
if __name__ == '__main__':
latplan.main.main(parameters)