From abba4c87a5b38cb1366b261752d451445552c45e Mon Sep 17 00:00:00 2001 From: Florian Schepers Date: Mon, 13 May 2024 18:07:59 +0200 Subject: [PATCH] fix: Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted on disc TASK: IL-475 --- CHANGELOG.md | 1 + src/intelligence_layer/examples/__init__.py | 1 + .../examples/classify/classify.py | 12 +-- tests/conftest.py | 43 ++++++++++- tests/evaluation/conftest.py | 40 +--------- tests/use_cases/classify/test_classify.py | 76 ++++++++++++++++++- 6 files changed, 125 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60e88b254..e66d35ce1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Check the how-to for detailed information [here](./src/documentation/how_tos/how ### Fixes - Improve description of using artifactory tokens for installation of IL + - Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted in a file repository ### Deprecations ... diff --git a/src/intelligence_layer/examples/__init__.py b/src/intelligence_layer/examples/__init__.py index cc8c06709..d3132b138 100644 --- a/src/intelligence_layer/examples/__init__.py +++ b/src/intelligence_layer/examples/__init__.py @@ -1,3 +1,4 @@ +from .classify.classify import AggregatedLabelInfo as AggregatedLabelInfo from .classify.classify import ( AggregatedMultiLabelClassifyEvaluation as AggregatedMultiLabelClassifyEvaluation, ) diff --git a/src/intelligence_layer/examples/classify/classify.py b/src/intelligence_layer/examples/classify/classify.py index d3dcb89e6..611e032f4 100644 --- a/src/intelligence_layer/examples/classify/classify.py +++ b/src/intelligence_layer/examples/classify/classify.py @@ -90,9 +90,9 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel): """ percentage_correct: float - confusion_matrix: Mapping[tuple[str, str], int] - by_label: Mapping[str, AggregatedLabelInfo] - missing_labels: Mapping[str, int] + confusion_matrix: dict[str, dict[str, int]] + by_label: dict[str, AggregatedLabelInfo] + missing_labels: dict[str, int] class SingleLabelClassifyAggregationLogic( @@ -105,14 +105,16 @@ def aggregate( ) -> AggregatedSingleLabelClassifyEvaluation: acc = MeanAccumulator() missing_labels: dict[str, int] = defaultdict(int) - confusion_matrix: dict[tuple[str, str], int] = defaultdict(int) + confusion_matrix: dict[str, dict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for evaluation in evaluations: acc.add(1.0 if evaluation.correct else 0.0) if evaluation.expected_label_missing: missing_labels[evaluation.expected] += 1 else: - confusion_matrix[(evaluation.predicted, evaluation.expected)] += 1 + confusion_matrix[evaluation.predicted][evaluation.expected] += 1 by_label[evaluation.predicted]["predicted"] += 1 by_label[evaluation.expected]["expected"] += 1 diff --git a/tests/conftest.py b/tests/conftest.py index e4eb66ae7..6b740a5c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,13 +17,21 @@ QdrantInMemoryRetriever, RetrieverType, ) -from intelligence_layer.core import LuminousControlModel, NoOpTracer, Task, TaskSpan +from intelligence_layer.core import ( + LuminousControlModel, + NoOpTracer, + Task, + TaskSpan, + utc_now, +) from intelligence_layer.evaluation import ( AsyncInMemoryEvaluationRepository, + EvaluationOverview, InMemoryAggregationRepository, InMemoryDatasetRepository, InMemoryEvaluationRepository, InMemoryRunRepository, + RunOverview, ) @@ -154,3 +162,36 @@ def in_memory_aggregation_repository() -> InMemoryAggregationRepository: @fixture() def async_in_memory_evaluation_repository() -> AsyncInMemoryEvaluationRepository: return AsyncInMemoryEvaluationRepository() + + +@fixture +def run_overview() -> RunOverview: + return RunOverview( + dataset_id="dataset-id", + id="run-id-1", + start=utc_now(), + end=utc_now(), + failed_example_count=0, + successful_example_count=3, + description="test run overview 1", + ) + + +@fixture +def evaluation_id() -> str: + return "evaluation-id-1" + + +@fixture +def evaluation_overview( + evaluation_id: str, run_overview: RunOverview +) -> EvaluationOverview: + return EvaluationOverview( + id=evaluation_id, + start_date=utc_now(), + end_date=utc_now(), + successful_evaluation_count=1, + failed_evaluation_count=1, + run_overviews=frozenset([run_overview]), + description="test evaluation overview 1", + ) diff --git a/tests/evaluation/conftest.py b/tests/evaluation/conftest.py index 052351de5..c15362589 100644 --- a/tests/evaluation/conftest.py +++ b/tests/evaluation/conftest.py @@ -1,4 +1,3 @@ -from datetime import datetime from os import getenv from pathlib import Path from typing import Iterable, Sequence @@ -29,7 +28,6 @@ InMemoryDatasetRepository, InMemoryRunRepository, Runner, - RunOverview, ) from tests.conftest import DummyStringInput, DummyStringOutput @@ -69,11 +67,6 @@ def sequence_examples() -> Iterable[Example[str, None]]: ] -@fixture -def evaluation_id() -> str: - return "evaluation-id-1" - - @fixture def successful_example_evaluation( evaluation_id: str, @@ -115,45 +108,16 @@ def dummy_aggregated_evaluation() -> DummyAggregatedEvaluation: return DummyAggregatedEvaluation(score=0.5) -@fixture -def run_overview() -> RunOverview: - return RunOverview( - dataset_id="dataset-id", - id="run-id-1", - start=utc_now(), - end=utc_now(), - failed_example_count=0, - successful_example_count=3, - description="test run overview 1", - ) - - -@fixture -def evaluation_overview( - evaluation_id: str, run_overview: RunOverview -) -> EvaluationOverview: - return EvaluationOverview( - id=evaluation_id, - start_date=utc_now(), - end_date=utc_now(), - successful_evaluation_count=1, - failed_evaluation_count=1, - run_overviews=frozenset([run_overview]), - description="test evaluation overview 1", - ) - - @fixture def aggregation_overview( evaluation_overview: EvaluationOverview, dummy_aggregated_evaluation: DummyAggregatedEvaluation, ) -> AggregationOverview[DummyAggregatedEvaluation]: - now = datetime.now() return AggregationOverview( evaluation_overviews=frozenset([evaluation_overview]), id="aggregation-id", - start=now, - end=now, + start=utc_now(), + end=utc_now(), successful_evaluation_count=5, crashed_during_evaluation_count=3, description="dummy-evaluator", diff --git a/tests/use_cases/classify/test_classify.py b/tests/use_cases/classify/test_classify.py index b54bd7334..80fddae2a 100644 --- a/tests/use_cases/classify/test_classify.py +++ b/tests/use_cases/classify/test_classify.py @@ -1,22 +1,29 @@ +from pathlib import Path from typing import Iterable, List, Sequence +from uuid import uuid4 from pytest import fixture from intelligence_layer.connectors import AlephAlphaClientProtocol -from intelligence_layer.core import Task, TextChunk +from intelligence_layer.core import Task, TextChunk, utc_now from intelligence_layer.evaluation import ( + AggregationOverview, Aggregator, DatasetRepository, + EvaluationOverview, Evaluator, Example, + FileAggregationRepository, InMemoryAggregationRepository, InMemoryDatasetRepository, InMemoryEvaluationRepository, + InMemoryRunRepository, Runner, - RunRepository, ) from intelligence_layer.examples import ( + AggregatedLabelInfo, AggregatedMultiLabelClassifyEvaluation, + AggregatedSingleLabelClassifyEvaluation, ClassifyInput, EmbeddingBasedClassify, LabelWithExamples, @@ -143,7 +150,7 @@ def multi_label_classify_aggregation_logic() -> MultiLabelClassifyAggregationLog @fixture def classify_evaluator( in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, in_memory_evaluation_repository: InMemoryEvaluationRepository, multi_label_classify_evaluation_logic: MultiLabelClassifyEvaluationLogic, ) -> Evaluator[ @@ -178,11 +185,28 @@ def classify_aggregator( ) +@fixture +def classify_aggregator_file_repo( + in_memory_evaluation_repository: InMemoryEvaluationRepository, + file_aggregation_repository: FileAggregationRepository, + multi_label_classify_aggregation_logic: MultiLabelClassifyAggregationLogic, +) -> Aggregator[ + MultiLabelClassifyEvaluation, + AggregatedMultiLabelClassifyEvaluation, +]: + return Aggregator( + in_memory_evaluation_repository, + file_aggregation_repository, + "multi-label-classify", + multi_label_classify_aggregation_logic, + ) + + @fixture def classify_runner( embedding_based_classify: Task[ClassifyInput, MultiLabelClassifyOutput], in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, ) -> Runner[ClassifyInput, MultiLabelClassifyOutput]: return Runner( embedding_based_classify, @@ -240,3 +264,47 @@ def test_multi_label_classify_evaluator_full_dataset( assert {"positive", "negative", "finance", "school"} == set( aggregation_overview.statistics.class_metrics.keys() ) + + +def test_confusion_matrix_in_single_label_classify_aggregation_is_compatible_with_file_repository( + evaluation_overview: EvaluationOverview, + tmp_path: Path, +) -> None: + aggregated_single_label_classify_evaluation = ( + AggregatedSingleLabelClassifyEvaluation( + percentage_correct=0.123, + confusion_matrix={ + "happy": {"happy": 1}, + "sad": {"sad": 2}, + "angry": {"sad": 1}, + }, + by_label={ + "happy": AggregatedLabelInfo(expected_count=1, predicted_count=1), + "sad": AggregatedLabelInfo(expected_count=3, predicted_count=2), + "angry": AggregatedLabelInfo(expected_count=0, predicted_count=1), + }, + missing_labels={"tired": 1}, + ) + ) + + aggregation_overview = AggregationOverview( + id=str(uuid4()), + evaluation_overviews=frozenset([evaluation_overview]), + start=utc_now(), + end=utc_now(), + successful_evaluation_count=5, + crashed_during_evaluation_count=3, + statistics=aggregated_single_label_classify_evaluation, + description="dummy-aggregator", + ) + + aggregation_file_repository = FileAggregationRepository(tmp_path) + aggregation_file_repository.store_aggregation_overview(aggregation_overview) + + aggregation_overview_from_file_repository = ( + aggregation_file_repository.aggregation_overview( + aggregation_overview.id, AggregatedSingleLabelClassifyEvaluation + ) + ) + + assert aggregation_overview_from_file_repository == aggregation_overview