From bfd7964d5b04ab1a49aad9be6ae0453850f5b593 Mon Sep 17 00:00:00 2001 From: jcollopy-tulane Date: Sun, 28 Apr 2024 20:53:52 -0500 Subject: [PATCH] Fixing Probas --- nlp/app/routes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nlp/app/routes.py b/nlp/app/routes.py index 923359c..984a593 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -25,14 +25,18 @@ def index(): model = bnb text = process_text(input_field) text = bnb_vectorizer.transform([text]) - pred = bnb.predict(text) - proba = bnb.predict_proba(text)[:, 1] + pred_labels = bnb.predict(text) + probas = bnb.predict_proba(text) + pred = pred_labels[0] + proba = probas[0, pred] elif model_choice == 'lr': model = lr text = process_text(input_field) text = lr_vectorizer.transform([text]) - pred = lr.predict(text) - proba = lr.predict_proba(text)[:, 1] + pred_labels = lr.predict(text) + probas = lr.predict_proba(text) + pred = pred_labels[0] + proba = probas[0, pred] elif model_choice == 'cnn': # For CNN, assuming preprocessing is handled differently or is built-in model = cnn