Skip to content

Commit

Permalink
fix: Change confusion_matrix in `SingleLabelClassifyAggregationLogi…
Browse files Browse the repository at this point in the history
…c` such that it can be persisted on disc

TASK: IL-475
  • Loading branch information
FlorianSchepersAA committed May 13, 2024
1 parent 0733d2f commit 84b4e6b
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

### Fixes
- Improve description of using artifactory tokens for installation of IL
- Change `confusion_matrix` in `SingleLabelClassifyAggregationLogic` such that it can be persisted in a file repository

### Deprecations
...
Expand Down
12 changes: 7 additions & 5 deletions src/intelligence_layer/examples/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class AggregatedSingleLabelClassifyEvaluation(BaseModel):
"""

percentage_correct: float
confusion_matrix: Mapping[tuple[str, str], int]
by_label: Mapping[str, AggregatedLabelInfo]
missing_labels: Mapping[str, int]
confusion_matrix: dict[str, dict[str, int]]
by_label: dict[str, AggregatedLabelInfo]
missing_labels: dict[str, int]


class SingleLabelClassifyAggregationLogic(
Expand All @@ -105,14 +105,16 @@ def aggregate(
) -> AggregatedSingleLabelClassifyEvaluation:
acc = MeanAccumulator()
missing_labels: dict[str, int] = defaultdict(int)
confusion_matrix: dict[tuple[str, str], int] = defaultdict(int)
confusion_matrix: dict[str, dict[str, int]] = defaultdict(
lambda: defaultdict(int)
)
by_label: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
for evaluation in evaluations:
acc.add(1.0 if evaluation.correct else 0.0)
if evaluation.expected_label_missing:
missing_labels[evaluation.expected] += 1
else:
confusion_matrix[(evaluation.predicted, evaluation.expected)] += 1
confusion_matrix[evaluation.predicted][evaluation.expected] += 1
by_label[evaluation.predicted]["predicted"] += 1
by_label[evaluation.expected]["expected"] += 1

Expand Down
115 changes: 112 additions & 3 deletions tests/use_cases/classify/test_classify.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Iterable, List, Sequence

from pytest import fixture
Expand All @@ -13,7 +14,19 @@
InMemoryDatasetRepository,
InMemoryEvaluationRepository,
Runner,
RunRepository,
)
from intelligence_layer.evaluation.aggregation.file_aggregation_repository import (
FileAggregationRepository,
)
from intelligence_layer.evaluation.dataset.file_dataset_repository import (
FileDatasetRepository,
)
from intelligence_layer.evaluation.evaluation.file_evaluation_repository import (
FileEvaluationRepository,
)
from intelligence_layer.evaluation.run.file_run_repository import FileRunRepository
from intelligence_layer.evaluation.run.in_memory_run_repository import (
InMemoryRunRepository,
)
from intelligence_layer.examples import (
AggregatedMultiLabelClassifyEvaluation,
Expand All @@ -25,6 +38,14 @@
MultiLabelClassifyEvaluationLogic,
MultiLabelClassifyOutput,
)
from intelligence_layer.examples.classify.classify import (
AggregatedSingleLabelClassifyEvaluation,
SingleLabelClassifyAggregationLogic,
SingleLabelClassifyEvaluationLogic,
)
from intelligence_layer.examples.classify.prompt_based_classify import (
PromptBasedClassify,
)


@fixture
Expand Down Expand Up @@ -143,7 +164,7 @@ def multi_label_classify_aggregation_logic() -> MultiLabelClassifyAggregationLog
@fixture
def classify_evaluator(
in_memory_dataset_repository: DatasetRepository,
in_memory_run_repository: RunRepository,
in_memory_run_repository: InMemoryRunRepository,
in_memory_evaluation_repository: InMemoryEvaluationRepository,
multi_label_classify_evaluation_logic: MultiLabelClassifyEvaluationLogic,
) -> Evaluator[
Expand Down Expand Up @@ -178,11 +199,28 @@ def classify_aggregator(
)


@fixture
def classify_aggregator_file_repo(
in_memory_evaluation_repository: InMemoryEvaluationRepository,
file_aggregation_repository: FileAggregationRepository,
multi_label_classify_aggregation_logic: MultiLabelClassifyAggregationLogic,
) -> Aggregator[
MultiLabelClassifyEvaluation,
AggregatedMultiLabelClassifyEvaluation,
]:
return Aggregator(
in_memory_evaluation_repository,
file_aggregation_repository,
"multi-label-classify",
multi_label_classify_aggregation_logic,
)


@fixture
def classify_runner(
embedding_based_classify: Task[ClassifyInput, MultiLabelClassifyOutput],
in_memory_dataset_repository: DatasetRepository,
in_memory_run_repository: RunRepository,
in_memory_run_repository: InMemoryRunRepository,
) -> Runner[ClassifyInput, MultiLabelClassifyOutput]:
return Runner(
embedding_based_classify,
Expand Down Expand Up @@ -240,3 +278,74 @@ def test_multi_label_classify_evaluator_full_dataset(
assert {"positive", "negative", "finance", "school"} == set(
aggregation_overview.statistics.class_metrics.keys()
)


def test_single_label_classify_with_file_repository(
tmp_path: Path,
) -> None:
in_memory_dataset_repository = FileDatasetRepository(tmp_path)
examples = [
Example(
input=ClassifyInput(
chunk=TextChunk("I am happy."),
labels=frozenset(["happy", "sad", "angry"]),
),
expected_output="happy",
),
Example(
input=ClassifyInput(
chunk=TextChunk("I am sad."),
labels=frozenset(["happy", "sad", "angry"]),
),
expected_output="sad",
),
Example(
input=ClassifyInput(
chunk=TextChunk("I am angry."),
labels=frozenset(["happy", "sad", "angry"]),
),
expected_output="angry",
),
]
dataset_id = in_memory_dataset_repository.create_dataset(
examples=examples, dataset_name="single-label-classify"
).id

in_memory_run_repository = FileRunRepository(tmp_path)
classify_runner = Runner(
PromptBasedClassify(),
in_memory_dataset_repository,
in_memory_run_repository,
"single-label-classify",
)
run_overview = classify_runner.run_dataset(dataset_id)

in_memory_evaluation_repository = FileEvaluationRepository(tmp_path)
classify_evaluator = Evaluator(
in_memory_dataset_repository,
in_memory_run_repository,
in_memory_evaluation_repository,
"single-label-classify",
SingleLabelClassifyEvaluationLogic(),
)
evaluation_overview = classify_evaluator.evaluate_runs(run_overview.id)

aggregation_file_repository = FileAggregationRepository(tmp_path)
classify_aggregator_file_repo = Aggregator(
in_memory_evaluation_repository,
aggregation_file_repository,
"single-label-classify",
SingleLabelClassifyAggregationLogic(),
)

aggregation_overview = classify_aggregator_file_repo.aggregate_evaluation(
evaluation_overview.id
)

aggregation_overview_from_file_repository = (
aggregation_file_repository.aggregation_overview(
aggregation_overview.id, AggregatedSingleLabelClassifyEvaluation
)
)

assert aggregation_overview_from_file_repository == aggregation_overview

0 comments on commit 84b4e6b

Please sign in to comment.