-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
38 lines (27 loc) · 954 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
def get_mdl_params(model):
# model parameters ---> vector (different storage)
vec = []
for param in model.parameters():
vec.append(param.clone().detach().cpu().reshape(-1))
return torch.cat(vec)
def param_to_vector(model):
# model parameters ---> vector (same storage)
vec = []
for param in model.parameters():
vec.append(param.reshape(-1))
return torch.cat(vec)
def set_client_from_params(device, model, params):
idx = 0
for param in model.parameters():
length = param.numel()
param.data.copy_(params[idx:idx + length].reshape(param.shape))
idx += length
return model.to(device)
def get_params_list_with_shape(model, param_list):
vec_with_shape = []
idx = 0
for param in model.parameters():
length = param.numel()
vec_with_shape.append(param_list[idx:idx + length].reshape(param.shape))
return vec_with_shape