diff --git a/src/intelligence_layer/use_cases/__init__.py b/src/intelligence_layer/use_cases/__init__.py index 85f233daa..de0947313 100644 --- a/src/intelligence_layer/use_cases/__init__.py +++ b/src/intelligence_layer/use_cases/__init__.py @@ -15,7 +15,6 @@ MultiLabelClassifyEvaluationLogic as MultiLabelClassifyEvaluationLogic, ) from .classify.classify import MultiLabelClassifyOutput as MultiLabelClassifyOutput -from .classify.classify import PerformanceScores as PerformanceScores from .classify.classify import Probability as Probability from .classify.classify import ( SingleLabelClassifyAggregationLogic as SingleLabelClassifyAggregationLogic, diff --git a/src/intelligence_layer/use_cases/classify/classify.py b/src/intelligence_layer/use_cases/classify/classify.py index 083d5d4b7..a2493c4ea 100644 --- a/src/intelligence_layer/use_cases/classify/classify.py +++ b/src/intelligence_layer/use_cases/classify/classify.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from typing import Iterable, Mapping, NewType, Sequence @@ -69,20 +70,6 @@ class SingleLabelClassifyEvaluation(BaseModel): expected_label_missing: bool -class PerformanceScores(BaseModel): - """The relevant metrics resulting from a confusion matrix in a classification run. - - Attributes: - precision: Proportion of correctly predicted classes to all predicted classes. - recall: Proportion of correctly predicted classes to all expected classes. - f1: Aggregated performance, formally the harmonic mean of precision and recall. - """ - - precision: float - recall: float - f1: float - - class AggregatedLabelInfo(BaseModel): expected_count: int predicted_count: int @@ -154,6 +141,10 @@ def do_evaluate_single_output( sorted_classes = sorted( output.scores.items(), key=lambda item: item[1], reverse=True ) + if example.expected_output not in example.input.labels: + warn_message = f"[WARNING] Example with ID '{example.id}' has expected label '{example.expected_output}', which is not part of the example's input labels." + warnings.warn(warn_message, RuntimeWarning) + predicted = sorted_classes[0][0] if predicted == example.expected_output: correct = True @@ -183,6 +174,20 @@ class MultiLabelClassifyEvaluation(BaseModel): fn: frozenset[str] +class MultiLabelClassifyMetrics(BaseModel): + """The relevant metrics resulting from a confusion matrix in a classification run. + + Attributes: + precision: Proportion of correctly predicted classes to all predicted classes. + recall: Proportion of correctly predicted classes to all expected classes. + f1: Aggregated performance, formally the harmonic mean of precision and recall. + """ + + precision: float + recall: float + f1: float + + class AggregatedMultiLabelClassifyEvaluation(BaseModel): """The aggregated evaluation of a multi-label classify dataset. @@ -193,9 +198,9 @@ class AggregatedMultiLabelClassifyEvaluation(BaseModel): """ - class_metrics: Mapping[str, PerformanceScores] - micro_avg: PerformanceScores - macro_avg: PerformanceScores + class_metrics: Mapping[str, MultiLabelClassifyMetrics] + micro_avg: MultiLabelClassifyMetrics + macro_avg: MultiLabelClassifyMetrics class MultiLabelClassifyAggregationLogic( @@ -243,7 +248,7 @@ def aggregate( else 0 ) - class_metrics[label] = PerformanceScores( + class_metrics[label] = MultiLabelClassifyMetrics( precision=precision, recall=recall, f1=f1 ) @@ -255,19 +260,19 @@ def aggregate( sum_f1 += f1 try: - micro_avg = PerformanceScores( + micro_avg = MultiLabelClassifyMetrics( precision=sum_tp / (sum_tp + sum_fp), recall=sum_tp / (sum_tp + sum_fn), f1=(2 * (sum_tp / (sum_tp + sum_fp)) * (sum_tp / (sum_tp + sum_fn))) / ((sum_tp / (sum_tp + sum_fp)) + (sum_tp / (sum_tp + sum_fn))), ) except ZeroDivisionError: - micro_avg = PerformanceScores( + micro_avg = MultiLabelClassifyMetrics( precision=0, recall=0, f1=0, ) - macro_avg = PerformanceScores( + macro_avg = MultiLabelClassifyMetrics( precision=sum_precision / len(class_metrics), recall=sum_recall / len(class_metrics), f1=sum_f1 / len(class_metrics), diff --git a/tests/use_cases/classify/test_prompt_based_classify.py b/tests/use_cases/classify/test_prompt_based_classify.py index a54b58bb4..35e6c8cc6 100644 --- a/tests/use_cases/classify/test_prompt_based_classify.py +++ b/tests/use_cases/classify/test_prompt_based_classify.py @@ -1,5 +1,6 @@ from typing import Sequence +import pytest from pytest import fixture from intelligence_layer.core import InMemoryTracer, NoOpTracer, TextChunk @@ -216,6 +217,35 @@ def test_can_evaluate_classify( assert evaluation.correct is True +def test_classify_warns_on_missing_label( + in_memory_dataset_repository: InMemoryDatasetRepository, + classify_runner: Runner[ClassifyInput, SingleLabelClassifyOutput], + in_memory_evaluation_repository: InMemoryEvaluationRepository, + classify_evaluator: Evaluator[ + ClassifyInput, + SingleLabelClassifyOutput, + Sequence[str], + SingleLabelClassifyEvaluation, + ], + prompt_based_classify: PromptBasedClassify, +) -> None: + example = Example( + input=ClassifyInput( + chunk=TextChunk("This is good"), + labels=frozenset({"positive", "negative"}), + ), + expected_output="SomethingElse", + ) + + dataset_id = in_memory_dataset_repository.create_dataset( + examples=[example], dataset_name="test-dataset" + ).id + + run_overview = classify_runner.run_dataset(dataset_id) + + pytest.warns(RuntimeWarning, classify_evaluator.evaluate_runs, run_overview.id) + + def test_can_aggregate_evaluations( classify_evaluator: Evaluator[ ClassifyInput,