diff --git a/nlp/app/routes.py b/nlp/app/routes.py index 3785e68..bc5e373 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -22,7 +22,7 @@ def index(): input_field = form.input_field.data model_choice = form.model_choice.data prediction = None - proba = None # Initialize proba variable + proba = None if model_choice == 'bnb': model = bnb @@ -36,24 +36,25 @@ def index(): else: prediction = "LOSS" proba = 1 - positive_proba - prediction = None - proba = None + + elif model_choice == 'lr': + model = lr text = process_text(input_field) text = lr_vectorizer.transform([text]) probas = lr.predict_proba(text) positive_proba = probas[:, 1] - prediction = None + if positive_proba > 0.5: prediction == "WIN" proba = positive_proba else: prediction == "LOSS" proba = 1 - positive_proba - prediction = None - proba = None + elif model_choice == 'cnn': + # For CNN, assuming preprocessing is handled differently or is built-in model_path = '/Users/jackiecollopy/Downloads/project-reddit/nlp/cnn_model.h5' model = load_model(model_path, compile=False) @@ -61,7 +62,6 @@ def index(): text = cnn_process(text) probas = model.predict(text) positive_proba = probas[:, 1] - prediction = None if positive_proba > 0.5: prediction == "WIN"