Skip to content

Commit

Permalink
feat: Add histogram to evaluate classification result
Browse files Browse the repository at this point in the history
TASK: IL-347
  • Loading branch information
MerlinKallenbornTNG authored and FlorianSchepersAA committed Apr 8, 2024
1 parent 3ec5e2f commit cd85b21
Showing 1 changed file with 126 additions and 10 deletions.
136 changes: 126 additions & 10 deletions src/examples/user_journey.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"source": [
"import json\n",
"\n",
"import numpy\n",
"import pandas\n",
"from dotenv import load_dotenv\n",
"from matplotlib import pyplot\n",
"\n",
"from intelligence_layer.core import InMemoryTracer, LuminousControlModel, TextChunk\n",
"from intelligence_layer.evaluation import (\n",
Expand Down Expand Up @@ -442,7 +445,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"runner_prompt_adjusted = Runner(\n",
Expand All @@ -461,7 +468,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"aggregation_overview_prompt_adjusted"
Expand All @@ -479,7 +490,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"classify_with_extended = PromptBasedClassify(\n",
Expand All @@ -497,7 +512,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"runner_with_extended = Runner(\n",
Expand All @@ -516,7 +535,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"aggregation_overview_with_extended"
Expand All @@ -532,17 +555,21 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"lineages = [\n",
"incorrect_predictions_lineages = [\n",
" lineage\n",
" for lineage in evaluator.evaluation_lineages(eval_overview_prompt_adjusted.id)\n",
" if not isinstance(lineage.evaluation.result, FailedExampleEvaluation)\n",
" and not lineage.evaluation.result.correct\n",
"]\n",
"\n",
"df = evaluation_lineages_to_pandas(lineages)\n",
"df = evaluation_lineages_to_pandas(incorrect_predictions_lineages)\n",
"df[\"input\"] = [i.chunk for i in df[\"input\"]]\n",
"df[\"predicted\"] = [r.predicted for r in df[\"result\"]]\n",
"df.reset_index()[[\"example_id\", \"input\", \"expected_output\", \"predicted\"]]"
Expand All @@ -552,7 +579,96 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see there are plenty of option on how to further enhance the accuracy of our classify task. Notice, for instance, that so far we did not tell our classification task what each class means.\n",
"So let's analyze this in more depth by visualizing how often each label was expected or predicted in a histogram. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"by_labels = aggregation_overview_with_extended.statistics.by_label\n",
"\n",
"expected_counts_by_labels = {\n",
" label: by_labels[label].expected_count for label in by_labels.keys()\n",
"}\n",
"predicted_counts_by_labels = {\n",
" label: by_labels[label].predicted_count for label in by_labels.keys()\n",
"}\n",
"\n",
"x_axis = numpy.arange(len(expected_counts_by_labels.keys()))\n",
"pyplot.bar(\n",
" x_axis - 0.2, expected_counts_by_labels.values(), width=0.4, label=\"expected counts\"\n",
")\n",
"pyplot.bar(\n",
" x_axis + 0.2,\n",
" predicted_counts_by_labels.values(),\n",
" width=0.4,\n",
" label=\"predicted counts\",\n",
")\n",
"pyplot.ylabel(\"Classification count\")\n",
"pyplot.xlabel(\"Labels\")\n",
"pyplot.legend()\n",
"_ = pyplot.xticks(x_axis, by_labels.keys(), rotation=45)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see the our task tends to overpredict the `Customer` label while it underpredicts `Infrastructure`, `CEO Office` and `Product`.\n",
"\n",
"We can get even more insight into the classification behaviour of our task by analysing its cross-matrix. From the off-diagonal cells in the cross-matrix we can see the explicit misslabeling for each class. This helps us to see if a specific class is frequently misslabeld as a particular other class. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"confusion_matrix = aggregation_overview_with_extended.statistics.confusion_matrix\n",
"\n",
"data = []\n",
"for (predicted_label, expected_label), count in confusion_matrix.items():\n",
" data.append(\n",
" {\n",
" \"Expected Label\": expected_label,\n",
" \"Predicted Label\": predicted_label,\n",
" \"Count\": count,\n",
" }\n",
" )\n",
"\n",
"df = pandas.DataFrame(data)\n",
"df = df.pivot(index=\"Expected Label\", columns=\"Predicted Label\", values=\"Count\")\n",
"df = df.fillna(0)\n",
"df = df.reindex(\n",
" index=labels, columns=labels, fill_value=0\n",
") # this will add any labels that were neither expected nor predicted\n",
"df = df.style.background_gradient(cmap=\"grey\", vmin=df.min().min(), vmax=df.max().max())\n",
"df = df.format(\"{:.0f}\")\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In our case we can see that the bias towards the `Customer` class does not come at the cost of one particular other class, but is caused by a more general mislabeling. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see there is plenty of room for further improvements of our classification task. \n",
"\n",
"Notice, for instance, that so far we did not tell our classification task what each class means.\n",
"\n",
"The model had to 'guess' what we mean by each class purely from the given labels. In order to tackle this issue you could use the `PromptBasedClassifyWithDefinitions` task. This task allows you to also provide a short description for each class.\n",
"\n",
Expand All @@ -576,7 +692,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down

0 comments on commit cd85b21

Please sign in to comment.