-
Notifications
You must be signed in to change notification settings - Fork 0
/
beautify.py
89 lines (69 loc) · 2.42 KB
/
beautify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
# import tensorflow as tf
# from tensorflow.python.summary import event_accumulator as ea
from tensorboard.backend.event_processing import event_accumulator as ea
from matplotlib import pyplot as plt
from matplotlib import colors as colors
import seaborn as sns
sns.set(style="darkgrid")
sns.set_context("paper")
def plot(params):
''' beautify tf log
Use better library (seaborn) to plot tf event file'''
log_path = params['logdir']
smooth_space = params['smooth']
color_code = params['color']
acc = ea.EventAccumulator(log_path)
acc.Reload()
# only support scalar now
scalar_list = acc.Tags()['scalars']
print(scalar_list)
x_list = []
y_list = []
x_list_raw = []
y_list_raw = []
for tag in scalar_list:
if tag != 'Return1000':
continue
x = [int(s.step) for s in acc.Scalars(tag)]
y = [s.value for s in acc.Scalars(tag)]
# segmentation
idx = []
for i, v in enumerate(x):
if v == 2048:
idx.append(i)
curve_list_x = []
curve_list_y = []
for i in range(len(idx)-1):
curve_list_x.append(x[idx[i] : idx[i+1]-1])
curve_list_y.append(y[idx[i] : idx[i+1]-1])
curve_list_x.append(x[idx[-1] : -1])
curve_list_y.append(y[idx[-1] : -1]) #raw curve
# smooth curve
curve_list_x_ = []
curve_list_y_ = []
for j in range(0, len(curve_list_x)):
x_ = []
y_ = []
for i in range(0, len(curve_list_x[j]), smooth_space):
x_.append(curve_list_x[j][i])
y_.append(sum(curve_list_y[j][i:i+smooth_space]) / float(smooth_space))
x_.append(curve_list_x[j][-1])
y_.append(curve_list_x[j][-1])
curve_list_x_.append(x_)
curve_list_y_.append(y_)
for i in range(len(curve_list_x)):
plt.figure(1)
plt.subplot(111)
plt.title(tag)
plt.plot(curve_list_x[i], curve_list_y[i], color=colors.to_rgba(color_code, alpha=0.4))
plt.plot(curve_list_x_[i], curve_list_y_[i], color=color_code, linewidth=1.5)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--logdir', default='./logdir', type=str, help='logdir to event file')
parser.add_argument('--smooth', default=100, type=float, help='window size for average smoothing')
parser.add_argument('--color', default='#4169E1', type=str, help='HTML code for the figure')
args = parser.parse_args()
params = vars(args) # convert to ordinary dict
plot(params)