-
Notifications
You must be signed in to change notification settings - Fork 1
/
toSVM.py
278 lines (208 loc) · 7.78 KB
/
toSVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
## read from mongodb.features.xxx
## generate svm feature file
import config
import pickle
import sys, color
import pymongo, os, re
from collections import defaultdict
from bson.objectid import ObjectId
import logging
## ---------------------- global ----------------------- ##
db = pymongo.Connection(config.mongo_addr)[config.db_name]
udocID_eid = {}
co_feature_setting = None
setting_id_str = ''
eids = { emotion : i for i, emotion in enumerate(sorted([d['emotion'] for d in db['emotions'].find({'label': 'LJ40K'})])) }
## ------------------------------------------------------- ##
def parse_src_setting_ids():
# support delimiters: [, . ; : white_space]
# case insensitive
# e.g., "537086fcd4388c7e81676914,537086fcd4388c7e816762139 537086fcd4388c7e81676916"
parsed = re.findall(r'([0-9a-z]{24})[.,;:\s]?', setting_id_str.lower())
return False if not parsed else sorted(parsed)
## src_setting_id: used for mongo fetching
## dest_setting_id: used for train/test naming
def obtain_dest_setting_id(src_setting_ids):
global co_feature_setting
# input format error
if not src_setting_ids or len(src_setting_ids) == 0:
dest_setting_id = False
# normal mode
elif len(src_setting_ids) == 1:
dest_setting_id = src_setting_ids[0]
# fused mode
else:
# combine src_setting_id(s) to get dest_setting_id
sources = ','.join(src_setting_ids)
mdoc = co_feature_setting.find_one({'sources': sources})
# new fusion
if not mdoc:
# generate fused setting_id
dest_setting_id = str( co_feature_setting.insert( {'sources': sources, 'feature_name': 'fusion'} ) )
# already fused, just get fusion id
else:
dest_setting_id = str( mdoc['_id'] )
return dest_setting_id
def is_dest_files_exist(dest_paths):
return False if False in [os.path.exists(dest_path) for dest_path in dest_paths.values()] else True
def get_dest_paths(dest_setting_id, ext='txt'):
# check
# generate paths
dest_paths = {
'_root_': 'tmp', # path generation will ignore entry surrounded with "_"
'train': None,
'test': None,
'gold': None
}
if not os.path.exists(dest_paths['_root_']):
os.mkdir(dest_paths['_root_'])
ftypes = filter(lambda x:not x.startswith('_'), dest_paths.keys())
for ftype in ftypes:
fn = '.'.join([dest_setting_id, ftype, ext])
dest_path = os.path.join(dest_paths['_root_'], fn)
dest_paths[ftype] = dest_path
return dest_paths
# src_setting_ids: [537086fcd4388c7e81676914, 537086fcd4388c7e816762139 ,...]
def generate_feature_vectors(src_setting_ids):
global co_feature_setting, eids, udocID_eid
feature_pool = {}
feature_vectors = {}
# for each src_setting_id
for src_setting_id in src_setting_ids:
# find feature_name --> collection_name
try:
feature_name = co_feature_setting.find_one( {'_id': ObjectId(src_setting_id) } )['feature_name']
except:
print 'check the format feature setting:',src_setting_id,'in mongodb'
return False
collection_name = 'features.' + feature_name
## use src_setting_id as prefix
prefix = src_setting_id
## gathering
number = db[collection_name].find({'setting':src_setting_id}).count()
if number == 0:
logging.error("can't find any instances with id "+ color.render(src_setting_id, 'y') + ' in ' + color.render(collection_name,'g'))
return False
for mdoc in db[collection_name].find({'setting':src_setting_id}):
udocID = mdoc['udocID']
emotion = mdoc['emotion']
## use emotion index as eid
eid = eids[emotion]
if eid not in feature_vectors:
feature_vectors[eid] = defaultdict(list)
## save the mapping of udocID -> eid
# udocID_gid[udocID] = eid
feature = mdoc['feature']
if not feature:
feature_vectors[eid][udocID] = []
else:
for f_name, f_value in feature:
# combine f_name with prefix
f_name = '#'.join([prefix, f_name])
# generate fid
if f_name not in feature_pool:
feature_pool[f_name] = len(feature_pool)
# get fid
fid = feature_pool[f_name]
feature_vectors[eid][udocID].append( (fid, f_value) )
return feature_vectors
def tranform_to_svm_format(feature_vectors):
## (before) feature_vectors:
# {
# '38000': [(3, 1), (2, 1), (1, 6), (0, 2)],
# ...
# }
str_feature_vectors = defaultdict(list)
### transform feature_vectors into all string type
for eid in feature_vectors:
for udocID in feature_vectors[eid]:
# sort
# [(0, 2), (1, 6), (2, 1), (3, 1)]
vector = sorted(feature_vectors[eid][udocID], key=lambda x:x[0])
# toString
# ['0:2', '1:6', '2:1', '3:1']
vector = map(lambda x: str(x[0])+':'+str(x[1]), vector)
# insert the gold_id
# ['38', '0:2', '1:6', '2:1', '3:1']
vector.insert(0, str(eid))
# join whitespace
# '38 0:2 1:6 2:1 3:1'
str_vector = ' '.join(vector)
str_feature_vectors[eid].append( (udocID, str_vector) )
# str_feature_vectors
# {
# 38: [(31000, '38 0:2 1:6 2:1 3:1'), ...],
# ...
# }
return str_feature_vectors
# {'test': 'tmp/5380557a3681dfc8523cd24e.test.txt', 'train': 'tmp/5380557a3681dfc8523cd24e.train.txt', '_root_': 'tmp', 'gold': 'tmp/5380557a3681dfc8523cd24e.gold.txt'}
def generate_train_test_files(str_feature_vectors, dest_paths):
fw = {}
for ftype in dest_paths:
if ftype.startswith('_'):
continue
else:
fw[ftype] = open(dest_paths[ftype], 'w')
# default: [800:200]
for eid in str_feature_vectors:
# str_feature_vectors[eid]
# [(31000, '38 0:2 1:6 2:1 3:1'), ...]
## sort by udocID
vectors = sorted(str_feature_vectors[eid], key=lambda x:x[0])
train = vectors[:800]
test = vectors[800:]
train_txt = '\n'.join([x[1] for x in train]) + '\n'
test_txt = '\n'.join([x[1] for x in test]) + '\n'
gold_txt = '\n'.join([str(eid)]*len(test)) + '\n'
fw['train'].write(train_txt)
fw['test'].write(test_txt)
fw['gold'].write(gold_txt)
for ftype in fw:
fw[ftype].close()
def run():
global co_feature_setting
# collection pointer of feature settings
co_feature_setting = db[config.co_feature_setting_name]
# sorted src_setting_id
src_setting_ids = parse_src_setting_ids()
dest_setting_id = obtain_dest_setting_id(src_setting_ids)
dest_paths = get_dest_paths(dest_setting_id)
## logging
logging.debug('src_setting_ids: '+color.render(','.join(src_setting_ids), 'y') )
for ftype, fn in sorted(dest_paths.items()):
logging.debug( ftype+': '+color.render(fn, 'g') )
logging.info('dest_setting_id: '+color.render(dest_setting_id, 'y') )
# files are all existed
if is_dest_files_exist(dest_paths) and not config.overwrite:
logging.info('all files are existed')
# files are not all existed
else:
logging.info('generate feature vectors')
feature_vectors = generate_feature_vectors(src_setting_ids)
if not feature_vectors: exit(-1)
logging.info('transform to svm format')
str_feature_vectors = tranform_to_svm_format(feature_vectors)
logging.info('generate train/test files')
generate_train_test_files(str_feature_vectors, dest_paths)
return True
if __name__ == '__main__':
import getopt
add_opts = [
('setting_id', ['<setting_id>: specify setting ID(s) (e.g., 537086fcd4388c7e81676914, or 537086fcd4388c7e81676914,537c6c90d4388c0e27069e7b)',
' which can be retrieved from the mongo collection features.settings' ]),
]
arg_idx = 2 if len(sys.argv) > 1 and not sys.argv[1].startswith('-') else 1
try:
opts, args = getopt.getopt(sys.argv[arg_idx:],'hvo',['help', 'verbose', 'overwrite'])
setting_id_str = sys.argv[1].strip()
except:
config.help('toSVM', addon=add_opts, args=['<setting_id>'], exit=2)
## read options
for opt, arg in opts:
if opt in ('-h', '--help'): config.help('toSVM',args=['setting_id'], addon=add_opts)
elif opt in ('-v','--verbose'): config.verbose = True
elif opt in ('-o','--overwrite'): config.overwrite = True
## set log level
loglevel = logging.DEBUG if config.verbose else logging.INFO
logging.basicConfig(format='[%(levelname)s] %(message)s', level=loglevel)
run()