-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
51 lines (42 loc) · 1.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import pickle
import torch
import numpy as np
def save(toBeSaved, filename, mode='wb'):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
file = open(filename, mode)
pickle.dump(toBeSaved, file)
file.close()
def load(filename, mode='rb'):
file = open(filename, mode)
loaded = pickle.load(file)
file.close()
return loaded
def pad_sents(sents, pad_token):
sents_padded = []
lens = get_lens(sents)
max_len = max(lens)
sents_padded = [sents[i] + [pad_token] * (max_len - l) for i, l in enumerate(lens)]
return sents_padded
def sort_sents(sents, reverse=True):
sents.sort(key=(lambda s: len(s)), reverse=reverse)
return sents
def get_mask(sents, unmask_idx=1, mask_idx=0):
lens = get_lens(sents)
max_len = max(lens)
mask = [([unmask_idx] * l + [mask_idx] * (max_len - l)) for l in lens]
return mask
def get_lens(sents):
return [len(sent) for sent in sents]
def get_max_len(sents):
max_len = max([len(sent) for sent in sents])
return max_len
def truncate_sents(sents, length):
sents = [sent[:length] for sent in sents]
return sents
def get_loss_weight(labels, label_order):
nums = [np.sum(labels == lo) for lo in label_order]
loss_weight = torch.tensor([n / len(labels) for n in nums])
return loss_weight