forked from tbmoon/facenet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
74 lines (55 loc) · 1.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
import visdom
from path import Path
class ModelSaver():
def __init__(self):
self._previous_acc = 0.
self._current_acc = 0.
@property
def previous_acc(self):
return self._previous_acc
@property
def current_acc(self):
return self._current_acc
@current_acc.setter
def current_acc(self, value):
self._current_acc = value
@previous_acc.setter
def previous_acc(self, value):
self._previous_acc = value
def __set_accuracy(self, accuracy):
self.previous_acc, self.current_acc = self.current_acc, accuracy
def save_if_best(self, accuracy, state):
if accuracy > self.current_acc:
self.__set_accuracy(accuracy)
torch.save(state, 'log/best_state.pth')
def create_if_not_exist(path):
path = Path(path)
if not path.exists(): path.touch()
def init_log_just_created(path):
create_if_not_exist(path)
with open(path, 'r') as f:
if len(f.readlines()) <= 0:
init_log_line(path)
def init_log_line(path):
with open(path, 'w') as f:
f.write('time,epoch,acc,loss,layers,bs,lr\n')
class VisdomLinePlotter(object):
"""Plots to Visdom"""
def __init__(self, env_name='main'):
self.viz = visdom.Visdom()
self.viz.check_connection()
self.env = env_name
self.plots = {}
def plot(self, var_name, split_name, x, y):
if var_name not in self.plots:
self.plots[var_name] = self.viz.line(X=np.array([x, x]), Y=np.array([y, y]), env=self.env, opts=dict(
legend=[split_name],
title=var_name,
xlabel='Epochs',
ylabel=var_name
))
else:
self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name],
name=split_name, update='append')