Skip to content

Commit

Permalink
Message
Browse files Browse the repository at this point in the history
  • Loading branch information
jcollopy-tulane committed Apr 29, 2024
1 parent 54d16ba commit a4a329b
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions nlp/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
cnn = pickle.load(open(lr_path, 'rb'))


labels = ['loss', 'win']

@app.route('/', methods=['GET', 'POST'])
@app.route('/index', methods=['GET', 'POST'])
def index():
Expand All @@ -28,27 +26,40 @@ def index():
model = bnb
text = process_text(input_field)
text = bnb_vectorizer.transform([text])
pred_labels = bnb.predict(text)
probas = bnb.predict_proba(text)
pred = pred_labels[0]
proba = probas[0, pred]
positive_proba = probas[:, 1]
if positive_proba > 0.5:
prediction == "WIN":
proba = positive_proba
else:
prediction == "LOSS":
proba = 1 - positive_proba
elif model_choice == 'lr':
model = lr
text = process_text(input_field)
text = lr_vectorizer.transform([text])
pred_labels = lr.predict(text)
probas = lr.predict_proba(text)
pred = pred_labels[0]
proba = probas[0, pred]
positive_proba = probas[:, 1]
if positive_proba > 0.5:
prediction == "WIN":
proba = positive_proba
else:
prediction == "LOSS":
proba = 1 - positive_proba
elif model_choice == 'cnn':
# For CNN, assuming preprocessing is handled differently or is built-in
model_path = '/Users/jackiecollopy/Downloads/project-reddit/nlp/cnn_model.h5'
model = load_model(model_path, compile=False)
text = basic_process(input_field)
text = cnn_process(text)
predictions_proba = model.predict(text)
pred = np.argmax(predictions_proba)
proba = predictions_proba[0]
probas = model.predict(text)
positive_proba = probas[:, 1]
if positive_proba > 0.5:
prediction == "WIN":
proba = positive_proba
else:
prediction == "LOSS":
proba = 1 - positive_proba
elif model_choice == 'bert':
tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-mini')
model = AutoModelForSequenceClassification.from_pretrained('/Users/jackiecollopy/Downloads/project-reddit/notebooks/bert.pth')
Expand All @@ -58,5 +69,5 @@ def index():
predicted_class = torch.argmax(outputs.logits).item()

return render_template('myform.html', title='', form=form,
prediction=labels[pred], confidence='%.2f' % (proba * 100))
prediction=prediction, confidence='%.2f' % (proba * 100))
return render_template('myform.html', title='', form=form, prediction=None, confidence=None)

0 comments on commit a4a329b

Please sign in to comment.