-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
121 lines (95 loc) · 4.13 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from flask import Flask, render_template, request
import torch
import re
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from collections import OrderedDict
class GPT2PPL:
def __init__(self, model_id="gpt2"):
self.model_id = model_id
self.model = GPT2LMHeadModel.from_pretrained(model_id)
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
self.max_length = self.model.config.n_positions
self.stride = 512
def getResults(self, threshold):
if threshold < 60:
label = 0
return "The Text is generated by AI.", label
elif threshold < 80:
label = 0
return "The Text most probably contains parts which are generated by AI. (requires more text for better judgment)", label
else:
label = 1
return "The Text is written by a Human.", label
def __call__(self, sentence):
"""
Takes in a sentence split by full stop
and print the perplexity of the total sentence
split the lines based on full stop and find the perplexity of each sentence and print
average perplexity
Burstiness is the max perplexity of each sentence
"""
results = OrderedDict()
total_valid_char = re.findall("[a-zA-Z0-9]+", sentence)
total_valid_char = sum([len(x) for x in total_valid_char]) # finds len of all the valid characters in a sentence
if total_valid_char < 100:
return {"status": "Please input more text (minimum 100 characters)"}, "Please input more text (minimum 100 characters)"
lines = re.split(r'(?<=[.?!][ \[\(])|(?<=\n)\s*', sentence)
lines = list(filter(lambda x: (x is not None) and (len(x) > 0), lines))
ppl = self.getPPL(sentence)
results["Perplexity"] = ppl
offset = ""
Perplexity_per_line = []
for i, line in enumerate(lines):
if re.search("[a-zA-Z0-9]+", line) is None:
continue
if len(offset) > 0:
line = offset + line
offset = ""
# remove the newline or space at the beginning of the first sentence, if it exists
if line[0] == "\n" or line[0] == " ":
line = line[1:]
if line[-1] == "\n" or line[-1] == " ":
line = line[:-1]
elif line[-1] == "[" or line[-1] == "(":
offset = line[-1]
line = line[:-1]
ppl = self.getPPL(line)
Perplexity_per_line.append(ppl)
results["Perplexity per line"] = sum(Perplexity_per_line) / len(Perplexity_per_line)
results["Burstiness"] = max(Perplexity_per_line)
out, label = self.getResults(results["Perplexity per line"])
results["label"] = label
return results, out
def getPPL(self, sentence):
encodings = self.tokenizer(sentence, return_tensors="pt")
seq_len = encodings.input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in range(0, seq_len, self.stride):
end_loc = min(begin_loc + self.max_length, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = encodings.input_ids[:, begin_loc:end_loc]
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = self.model(input_ids, labels=target_ids)
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = int(torch.exp(torch.stack(nlls).sum() / end_loc))
return ppl
# Initialize the Flask app
app = Flask(__name__)
# Initialize the GPT2PPL model
model = GPT2PPL()
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
sentence = request.form['sentence']
results, output = model(sentence)
return render_template('index.html', sentence=sentence, results=results, output=output)
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True)