Skip to content

Commit

Permalink
feat: Add warning to SingleLabelClassifyEvaluationLogic on missing in…
Browse files Browse the repository at this point in the history
…put label

IL-367
  • Loading branch information
SebastianNiehusTNG authored and JohannesWesch committed Apr 4, 2024
1 parent 95033f3 commit b51d98f
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 22 deletions.
1 change: 0 additions & 1 deletion src/intelligence_layer/use_cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
MultiLabelClassifyEvaluationLogic as MultiLabelClassifyEvaluationLogic,
)
from .classify.classify import MultiLabelClassifyOutput as MultiLabelClassifyOutput
from .classify.classify import PerformanceScores as PerformanceScores
from .classify.classify import Probability as Probability
from .classify.classify import (
SingleLabelClassifyAggregationLogic as SingleLabelClassifyAggregationLogic,
Expand Down
47 changes: 26 additions & 21 deletions src/intelligence_layer/use_cases/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections import defaultdict
from typing import Iterable, Mapping, NewType, Sequence

Expand Down Expand Up @@ -69,20 +70,6 @@ class SingleLabelClassifyEvaluation(BaseModel):
expected_label_missing: bool


class PerformanceScores(BaseModel):
"""The relevant metrics resulting from a confusion matrix in a classification run.
Attributes:
precision: Proportion of correctly predicted classes to all predicted classes.
recall: Proportion of correctly predicted classes to all expected classes.
f1: Aggregated performance, formally the harmonic mean of precision and recall.
"""

precision: float
recall: float
f1: float


class AggregatedLabelInfo(BaseModel):
expected_count: int
predicted_count: int
Expand Down Expand Up @@ -154,6 +141,10 @@ def do_evaluate_single_output(
sorted_classes = sorted(
output.scores.items(), key=lambda item: item[1], reverse=True
)
if example.expected_output not in example.input.labels:
warn_message = f"[WARNING] Example with ID '{example.id}' has expected label '{example.expected_output}', which is not part of the example's input labels."
warnings.warn(warn_message, RuntimeWarning)

predicted = sorted_classes[0][0]
if predicted == example.expected_output:
correct = True
Expand Down Expand Up @@ -183,6 +174,20 @@ class MultiLabelClassifyEvaluation(BaseModel):
fn: frozenset[str]


class MultiLabelClassifyMetrics(BaseModel):
"""The relevant metrics resulting from a confusion matrix in a classification run.
Attributes:
precision: Proportion of correctly predicted classes to all predicted classes.
recall: Proportion of correctly predicted classes to all expected classes.
f1: Aggregated performance, formally the harmonic mean of precision and recall.
"""

precision: float
recall: float
f1: float


class AggregatedMultiLabelClassifyEvaluation(BaseModel):
"""The aggregated evaluation of a multi-label classify dataset.
Expand All @@ -193,9 +198,9 @@ class AggregatedMultiLabelClassifyEvaluation(BaseModel):
"""

class_metrics: Mapping[str, PerformanceScores]
micro_avg: PerformanceScores
macro_avg: PerformanceScores
class_metrics: Mapping[str, MultiLabelClassifyMetrics]
micro_avg: MultiLabelClassifyMetrics
macro_avg: MultiLabelClassifyMetrics


class MultiLabelClassifyAggregationLogic(
Expand Down Expand Up @@ -243,7 +248,7 @@ def aggregate(
else 0
)

class_metrics[label] = PerformanceScores(
class_metrics[label] = MultiLabelClassifyMetrics(
precision=precision, recall=recall, f1=f1
)

Expand All @@ -255,19 +260,19 @@ def aggregate(
sum_f1 += f1

try:
micro_avg = PerformanceScores(
micro_avg = MultiLabelClassifyMetrics(
precision=sum_tp / (sum_tp + sum_fp),
recall=sum_tp / (sum_tp + sum_fn),
f1=(2 * (sum_tp / (sum_tp + sum_fp)) * (sum_tp / (sum_tp + sum_fn)))
/ ((sum_tp / (sum_tp + sum_fp)) + (sum_tp / (sum_tp + sum_fn))),
)
except ZeroDivisionError:
micro_avg = PerformanceScores(
micro_avg = MultiLabelClassifyMetrics(
precision=0,
recall=0,
f1=0,
)
macro_avg = PerformanceScores(
macro_avg = MultiLabelClassifyMetrics(
precision=sum_precision / len(class_metrics),
recall=sum_recall / len(class_metrics),
f1=sum_f1 / len(class_metrics),
Expand Down
30 changes: 30 additions & 0 deletions tests/use_cases/classify/test_prompt_based_classify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Sequence

import pytest
from pytest import fixture

from intelligence_layer.core import InMemoryTracer, NoOpTracer, TextChunk
Expand Down Expand Up @@ -216,6 +217,35 @@ def test_can_evaluate_classify(
assert evaluation.correct is True


def test_classify_warns_on_missing_label(
in_memory_dataset_repository: InMemoryDatasetRepository,
classify_runner: Runner[ClassifyInput, SingleLabelClassifyOutput],
in_memory_evaluation_repository: InMemoryEvaluationRepository,
classify_evaluator: Evaluator[
ClassifyInput,
SingleLabelClassifyOutput,
Sequence[str],
SingleLabelClassifyEvaluation,
],
prompt_based_classify: PromptBasedClassify,
) -> None:
example = Example(
input=ClassifyInput(
chunk=TextChunk("This is good"),
labels=frozenset({"positive", "negative"}),
),
expected_output="SomethingElse",
)

dataset_id = in_memory_dataset_repository.create_dataset(
examples=[example], dataset_name="test-dataset"
).id

run_overview = classify_runner.run_dataset(dataset_id)

pytest.warns(RuntimeWarning, classify_evaluator.evaluate_runs, run_overview.id)


def test_can_aggregate_evaluations(
classify_evaluator: Evaluator[
ClassifyInput,
Expand Down

0 comments on commit b51d98f

Please sign in to comment.