From 8b28a236a11f0cd9b7d041a3fd77ca1e20d9dac4 Mon Sep 17 00:00:00 2001 From: Florian Schepers Date: Mon, 13 May 2024 18:07:59 +0200 Subject: [PATCH] fix: Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted on disc TASK: IL-475 --- CHANGELOG.md | 1 + .../examples/classify/classify.py | 12 +-- tests/conftest.py | 40 +++++++++- tests/evaluation/conftest.py | 37 +-------- tests/use_cases/classify/test_classify.py | 76 ++++++++++++++++++- 5 files changed, 121 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 394959d95..38f95b045 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ### Fixes - Improve description of using artifactory tokens for installation of IL + - Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted in a file repository ### Deprecations ... diff --git a/src/intelligence_layer/examples/classify/classify.py b/src/intelligence_layer/examples/classify/classify.py index d3dcb89e6..611e032f4 100644 --- a/src/intelligence_layer/examples/classify/classify.py +++ b/src/intelligence_layer/examples/classify/classify.py @@ -90,9 +90,9 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel): """ percentage_correct: float - confusion_matrix: Mapping[tuple[str, str], int] - by_label: Mapping[str, AggregatedLabelInfo] - missing_labels: Mapping[str, int] + confusion_matrix: dict[str, dict[str, int]] + by_label: dict[str, AggregatedLabelInfo] + missing_labels: dict[str, int] class SingleLabelClassifyAggregationLogic( @@ -105,14 +105,16 @@ def aggregate( ) -> AggregatedSingleLabelClassifyEvaluation: acc = MeanAccumulator() missing_labels: dict[str, int] = defaultdict(int) - confusion_matrix: dict[tuple[str, str], int] = defaultdict(int) + confusion_matrix: dict[str, dict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for evaluation in evaluations: acc.add(1.0 if evaluation.correct else 0.0) if evaluation.expected_label_missing: missing_labels[evaluation.expected] += 1 else: - confusion_matrix[(evaluation.predicted, evaluation.expected)] += 1 + confusion_matrix[evaluation.predicted][evaluation.expected] += 1 by_label[evaluation.predicted]["predicted"] += 1 by_label[evaluation.expected]["expected"] += 1 diff --git a/tests/conftest.py b/tests/conftest.py index 9016f142e..2c9191420 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,12 +18,20 @@ QdrantInMemoryRetriever, RetrieverType, ) -from intelligence_layer.core import LuminousControlModel, NoOpTracer, Task, TaskSpan +from intelligence_layer.core import ( + LuminousControlModel, + NoOpTracer, + Task, + TaskSpan, + utc_now, +) from intelligence_layer.evaluation import ( + EvaluationOverview, InMemoryAggregationRepository, InMemoryDatasetRepository, InMemoryEvaluationRepository, InMemoryRunRepository, + RunOverview, ) @@ -155,3 +163,33 @@ def in_memory_evaluation_repository() -> InMemoryEvaluationRepository: @fixture def in_memory_aggregation_repository() -> InMemoryAggregationRepository: return InMemoryAggregationRepository() + + +@fixture +def run_overview() -> RunOverview: + return RunOverview( + dataset_id="dataset-id", + id="run-id-1", + start=utc_now(), + end=utc_now(), + failed_example_count=0, + successful_example_count=3, + description="test run overview 1", + ) + + +@fixture +def evaluation_id() -> str: + return "evaluation-id-1" + + +@fixture +def evaluation_overview( + evaluation_id: str, run_overview: RunOverview +) -> EvaluationOverview: + return EvaluationOverview( + id=evaluation_id, + start=utc_now(), + run_overviews=frozenset([run_overview]), + description="test evaluation overview 1", + ) diff --git a/tests/evaluation/conftest.py b/tests/evaluation/conftest.py index 8167d7c91..8aa2be86f 100644 --- a/tests/evaluation/conftest.py +++ b/tests/evaluation/conftest.py @@ -1,4 +1,3 @@ -from datetime import datetime from os import getenv from pathlib import Path from typing import Iterable, Sequence @@ -30,7 +29,6 @@ InMemoryRunRepository, InstructComparisonArgillaAggregationLogic, Runner, - RunOverview, ) from tests.conftest import DummyStringInput, DummyStringOutput @@ -70,11 +68,6 @@ def sequence_examples() -> Iterable[Example[str, None]]: ] -@fixture -def evaluation_id() -> str: - return "evaluation-id-1" - - @fixture def successful_example_evaluation( evaluation_id: str, @@ -116,42 +109,16 @@ def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation: return DummyAggregatedEvaluation(score=0.5) -@fixture -def run_overview() -> RunOverview: - return RunOverview( - dataset_id="dataset-id", - id="run-id-1", - start=utc_now(), - end=utc_now(), - failed_example_count=0, - successful_example_count=3, - description="test run overview 1", - ) - - -@fixture -def evaluation_overview( - evaluation_id: str, run_overview: RunOverview -) -> EvaluationOverview: - return EvaluationOverview( - id=evaluation_id, - start=utc_now(), - run_overviews=frozenset([run_overview]), - description="test evaluation overview 1", - ) - - @fixture def aggregation_overview( evaluation_overview: EvaluationOverview, dummy_aggregated_evaluation: DummyAggregatedEvaluation, ) -> AggregationOverview[DummyAggregatedEvaluation]: - now = datetime.now() return AggregationOverview( evaluation_overviews=frozenset([evaluation_overview]), id="aggregation-id", - start=now, - end=now, + start=utc_now(), + end=utc_now(), successful_evaluation_count=5, crashed_during_evaluation_count=3, description="dummy-evaluator", diff --git a/tests/use_cases/classify/test_classify.py b/tests/use_cases/classify/test_classify.py index b54bd7334..8de01a67e 100644 --- a/tests/use_cases/classify/test_classify.py +++ b/tests/use_cases/classify/test_classify.py @@ -1,22 +1,28 @@ +from pathlib import Path from typing import Iterable, List, Sequence +from uuid import uuid4 from pytest import fixture from intelligence_layer.connectors import AlephAlphaClientProtocol -from intelligence_layer.core import Task, TextChunk +from intelligence_layer.core import Task, TextChunk, utc_now from intelligence_layer.evaluation import ( + AggregationOverview, Aggregator, DatasetRepository, + EvaluationOverview, Evaluator, Example, + FileAggregationRepository, InMemoryAggregationRepository, InMemoryDatasetRepository, InMemoryEvaluationRepository, + InMemoryRunRepository, Runner, - RunRepository, ) from intelligence_layer.examples import ( AggregatedMultiLabelClassifyEvaluation, + AggregatedSingleLabelClassifyEvaluation, ClassifyInput, EmbeddingBasedClassify, LabelWithExamples, @@ -25,6 +31,7 @@ MultiLabelClassifyEvaluationLogic, MultiLabelClassifyOutput, ) +from intelligence_layer.examples.classify.classify import AggregatedLabelInfo @fixture @@ -143,7 +150,7 @@ def multi_label_classify_aggregation_logic() -> MultiLabelClassifyAggregationLog @fixture def classify_evaluator( in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, in_memory_evaluation_repository: InMemoryEvaluationRepository, multi_label_classify_evaluation_logic: MultiLabelClassifyEvaluationLogic, ) -> Evaluator[ @@ -178,11 +185,28 @@ def classify_aggregator( ) +@fixture +def classify_aggregator_file_repo( + in_memory_evaluation_repository: InMemoryEvaluationRepository, + file_aggregation_repository: FileAggregationRepository, + multi_label_classify_aggregation_logic: MultiLabelClassifyAggregationLogic, +) -> Aggregator[ + MultiLabelClassifyEvaluation, + AggregatedMultiLabelClassifyEvaluation, +]: + return Aggregator( + in_memory_evaluation_repository, + file_aggregation_repository, + "multi-label-classify", + multi_label_classify_aggregation_logic, + ) + + @fixture def classify_runner( embedding_based_classify: Task[ClassifyInput, MultiLabelClassifyOutput], in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, ) -> Runner[ClassifyInput, MultiLabelClassifyOutput]: return Runner( embedding_based_classify, @@ -240,3 +264,47 @@ def test_multi_label_classify_evaluator_full_dataset( assert {"positive", "negative", "finance", "school"} == set( aggregation_overview.statistics.class_metrics.keys() ) + + +def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_with_file_repository( + evaluation_overview: EvaluationOverview, + tmp_path: Path, +) -> None: + aggregated_single_label_classify_evaluation = ( + AggregatedSingleLabelClassifyEvaluation( + percentage_correct=0.123, + confusion_matrix={ + "happy": {"happy": 1}, + "sad": {"sad": 2}, + "angry": {"sad": 1}, + }, + by_label={ + "happy": AggregatedLabelInfo(expected_count=1, predicted_count=1), + "sad": AggregatedLabelInfo(expected_count=3, predicted_count=2), + "angry": AggregatedLabelInfo(expected_count=0, predicted_count=1), + }, + missing_labels={"tired": 1}, + ) + ) + + aggregation_overview = AggregationOverview( + id=str(uuid4()), + evaluation_overviews=frozenset([evaluation_overview]), + start=utc_now(), + end=utc_now(), + successful_evaluation_count=5, + crashed_during_evaluation_count=3, + statistics=aggregated_single_label_classify_evaluation, + description="dummy-aggregator", + ) + + aggregation_file_repository = FileAggregationRepository(tmp_path) + aggregation_file_repository.store_aggregation_overview(aggregation_overview) + + aggregation_overview_from_file_repository = ( + aggregation_file_repository.aggregation_overview( + aggregation_overview.id, AggregatedSingleLabelClassifyEvaluation + ) + ) + + assert aggregation_overview_from_file_repository == aggregation_overview