diff --git a/CHANGELOG.md b/CHANGELOG.md index 394959d95..38f95b045 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ### Fixes - Improve description of using artifactory tokens for installation of IL + - Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted in a file repository ### Deprecations ... diff --git a/src/intelligence_layer/examples/classify/classify.py b/src/intelligence_layer/examples/classify/classify.py index d3dcb89e6..611e032f4 100644 --- a/src/intelligence_layer/examples/classify/classify.py +++ b/src/intelligence_layer/examples/classify/classify.py @@ -90,9 +90,9 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel): """ percentage_correct: float - confusion_matrix: Mapping[tuple[str, str], int] - by_label: Mapping[str, AggregatedLabelInfo] - missing_labels: Mapping[str, int] + confusion_matrix: dict[str, dict[str, int]] + by_label: dict[str, AggregatedLabelInfo] + missing_labels: dict[str, int] class SingleLabelClassifyAggregationLogic( @@ -105,14 +105,16 @@ def aggregate( ) -> AggregatedSingleLabelClassifyEvaluation: acc = MeanAccumulator() missing_labels: dict[str, int] = defaultdict(int) - confusion_matrix: dict[tuple[str, str], int] = defaultdict(int) + confusion_matrix: dict[str, dict[str, int]] = defaultdict( + lambda: defaultdict(int) + ) by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) for evaluation in evaluations: acc.add(1.0 if evaluation.correct else 0.0) if evaluation.expected_label_missing: missing_labels[evaluation.expected] += 1 else: - confusion_matrix[(evaluation.predicted, evaluation.expected)] += 1 + confusion_matrix[evaluation.predicted][evaluation.expected] += 1 by_label[evaluation.predicted]["predicted"] += 1 by_label[evaluation.expected]["expected"] += 1 diff --git a/tests/use_cases/classify/test_classify.py b/tests/use_cases/classify/test_classify.py index b54bd7334..f3f7015c7 100644 --- a/tests/use_cases/classify/test_classify.py +++ b/tests/use_cases/classify/test_classify.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Iterable, List, Sequence from pytest import fixture @@ -9,11 +10,15 @@ DatasetRepository, Evaluator, Example, + FileAggregationRepository, + FileDatasetRepository, + FileEvaluationRepository, + FileRunRepository, InMemoryAggregationRepository, InMemoryDatasetRepository, InMemoryEvaluationRepository, + InMemoryRunRepository, Runner, - RunRepository, ) from intelligence_layer.examples import ( AggregatedMultiLabelClassifyEvaluation, @@ -25,6 +30,14 @@ MultiLabelClassifyEvaluationLogic, MultiLabelClassifyOutput, ) +from intelligence_layer.examples.classify.classify import ( + AggregatedSingleLabelClassifyEvaluation, + SingleLabelClassifyAggregationLogic, + SingleLabelClassifyEvaluationLogic, +) +from intelligence_layer.examples.classify.prompt_based_classify import ( + PromptBasedClassify, +) @fixture @@ -143,7 +156,7 @@ def multi_label_classify_aggregation_logic() -> MultiLabelClassifyAggregationLog @fixture def classify_evaluator( in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, in_memory_evaluation_repository: InMemoryEvaluationRepository, multi_label_classify_evaluation_logic: MultiLabelClassifyEvaluationLogic, ) -> Evaluator[ @@ -178,11 +191,28 @@ def classify_aggregator( ) +@fixture +def classify_aggregator_file_repo( + in_memory_evaluation_repository: InMemoryEvaluationRepository, + file_aggregation_repository: FileAggregationRepository, + multi_label_classify_aggregation_logic: MultiLabelClassifyAggregationLogic, +) -> Aggregator[ + MultiLabelClassifyEvaluation, + AggregatedMultiLabelClassifyEvaluation, +]: + return Aggregator( + in_memory_evaluation_repository, + file_aggregation_repository, + "multi-label-classify", + multi_label_classify_aggregation_logic, + ) + + @fixture def classify_runner( embedding_based_classify: Task[ClassifyInput, MultiLabelClassifyOutput], in_memory_dataset_repository: DatasetRepository, - in_memory_run_repository: RunRepository, + in_memory_run_repository: InMemoryRunRepository, ) -> Runner[ClassifyInput, MultiLabelClassifyOutput]: return Runner( embedding_based_classify, @@ -240,3 +270,74 @@ def test_multi_label_classify_evaluator_full_dataset( assert {"positive", "negative", "finance", "school"} == set( aggregation_overview.statistics.class_metrics.keys() ) + + +def test_single_label_classify_with_file_repository( + tmp_path: Path, +) -> None: + in_memory_dataset_repository = FileDatasetRepository(tmp_path) + examples = [ + Example( + input=ClassifyInput( + chunk=TextChunk("I am happy."), + labels=frozenset(["happy", "sad", "angry"]), + ), + expected_output="happy", + ), + Example( + input=ClassifyInput( + chunk=TextChunk("I am sad."), + labels=frozenset(["happy", "sad", "angry"]), + ), + expected_output="sad", + ), + Example( + input=ClassifyInput( + chunk=TextChunk("I am angry."), + labels=frozenset(["happy", "sad", "angry"]), + ), + expected_output="angry", + ), + ] + dataset_id = in_memory_dataset_repository.create_dataset( + examples=examples, dataset_name="single-label-classify" + ).id + + in_memory_run_repository = FileRunRepository(tmp_path) + classify_runner = Runner( + PromptBasedClassify(), + in_memory_dataset_repository, + in_memory_run_repository, + "single-label-classify", + ) + run_overview = classify_runner.run_dataset(dataset_id) + + in_memory_evaluation_repository = FileEvaluationRepository(tmp_path) + classify_evaluator = Evaluator( + in_memory_dataset_repository, + in_memory_run_repository, + in_memory_evaluation_repository, + "single-label-classify", + SingleLabelClassifyEvaluationLogic(), + ) + evaluation_overview = classify_evaluator.evaluate_runs(run_overview.id) + + aggregation_file_repository = FileAggregationRepository(tmp_path) + classify_aggregator_file_repo = Aggregator( + in_memory_evaluation_repository, + aggregation_file_repository, + "single-label-classify", + SingleLabelClassifyAggregationLogic(), + ) + + aggregation_overview = classify_aggregator_file_repo.aggregate_evaluation( + evaluation_overview.id + ) + + aggregation_overview_from_file_repository = ( + aggregation_file_repository.aggregation_overview( + aggregation_overview.id, AggregatedSingleLabelClassifyEvaluation + ) + ) + + assert aggregation_overview_from_file_repository == aggregation_overview