Skip to content

Commit

Permalink
probas
Browse files Browse the repository at this point in the history
  • Loading branch information
jcollopy-tulane committed Apr 30, 2024
1 parent e366dba commit 39c7fe2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
8 changes: 4 additions & 4 deletions nlp/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def index():
model = load_model(model_path, compile=False)
text = basic_process(input_field)
text = cnn_process(text)
preds = model.predict(text)
probas = (preds > 0.5).astype(int)
probas = model.predict(text)
preds = (probas > 0.5).astype(int)
if preds == 1:
prediction = "WIN"
proba = probas[1]
proba = probas
else:
prediction = "LOSS"
proba = proba[0]
proba = probas
elif model_choice == 'bert':
tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-mini')
model = AutoModelForSequenceClassification.from_pretrained('/Users/jackiecollopy/Downloads/project-reddit/notebooks/bert.pth')
Expand Down
43 changes: 32 additions & 11 deletions notebooks/Experiment-CNN-1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 27,
"id": "8cfa7100-4fed-4538-b8ce-213a5efb4d1f",
"metadata": {},
"outputs": [
Expand All @@ -386,15 +386,15 @@
],
"source": [
"predictions = best_model_cnn.predict(X_val)\n",
"predictions = (predictions > 0.5).astype(int) \n",
"preds = (predictions > 0.5).astype(int) \n",
"\n",
"f1 = f1_score(y_val, predictions)\n",
"f1 = f1_score(y_val, preds)\n",
"print(\"F1 Score:\", round(f1,3))\n",
"# Calculate Precision\n",
"precision = precision_score(y_val, predictions)\n",
"precision = precision_score(y_val, preds)\n",
"print(\"Precision:\", round(precision, 3))\n",
"# Calculate recall\n",
"recall = recall_score(y_val, predictions)\n",
"recall = recall_score(y_val, preds)\n",
"print(\"Recall:\", round(recall, 3))"
]
},
Expand Down Expand Up @@ -459,7 +459,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 24,
"id": "9612ac53-cc11-4ee6-acff-d33cabf98e73",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -490,7 +490,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 25,
"id": "ff0623cb-4337-4e8f-8f38-6f2a3a215edf",
"metadata": {},
"outputs": [
Expand All @@ -506,7 +506,7 @@
" [0]])"
]
},
"execution_count": 22,
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -561,11 +561,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 28,
"id": "e96b0c61-e235-4fb2-88af-6149c17184a6",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"ename": "NameError",
"evalue": "name 'cnn_model' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[28], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m results_df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mActual_Label\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m y_test\n\u001b[1;32m 5\u001b[0m results_df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPredicted_Label\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m preds\n\u001b[0;32m----> 6\u001b[0m probs \u001b[38;5;241m=\u001b[39m \u001b[43mcnn_model\u001b[49m\u001b[38;5;241m.\u001b[39mpredict(X_test)\n\u001b[1;32m 7\u001b[0m results_df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpredict_proba\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m probs\n\u001b[1;32m 8\u001b[0m results_df\u001b[38;5;241m.\u001b[39mhead()\n",
"\u001b[0;31mNameError\u001b[0m: name 'cnn_model' is not defined"
]
}
],
"source": [
"## Getting a dataframe\n",
"\n",
"results_df = test_df.copy()\n",
"results_df[\"Actual_Label\"] = y_test\n",
"results_df[\"Predicted_Label\"] = preds\n",
"probs = cnn_model.predict(X_test)\n",
"results_df['predict_proba'] = probs\n",
"results_df.head()"
]
},
{
"cell_type": "code",
Expand Down

0 comments on commit 39c7fe2

Please sign in to comment.