Skip to content

Commit

Permalink
feat: compute precision, recall and f1-score by class in SingleLabelC…
Browse files Browse the repository at this point in the history
…lassifyEvaluationLogic (#1110)

* feat: compute precision, recall and f1-score by class in `SingleLabelClassifyEvaluationLogic`
* doc: improve confusion_matrix description in docstring

---------

Co-authored-by: Sebastian Niehus <[email protected]>
  • Loading branch information
FlorianHeiderichAA and SebastianNiehusAA authored Nov 4, 2024
1 parent 38e6638 commit 210c6de
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
- Introduced `InstructionFinetuningDataRepository` for storing and retrieving finetuning samples. Comes in two implementations:
- `PostgresInstructionFinetuningDataRepository` to work with data stored in a Postgres database.
- `FileInstructionFinetuningDataRepository` to work with data stored in the local file-system.

- Compute precision, recall and f1-score by class in `SingleLabelClassifyAggregationLogic`

### Fixes
...
Expand Down
116 changes: 115 additions & 1 deletion src/intelligence_layer/examples/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,18 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel):
Attributes:
percentage_correct: Percentage of answers that were considered to be correct.
confusion_matrix: A matrix showing the predicted classifications vs the expected classifications.
precision_by_class: Precision for each class
recall_by_class: Recall for each class
f1_by_class: f1-score for each class
confusion_matrix: A matrix showing the predicted classifications vs the expected classifications. First key refers to the rows of the confusion matrix (=actual prediction), second key refers to the columns of the matrix (=expected value).
by_label: Each label along side the counts how often it was expected or predicted.
missing_labels: Each expected label which is missing in the set of possible labels in the task input and the number of its occurrences.
"""

percentage_correct: float
precision_by_class: dict[str, float | None]
recall_by_class: dict[str, float | None]
f1_by_class: dict[str, float]
confusion_matrix: dict[str, dict[str, int]]
by_label: dict[str, AggregatedLabelInfo]
missing_labels: dict[str, int]
Expand All @@ -101,6 +107,105 @@ class SingleLabelClassifyAggregationLogic(
SingleLabelClassifyEvaluation, AggregatedSingleLabelClassifyEvaluation
]
):
@staticmethod
def _true_positives(
confusion_matrix: dict[str, dict[str, int]], predicted_class: str
) -> int:
return confusion_matrix[predicted_class][predicted_class]

@staticmethod
def _false_positives(
confusion_matrix: dict[str, dict[str, int]], predicted_class: str
) -> int:
expected_classes_for_predicted_class = confusion_matrix[predicted_class].keys()
return sum(
confusion_matrix[predicted_class][e]
for e in expected_classes_for_predicted_class
if e != predicted_class
)

@staticmethod
def _false_negatives(
confusion_matrix: dict[str, dict[str, int]], predicted_class: str
) -> int:
predicted_classes = confusion_matrix.keys()
return sum(
confusion_matrix[p][predicted_class]
for p in predicted_classes
if p != predicted_class
)

@staticmethod
def _precision(true_positives: int, false_positives: int) -> float | None:
if true_positives + false_positives == 0:
return None
return true_positives / (true_positives + false_positives)

@staticmethod
def _recall(true_positives: int, false_negatives: int) -> float | None:
if true_positives + false_negatives == 0:
return None
return true_positives / (true_positives + false_negatives)

@staticmethod
def _f1(true_positives: int, false_positives: int, false_negatives: int) -> float:
return (
2
* true_positives
/ (2 * true_positives + false_positives + false_negatives)
)

@staticmethod
def _precision_by_class(
confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str]
) -> dict[str, float | None]:
return {
predicted_class: SingleLabelClassifyAggregationLogic._precision(
true_positives=SingleLabelClassifyAggregationLogic._true_positives(
confusion_matrix, predicted_class
),
false_positives=SingleLabelClassifyAggregationLogic._false_positives(
confusion_matrix, predicted_class
),
)
for predicted_class in predicted_classes
}

@staticmethod
def _recall_by_class(
confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str]
) -> dict[str, float | None]:
return {
predicted_class: SingleLabelClassifyAggregationLogic._recall(
true_positives=SingleLabelClassifyAggregationLogic._true_positives(
confusion_matrix, predicted_class
),
false_negatives=SingleLabelClassifyAggregationLogic._false_negatives(
confusion_matrix, predicted_class
),
)
for predicted_class in predicted_classes
}

@staticmethod
def _f1_by_class(
confusion_matrix: dict[str, dict[str, int]], predicted_classes: list[str]
) -> dict[str, float]:
return {
predicted_class: SingleLabelClassifyAggregationLogic._f1(
SingleLabelClassifyAggregationLogic._true_positives(
confusion_matrix, predicted_class
),
SingleLabelClassifyAggregationLogic._false_positives(
confusion_matrix, predicted_class
),
SingleLabelClassifyAggregationLogic._false_negatives(
confusion_matrix, predicted_class
),
)
for predicted_class in predicted_classes
}

def aggregate(
self, evaluations: Iterable[SingleLabelClassifyEvaluation]
) -> AggregatedSingleLabelClassifyEvaluation:
Expand All @@ -110,7 +215,11 @@ def aggregate(
lambda: defaultdict(int)
)
by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))

predicted_classes = []

for evaluation in evaluations:
predicted_classes.append(evaluation.predicted)
acc.add(1.0 if evaluation.correct else 0.0)
if evaluation.expected_label_missing:
missing_labels[evaluation.expected] += 1
Expand All @@ -126,6 +235,11 @@ def aggregate(
return AggregatedSingleLabelClassifyEvaluation(
percentage_correct=acc.extract(),
confusion_matrix=confusion_matrix,
precision_by_class=self._precision_by_class(
confusion_matrix, predicted_classes
),
recall_by_class=self._recall_by_class(confusion_matrix, predicted_classes),
f1_by_class=self._f1_by_class(confusion_matrix, predicted_classes),
by_label={
label: AggregatedLabelInfo(
expected_count=counts["expected"],
Expand Down
61 changes: 58 additions & 3 deletions tests/examples/classify/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
MultiLabelClassifyEvaluationLogic,
MultiLabelClassifyOutput,
Probability,
SingleLabelClassifyAggregationLogic,
SingleLabelClassifyEvaluation,
)


Expand Down Expand Up @@ -339,10 +341,13 @@ def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_wit
aggregated_single_label_classify_evaluation = (
AggregatedSingleLabelClassifyEvaluation(
percentage_correct=0.123,
precision_by_class={"happy": 1, "sad": 1, "angry": 0},
recall_by_class={"happy": 1, "sad": 2 / 3, "angry": None},
f1_by_class={"happy": 1, "sad": 4 / 5, "angry": 0},
confusion_matrix={
"happy": {"happy": 1},
"sad": {"sad": 2},
"angry": {"sad": 1},
"happy": {"happy": 1, "sad": 0, "angry": 0},
"sad": {"happy": 0, "sad": 2, "angry": 0},
"angry": {"happy": 0, "sad": 1, "angry": 0},
},
by_label={
"happy": AggregatedLabelInfo(expected_count=1, predicted_count=1),
Expand Down Expand Up @@ -374,3 +379,53 @@ def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_wit
)

assert aggregation_overview_from_file_repository == aggregation_overview


def test_single_label_classify_aggregation_logic_aggregate() -> None:
evaluations = [
SingleLabelClassifyEvaluation(
correct=True,
predicted="happy",
expected="happy",
expected_label_missing=False,
),
SingleLabelClassifyEvaluation(
correct=True, predicted="sad", expected="sad", expected_label_missing=False
),
SingleLabelClassifyEvaluation(
correct=True, predicted="sad", expected="sad", expected_label_missing=False
),
SingleLabelClassifyEvaluation(
correct=False,
predicted="angry",
expected="sad",
expected_label_missing=False,
),
]
aggregated_single_label_classify_evaluation = (
SingleLabelClassifyAggregationLogic().aggregate(evaluations)
)
expected_aggregated_single_label_classify_evaluation = (
AggregatedSingleLabelClassifyEvaluation(
percentage_correct=3 / 4,
precision_by_class={"happy": 1, "sad": 1, "angry": 0},
recall_by_class={"happy": 1, "sad": 2 / 3, "angry": None},
f1_by_class={"happy": 1, "sad": 4 / 5, "angry": 0},
confusion_matrix={
"happy": {"happy": 1, "sad": 0, "angry": 0},
"sad": {"happy": 0, "sad": 2, "angry": 0},
"angry": {"happy": 0, "sad": 1, "angry": 0},
},
by_label={
"happy": AggregatedLabelInfo(expected_count=1, predicted_count=1),
"sad": AggregatedLabelInfo(expected_count=3, predicted_count=2),
"angry": AggregatedLabelInfo(expected_count=0, predicted_count=1),
},
missing_labels={},
)
)

assert (
aggregated_single_label_classify_evaluation
== expected_aggregated_single_label_classify_evaluation
)
10 changes: 10 additions & 0 deletions tests/examples/classify/test_prompt_based_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def test_can_aggregate_evaluations(
)

assert aggregation_overview.statistics.percentage_correct == 0.5
assert aggregation_overview.statistics.confusion_matrix == {
"positive": {"positive": 1, "negative": 0},
"negative": {"positive": 1, "negative": 0},
}
assert aggregation_overview.statistics.recall_by_class["positive"] == 1 / 2
assert aggregation_overview.statistics.precision_by_class["positive"] == 1
assert aggregation_overview.statistics.f1_by_class["positive"] == 2 / 3
assert aggregation_overview.statistics.precision_by_class["negative"] == 0
assert aggregation_overview.statistics.recall_by_class["negative"] is None
assert aggregation_overview.statistics.f1_by_class["negative"] == 0


@pytest.mark.filterwarnings("ignore::UserWarning")
Expand Down

0 comments on commit 210c6de

Please sign in to comment.