diff --git a/nlp/app/routes.py b/nlp/app/routes.py index 54ca5fa..5259d8e 100644 --- a/nlp/app/routes.py +++ b/nlp/app/routes.py @@ -14,8 +14,6 @@ cnn = pickle.load(open(lr_path, 'rb')) -labels = ['loss', 'win'] - @app.route('/', methods=['GET', 'POST']) @app.route('/index', methods=['GET', 'POST']) def index(): @@ -28,27 +26,40 @@ def index(): model = bnb text = process_text(input_field) text = bnb_vectorizer.transform([text]) - pred_labels = bnb.predict(text) probas = bnb.predict_proba(text) - pred = pred_labels[0] - proba = probas[0, pred] + positive_proba = probas[:, 1] + if positive_proba > 0.5: + prediction == "WIN": + proba = positive_proba + else: + prediction == "LOSS": + proba = 1 - positive_proba elif model_choice == 'lr': model = lr text = process_text(input_field) text = lr_vectorizer.transform([text]) - pred_labels = lr.predict(text) probas = lr.predict_proba(text) - pred = pred_labels[0] - proba = probas[0, pred] + positive_proba = probas[:, 1] + if positive_proba > 0.5: + prediction == "WIN": + proba = positive_proba + else: + prediction == "LOSS": + proba = 1 - positive_proba elif model_choice == 'cnn': # For CNN, assuming preprocessing is handled differently or is built-in model_path = '/Users/jackiecollopy/Downloads/project-reddit/nlp/cnn_model.h5' model = load_model(model_path, compile=False) text = basic_process(input_field) text = cnn_process(text) - predictions_proba = model.predict(text) - pred = np.argmax(predictions_proba) - proba = predictions_proba[0] + probas = model.predict(text) + positive_proba = probas[:, 1] + if positive_proba > 0.5: + prediction == "WIN": + proba = positive_proba + else: + prediction == "LOSS": + proba = 1 - positive_proba elif model_choice == 'bert': tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-mini') model = AutoModelForSequenceClassification.from_pretrained('/Users/jackiecollopy/Downloads/project-reddit/notebooks/bert.pth') @@ -58,5 +69,5 @@ def index(): predicted_class = torch.argmax(outputs.logits).item() return render_template('myform.html', title='', form=form, - prediction=labels[pred], confidence='%.2f' % (proba * 100)) + prediction=prediction, confidence='%.2f' % (proba * 100)) return render_template('myform.html', title='', form=form, prediction=None, confidence=None)