diff --git a/src/examples/user_journey.ipynb b/src/examples/user_journey.ipynb index 7ce24220d..e831c18fe 100644 --- a/src/examples/user_journey.ipynb +++ b/src/examples/user_journey.ipynb @@ -8,7 +8,10 @@ "source": [ "import json\n", "\n", + "import numpy\n", + "import pandas\n", "from dotenv import load_dotenv\n", + "from matplotlib import pyplot\n", "\n", "from intelligence_layer.core import InMemoryTracer, LuminousControlModel, TextChunk\n", "from intelligence_layer.evaluation import (\n", @@ -442,7 +445,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "runner_prompt_adjusted = Runner(\n", @@ -461,7 +468,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "aggregation_overview_prompt_adjusted" @@ -479,7 +490,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "classify_with_extended = PromptBasedClassify(\n", @@ -497,7 +512,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "runner_with_extended = Runner(\n", @@ -516,7 +535,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ "aggregation_overview_with_extended" @@ -532,17 +555,21 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "is_executing": true + } + }, "outputs": [], "source": [ - "lineages = [\n", + "incorrect_predictions_lineages = [\n", " lineage\n", " for lineage in evaluator.evaluation_lineages(eval_overview_prompt_adjusted.id)\n", " if not isinstance(lineage.evaluation.result, FailedExampleEvaluation)\n", " and not lineage.evaluation.result.correct\n", "]\n", "\n", - "df = evaluation_lineages_to_pandas(lineages)\n", + "df = evaluation_lineages_to_pandas(incorrect_predictions_lineages)\n", "df[\"input\"] = [i.chunk for i in df[\"input\"]]\n", "df[\"predicted\"] = [r.predicted for r in df[\"result\"]]\n", "df.reset_index()[[\"example_id\", \"input\", \"expected_output\", \"predicted\"]]" @@ -552,7 +579,96 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As you can see there are plenty of option on how to further enhance the accuracy of our classify task. Notice, for instance, that so far we did not tell our classification task what each class means.\n", + "So let's analyze this in more depth by visualizing how often each label was expected or predicted in a histogram. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "by_labels = aggregation_overview_with_extended.statistics.by_label\n", + "\n", + "expected_counts_by_labels = {\n", + " label: by_labels[label].expected_count for label in by_labels.keys()\n", + "}\n", + "predicted_counts_by_labels = {\n", + " label: by_labels[label].predicted_count for label in by_labels.keys()\n", + "}\n", + "\n", + "x_axis = numpy.arange(len(expected_counts_by_labels.keys()))\n", + "pyplot.bar(\n", + " x_axis - 0.2, expected_counts_by_labels.values(), width=0.4, label=\"expected counts\"\n", + ")\n", + "pyplot.bar(\n", + " x_axis + 0.2,\n", + " predicted_counts_by_labels.values(),\n", + " width=0.4,\n", + " label=\"predicted counts\",\n", + ")\n", + "pyplot.ylabel(\"Classification count\")\n", + "pyplot.xlabel(\"Labels\")\n", + "pyplot.legend()\n", + "_ = pyplot.xticks(x_axis, by_labels.keys(), rotation=45)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see the our task tends to overpredict the `Customer` label while it underpredicts `Infrastructure`, `CEO Office` and `Product`.\n", + "\n", + "We can get even more insight into the classification behaviour of our task by analysing its cross-matrix. From the off-diagonal cells in the cross-matrix we can see the explicit misslabeling for each class. This helps us to see if a specific class is frequently misslabeld as a particular other class. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "confusion_matrix = aggregation_overview_with_extended.statistics.confusion_matrix\n", + "\n", + "data = []\n", + "for (predicted_label, expected_label), count in confusion_matrix.items():\n", + " data.append(\n", + " {\n", + " \"Expected Label\": expected_label,\n", + " \"Predicted Label\": predicted_label,\n", + " \"Count\": count,\n", + " }\n", + " )\n", + "\n", + "df = pandas.DataFrame(data)\n", + "df = df.pivot(index=\"Expected Label\", columns=\"Predicted Label\", values=\"Count\")\n", + "df = df.fillna(0)\n", + "df = df.reindex(\n", + " index=labels, columns=labels, fill_value=0\n", + ") # this will add any labels that were neither expected nor predicted\n", + "df = df.style.background_gradient(cmap=\"grey\", vmin=df.min().min(), vmax=df.max().max())\n", + "df = df.format(\"{:.0f}\")\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In our case we can see that the bias towards the `Customer` class does not come at the cost of one particular other class, but is caused by a more general mislabeling. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see there is plenty of room for further improvements of our classification task. \n", + "\n", + "Notice, for instance, that so far we did not tell our classification task what each class means.\n", "\n", "The model had to 'guess' what we mean by each class purely from the given labels. In order to tackle this issue you could use the `PromptBasedClassifyWithDefinitions` task. This task allows you to also provide a short description for each class.\n", "\n", @@ -576,7 +692,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.7" } }, "nbformat": 4,