-
Notifications
You must be signed in to change notification settings - Fork 1
/
fast_text.py
52 lines (44 loc) · 1.46 KB
/
fast_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import subprocess
from pymongo import MongoClient
from pathlib import Path
import numpy as np
client = MongoClient()
db = client['crawled_news']
collection = db['crawled_news']
def train_fast_text():
with open("textfiles/tempfile_input.txt","w") as f:
for t in collection.find():
f.write(t["content"])
f.write("\n")
subprocess.run(["./fasttext","skipgram","-input","textfiles/tempfile_input.txt","-output","model/fast_text"])
def query_fast_text(text):
stdout = subprocess.check_output(["sh","fast_text_vector.sh",text])
# stdout = process.communicate()[0]
# stdout_string = str(stdout)
# print(stdout.split(b'\n'))
vector = []
for line in stdout.split(b'\n')[:-1]:
# print(line)
row = []
for n in line.split()[1:]:
try:
row.append(float(n))
except ValueError:
print("ValueError unparseable")
row.append(0.0)
try:
if len(row) == 100:
vector.append(np.array(row))
else:
vector.append(np.array([0.0]*100))
except ValueError:
print("ValueError not same size")
vector.append(np.array([0.0]*100))
# print(row)
v = np.array(vector)
# print(v.shape)
return np.sum(v,axis=0)/v.shape[0]
if __name__ == '__main__':
# train_fast_text()
v = query_fast_text("minister dying now \asd k alsdlkalsndkas\ ds")
print(v)