From 210c6de728267150ff583e076a775879fa60c47b Mon Sep 17 00:00:00 2001 From: Florian Heiderich <159080706+FlorianHeiderichAA@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:29:25 +0100 Subject: [PATCH] feat: compute precision, recall and f1-score by class in SingleLabelClassifyEvaluationLogic (#1110) * feat: compute precision, recall and f1-score by class in `SingleLabelClassifyEvaluationLogic` * doc: improve confusion_matrix description in docstring --------- Co-authored-by: Sebastian Niehus <165138846+SebastianNiehusAA@users.noreply.github.com> --- CHANGELOG.md | 2 +- .../examples/classify/classify.py | 116 +++++++++++++++++- tests/examples/classify/test_classify.py | 61 ++++++++- .../classify/test_prompt_based_classify.py | 10 ++ 4 files changed, 184 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2910b0cb2..b3d11a644 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ - Introduced `InstructionFinetuningDataRepository` for storing and retrieving finetuning samples. Comes in two implementations: - `PostgresInstructionFinetuningDataRepository` to work with data stored in a Postgres database. - `FileInstructionFinetuningDataRepository` to work with data stored in the local file-system. - +- Compute precision, recall and f1-score by class in `SingleLabelClassifyAggregationLogic` ### Fixes ... diff --git a/src/intelligence_layer/examples/classify/classify.py b/src/intelligence_layer/examples/classify/classify.py index f710a3456..e245e9538 100644 --- a/src/intelligence_layer/examples/classify/classify.py +++ b/src/intelligence_layer/examples/classify/classify.py @@ -85,12 +85,18 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel): Attributes: percentage_correct: Percentage of answers that were considered to be correct. - confusion_matrix: A matrix showing the predicted classifications vs the expected classifications. + precision_by_class: Precision for each class + recall_by_class: Recall for each class + f1_by_class: f1-score for each class + confusion_matrix: A matrix showing the predicted classifications vs the expected classifications. First key refers to the rows of the confusion matrix (=actual prediction), second key refers to the columns of the matrix (=expected value). by_label: Each label along side the counts how often it was expected or predicted. missing_labels: Each expected label which is missing in the set of possible labels in the task input and the number of its occurrences. """ percentage_correct: float + precision_by_class: dict[str, float | None] + recall_by_class: dict[str, float | None] + f1_by_class: dict[str, float] confusion_matrix: dict[str, dict[str, int]] by_label: dict[str, AggregatedLabelInfo] missing_labels: dict[str, int] @@ -101,6 +107,105 @@ class SingleLabelClassifyAggregationLogic( SingleLabelClassifyEvaluation, AggregatedSingleLabelClassifyEvaluation ] ): + @staticmethod + def _true_positives( + confusion_matrix: dict[str, dict[str, int]], predicted_class: str + ) -> int: + return confusion_matrix[predicted_class][predicted_class] + + @staticmethod + def _false_positives( + confusion_matrix: dict[str, dict[str, int]], predicted_class: str + ) -> int: + expected_classes_for_predicted_class = confusion_matrix[predicted_class].keys() + return sum( + confusion_matrix[predicted_class][e] + for e in expected_classes_for_predicted_class + if e != predicted_class + ) + + @staticmethod + def _false_negatives( + confusion_matrix: dict[str, dict[str, int]], predicted_class: str + ) -> int: + predicted_classes = confusion_matrix.keys() + return sum( + confusion_matrix[p][predicted_class] + for p in predicted_classes + if p != predicted_class + ) + + @staticmethod + def _precision(true_positives: int, false_positives: int) -> float | None: + if true_positives + false_positives == 0: + return None + return true_positives / (true_positives + false_positives) + + @staticmethod + def _recall(true_positives: int, false_negatives: int) -> float | None: + if true_positives + false_negatives == 0: + return None + return true_positives / (true_positives + false_negatives) + + @staticmethod + def _f1(true_positives: int, false_positives: int, false_negatives: int) -> float: + return ( + 2 + * true_positives + / (2 * true_positives + false_positives + false_negatives) + ) + + @staticmethod + def _precision_by_class( + confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str] + ) -> dict[str, float | None]: + return { + predicted_class: SingleLabelClassifyAggregationLogic._precision( + true_positives=SingleLabelClassifyAggregationLogic._true_positives( + confusion_matrix, predicted_class + ), + false_positives=SingleLabelClassifyAggregationLogic._false_positives( + confusion_matrix, predicted_class + ), + ) + for predicted_class in predicted_classes + } + + @staticmethod + def _recall_by_class( + confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str] + ) -> dict[str, float | None]: + return { + predicted_class: SingleLabelClassifyAggregationLogic._recall( + true_positives=SingleLabelClassifyAggregationLogic._true_positives( + confusion_matrix, predicted_class + ), + false_negatives=SingleLabelClassifyAggregationLogic._false_negatives( + confusion_matrix, predicted_class + ), + ) + for predicted_class in predicted_classes + } + + @staticmethod + def _f1_by_class( + confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str] + ) -> dict[str, float]: + return { + predicted_class: SingleLabelClassifyAggregationLogic._f1( + SingleLabelClassifyAggregationLogic._true_positives( + confusion_matrix, predicted_class + ), + SingleLabelClassifyAggregationLogic._false_positives( + confusion_matrix, predicted_class + ), + SingleLabelClassifyAggregationLogic._false_negatives( + confusion_matrix, predicted_class + ), + ) + for predicted_class in predicted_classes + } + def aggregate( self, evaluations: Iterable[SingleLabelClassifyEvaluation] ) -> AggregatedSingleLabelClassifyEvaluation: @@ -110,7 +215,11 @@ def aggregate( lambda: defaultdict(int) ) by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + + predicted_classes = [] + for evaluation in evaluations: + predicted_classes.append(evaluation.predicted) acc.add(1.0 if evaluation.correct else 0.0) if evaluation.expected_label_missing: missing_labels[evaluation.expected] += 1 @@ -126,6 +235,11 @@ def aggregate( return AggregatedSingleLabelClassifyEvaluation( percentage_correct=acc.extract(), confusion_matrix=confusion_matrix, + precision_by_class=self._precision_by_class( + confusion_matrix, predicted_classes + ), + recall_by_class=self._recall_by_class(confusion_matrix, predicted_classes), + f1_by_class=self._f1_by_class(confusion_matrix, predicted_classes), by_label={ label: AggregatedLabelInfo( expected_count=counts["expected"], diff --git a/tests/examples/classify/test_classify.py b/tests/examples/classify/test_classify.py index cd1278a68..56fb04135 100644 --- a/tests/examples/classify/test_classify.py +++ b/tests/examples/classify/test_classify.py @@ -33,6 +33,8 @@ MultiLabelClassifyEvaluationLogic, MultiLabelClassifyOutput, Probability, + SingleLabelClassifyAggregationLogic, + SingleLabelClassifyEvaluation, ) @@ -339,10 +341,13 @@ def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_wit aggregated_single_label_classify_evaluation = ( AggregatedSingleLabelClassifyEvaluation( percentage_correct=0.123, + precision_by_class={"happy": 1, "sad": 1, "angry": 0}, + recall_by_class={"happy": 1, "sad": 2 / 3, "angry": None}, + f1_by_class={"happy": 1, "sad": 4 / 5, "angry": 0}, confusion_matrix={ - "happy": {"happy": 1}, - "sad": {"sad": 2}, - "angry": {"sad": 1}, + "happy": {"happy": 1, "sad": 0, "angry": 0}, + "sad": {"happy": 0, "sad": 2, "angry": 0}, + "angry": {"happy": 0, "sad": 1, "angry": 0}, }, by_label={ "happy": AggregatedLabelInfo(expected_count=1, predicted_count=1), @@ -374,3 +379,53 @@ def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_wit ) assert aggregation_overview_from_file_repository == aggregation_overview + + +def test_single_label_classify_aggregation_logic_aggregate() -> None: + evaluations = [ + SingleLabelClassifyEvaluation( + correct=True, + predicted="happy", + expected="happy", + expected_label_missing=False, + ), + SingleLabelClassifyEvaluation( + correct=True, predicted="sad", expected="sad", expected_label_missing=False + ), + SingleLabelClassifyEvaluation( + correct=True, predicted="sad", expected="sad", expected_label_missing=False + ), + SingleLabelClassifyEvaluation( + correct=False, + predicted="angry", + expected="sad", + expected_label_missing=False, + ), + ] + aggregated_single_label_classify_evaluation = ( + SingleLabelClassifyAggregationLogic().aggregate(evaluations) + ) + expected_aggregated_single_label_classify_evaluation = ( + AggregatedSingleLabelClassifyEvaluation( + percentage_correct=3 / 4, + precision_by_class={"happy": 1, "sad": 1, "angry": 0}, + recall_by_class={"happy": 1, "sad": 2 / 3, "angry": None}, + f1_by_class={"happy": 1, "sad": 4 / 5, "angry": 0}, + confusion_matrix={ + "happy": {"happy": 1, "sad": 0, "angry": 0}, + "sad": {"happy": 0, "sad": 2, "angry": 0}, + "angry": {"happy": 0, "sad": 1, "angry": 0}, + }, + by_label={ + "happy": AggregatedLabelInfo(expected_count=1, predicted_count=1), + "sad": AggregatedLabelInfo(expected_count=3, predicted_count=2), + "angry": AggregatedLabelInfo(expected_count=0, predicted_count=1), + }, + missing_labels={}, + ) + ) + + assert ( + aggregated_single_label_classify_evaluation + == expected_aggregated_single_label_classify_evaluation + ) diff --git a/tests/examples/classify/test_prompt_based_classify.py b/tests/examples/classify/test_prompt_based_classify.py index 6d1bcb52f..bef24508e 100644 --- a/tests/examples/classify/test_prompt_based_classify.py +++ b/tests/examples/classify/test_prompt_based_classify.py @@ -327,6 +327,16 @@ def test_can_aggregate_evaluations( ) assert aggregation_overview.statistics.percentage_correct == 0.5 + assert aggregation_overview.statistics.confusion_matrix == { + "positive": {"positive": 1, "negative": 0}, + "negative": {"positive": 1, "negative": 0}, + } + assert aggregation_overview.statistics.recall_by_class["positive"] == 1 / 2 + assert aggregation_overview.statistics.precision_by_class["positive"] == 1 + assert aggregation_overview.statistics.f1_by_class["positive"] == 2 / 3 + assert aggregation_overview.statistics.precision_by_class["negative"] == 0 + assert aggregation_overview.statistics.recall_by_class["negative"] is None + assert aggregation_overview.statistics.f1_by_class["negative"] == 0 @pytest.mark.filterwarnings("ignore::UserWarning")