-
Notifications
You must be signed in to change notification settings - Fork 0
/
text_gen_api.py
91 lines (76 loc) · 2.99 KB
/
text_gen_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#importing libraries
from flask import Flask
from flask_restplus import Api, Resource, fields
# from flask_pymongo import PyMongo
from pymongo import MongoClient
import fastai
from fastai import *
import pathlib
import pickle
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
from fastai.text.all import *
# CONFIGURING MONGODB
# Connection to the MongoDB Server
mongoClient = MongoClient ('mongodb+srv://hamza:[email protected]/myFirstDatabase?retryWrites=true&w=majority')
# Connection to the database
db = mongoClient.posts
#Collection
collection = db.record
# reading all the generated text
generated_text = collection.find()
app = Flask(__name__)
api = Api(app,version='1.0', title='Text Generation',
description='A simple Topic Detection and Text Generation API')
api = api.namespace('TEXTGEN', description='Topic Detection and Text Generation')
a_text = api.model('USER_INPUT', {'Input' : fields.String('Enter the text')})
@api.route('/generated_text')
class Text_Gen(Resource):
def get(self):
'''Show all the Generated Text in Database'''
db1 = mongoClient.posts
collection1 = db1.record
generated_text1 = collection1.find()
result = []
for document in generated_text1:
text = {"Category" : document['Category'], "Generated_Text" : document['Generated_Text']}
result.append(text)
return result
@api.expect(a_text)
def post(self):
'''Detect Topic and Generate Text'''
user_input = api.payload
ml_result = ml_model(user_input['Input'])
# adding the generated text and category in mongodb database
collection.insert(ml_result)
return ({"Category" : ml_result['Category'], "Generated_Text" : ml_result['Generated_Text']}), 201
def ml_model(input):
# lodaing classification model
classifier = load_learner('reddit_classifier.pkl')
# classification
cat,_,probs = classifier.predict(input)
cat = cat.capitalize()
print(cat)
# generate text
clean_text = generate_text(input)
# generate text again if the same text found in the database
for document in generated_text:
if (clean_text == document['Generated_Text']):
clean_text = generate_text(input)
prediction = {"Category" : cat, "Generated_Text" : clean_text}
return prediction
def generate_text(text_input):
# initialize tokenizer and model from pretrained GPT2 model
tokenizer = pickle.load(open('gpt2_tokenizer.pkl','rb'))
model = pickle.load(open('gpt2_language.pkl','rb'))
# encode user_input
inputs = tokenizer.encode(text_input, return_tensors='pt')
# we pass a maximum output length of 200 tokens
outputs = model.generate(inputs, max_length=100, do_sample=True)
#decode user_input
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
text = text.replace("\n", "")
text = text.replace("\"" , "")
return text
if __name__ == '__main__':
app.run(debug=True)