-
Notifications
You must be signed in to change notification settings - Fork 12
/
vocab.py
133 lines (110 loc) · 3.96 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Create a vocabulary wrapper
import nltk
import pickle
from collections import Counter
# from pycocotools.coco import COCO
import json
import argparse
import os
from nltk.stem import WordNetLemmatizer
annotations = {
'coco_precomp': ['train_caps.txt', 'dev_caps.txt'],
'coco': ['annotations/captions_train2014.json',
'annotations/captions_val2014.json'],
'f8k_precomp': ['train_caps.txt', 'dev_caps.txt'],
'10crop_precomp': ['train_caps.txt', 'dev_caps.txt'],
'f30k_precomp': ['train_caps.txt', 'dev_caps.txt'],
'f8k': ['dataset_flickr8k.json'],
'f30k': ['dataset_flickr30k.json'],
}
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self):
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __call__(self, word):
if word not in self.word2idx:
return self.word2idx['<unk>']
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)
def from_coco_json(path):
coco = COCO(path)
ids = coco.anns.keys()
captions = []
for i, idx in enumerate(ids):
captions.append(str(coco.anns[idx]['caption']))
return captions
def from_flickr_json(path):
dataset = json.load(open(path, 'r'))['images']
captions = []
for i, d in enumerate(dataset):
captions += [str(x['raw']) for x in d['sentences']]
return captions
def from_txt(txt):
captions = []
with open(txt, 'rb') as f:
for line in f:
captions.append(line.strip())
return captions
def build_vocab(data_path, data_name, jsons, threshold):
"""Build a simple vocabulary wrapper."""
counter = Counter()
for path in jsons[data_name]:
full_path = os.path.join(os.path.join(data_path, data_name), path)
if data_name == 'coco':
captions = from_coco_json(full_path)
elif data_name == 'f8k' or data_name == 'f30k':
captions = from_flickr_json(full_path)
else:
captions = from_txt(full_path)
for i, caption in enumerate(captions):
tokens = nltk.tokenize.word_tokenize(
caption.lower().encode('utf-8').decode('utf-8'))
counter.update(tokens)
if i % 1000 == 0:
print("\r[%d/%d] tokenized the captions." % (i, len(captions)),end = '')
# Discard if the occurrence of the word is less than min_word_cnt.
words = []
counts = []
for word, cnt in counter.items():
if cnt >= threshold:
words.append(word)
counts.append((word, cnt))
# words = [word for word, cnt in counter.items() if cnt >= threshold]
counts_new = sorted(counts, key=lambda x:x[1], reverse=True)
print(counts_new)
# Create a vocab wrapper and add some special tokens.
vocab = Vocabulary()
vocab.add_word('<pad>')
vocab.add_word('<start>')
vocab.add_word('<end>')
vocab.add_word('<unk>')
print(len(counts_new))
# Add words to the vocabulary.
chosen_nums = 256
for i, word_cnt in enumerate(counts_new):
word, count = word_cnt
# print(word)
if i < chosen_nums:
vocab.add_word(word)
print(word)
return vocab
def main(data_path, data_name):
vocab = build_vocab(data_path, data_name, jsons=annotations, threshold=300)
with open('%s_vocab.pkl' % data_name, 'wb') as f:
pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL)
# print("Saved vocabulary file to ", '%s_vocab.pkl' % data_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='data')
parser.add_argument('--data_name', default='f30k',
help='{coco,f8k,f30k,10crop}_precomp|coco|f8k|f30k')
opt = parser.parse_args()
main(opt.data_path, opt.data_name)