-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
91 lines (74 loc) · 3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#ANCHOR Libraries
import numpy as np
import torch
import os
import seaborn as sns
import matplotlib.pyplot as plt
import copy
#ANCHOR Print table of zeros and non-zeros count
def print_nonzeros(model):
nonzero = total = 0
for name, p in model.named_parameters():
tensor = p.data.cpu().numpy()
nz_count = np.count_nonzero(tensor)
total_params = np.prod(tensor.shape)
nonzero += nz_count
total += total_params
print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)')
return (round((nonzero/total)*100,1))
def original_initialization(mask_temp, initial_state_dict):
global model
step = 0
for name, param in model.named_parameters():
if "weight" in name:
weight_dev = param.device
param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
step = step + 1
if "bias" in name:
param.data = initial_state_dict[name]
step = 0
#ANCHOR Checks of the directory exist and if not, creates a new directory
def checkdir(directory):
if not os.path.exists(directory):
os.makedirs(directory)
#FIXME
def plot_train_test_stats(stats,
epoch_num,
key1='train',
key2='test',
key1_label=None,
key2_label=None,
xlabel=None,
ylabel=None,
title=None,
yscale=None,
ylim_bottom=None,
ylim_top=None,
savefig=None,
sns_style='darkgrid'
):
assert len(stats[key1]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key1, len(stats[key1]), epoch_num)
assert len(stats[key2]) == epoch_num, "len(stats['{}'])({}) != epoch_num({})".format(key2, len(stats[key2]), epoch_num)
plt.clf()
sns.set_style(sns_style)
x_ticks = np.arange(epoch_num)
plt.plot(x_ticks, stats[key1], label=key1_label)
plt.plot(x_ticks, stats[key2], label=key2_label)
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
if title is not None:
plt.title(title)
if yscale is not None:
plt.yscale(yscale)
if ylim_bottom is not None:
plt.ylim(bottom=ylim_bottom)
if ylim_top is not None:
plt.ylim(top=ylim_top)
plt.legend(bbox_to_anchor=(1.04,0.5), loc="center left", borderaxespad=0, fancybox=True)
if savefig is not None:
plt.savefig(savefig, bbox_inches='tight')
else:
plt.show()