-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_train.py
56 lines (41 loc) · 1.57 KB
/
model_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from MLProto.MLProto.Proto import Proto
import data_pull as dp
import argparse
import json
import itertools
def train():
# load training config values
with open('model_train_config.json', 'r') as config:
hypers = json.load(config)
# get list of configs
keys, values = zip(*my_dict.items())
config_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]
# get data
data = dp.cloud_to_df()
print(data)
print()
# buffer for comparison metrics
metrics = {}
# iterate through configurations
iteration = 0
for config in config_dict:
# create model with given config
model = Proto('model'+str(iteration), data, 2, depth=config['depth'], node_counts=config['node_counts'][:config['depth']], \
batch=config['batch'], test_size=config['test_size'], loss=config['loss'], learning_rate=config['learning_rate'], past_window=config['past_window'])
# write summary for each model to log file
print(model.identifier + '\n----------------------------------------')
print(model.model.summary())
print('TRAINING\n')
# train model for 10 epochs
model.train(10, True, True)
# evaluate and record loss
print('EVALUATING\n')
model.evaluate()
metrics[model.loss] = model.identifier
# save model to models directory
print('SAVING\n')
model.save_model()
print('FINISHED ' + model.identifier)
print('BEST MODEL: ' + metrics[min(metrics.keys())] + '\n')
if __name__ == '__main__':
train()