-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
64 lines (54 loc) · 1.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import random
import numpy as np
import torch
import json
def set_seed(seed): # set the random seed for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def strip_eos(sents):
return [sent[:sent.index('<eos>')] if '<eos>' in sent else sent
for sent in sents]
def load_sent(path):
sents = []
with open(path) as f:
for line in f:
sent = json.loads(line)
sent['text'] = sent['text'].split()
sents.append(sent)
return sents
def write_sent(sents, path):
with open(path, 'w') as f:
for s in sents:
f.write(' '.join(s) + '\n')
def write_doc(docs, path):
with open(path, 'w') as f:
for d in docs:
for s in d:
f.write(' '.join(s) + '\n')
f.write('\n')
def write_z(z, path):
with open(path, 'w') as f:
for zi in z:
for zij in zi:
f.write('%f ' % zij)
f.write('\n')
def logging(s, path, print_=True):
if print_:
print(s)
if path:
with open(path, 'a+') as f:
f.write(s + '\n')
def lerp(t, p, q):
return (1-t) * p + t * q
# spherical interpolation https://github.com/soumith/dcgan.torch/issues/14#issuecomment-199171316
def slerp(t, p, q):
o = np.arccos(np.dot(p/np.linalg.norm(p), q/np.linalg.norm(q)))
so = np.sin(o)
return np.sin((1-t)*o) / so * p + np.sin(t*o) / so * q
def interpolate(z1, z2, n):
z = []
for i in range(n):
zi = lerp(1.0*i/(n-1), z1, z2)
z.append(np.expand_dims(zi, axis=0))
return np.concatenate(z, axis=0)