Skip to content

Commit

Permalink
add search aggregation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ivo-1 committed Apr 8, 2024
1 parent 5d8f6f9 commit 6a3d5d2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/intelligence_layer/use_cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .qa.single_chunk_qa import SingleChunkQaInput as SingleChunkQaInput
from .qa.single_chunk_qa import SingleChunkQaOutput as SingleChunkQaOutput
from .search.search import AggregatedSearchEvaluation as AggregatedSearchEvaluation
from .search.search import ChunkFound as ChunkFound
from .search.search import ExpectedSearchOutput as ExpectedSearchOutput
from .search.search import Search as Search
from .search.search import SearchAggregationLogic as SearchAggregationLogic
Expand Down
18 changes: 7 additions & 11 deletions src/intelligence_layer/use_cases/search/search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Iterable, Optional, Sequence
from typing import Generic, Iterable, Mapping, Optional, Sequence

from pydantic import BaseModel

Expand Down Expand Up @@ -125,11 +125,6 @@ def overlaps(a: tuple[int, int], b: tuple[int, int]) -> bool:
return SearchEvaluation(rank=rank, similarity_score=score)


class MeanTopK(BaseModel):
top_k: int
mean: float


class ChunkFound(BaseModel):
found_count: int # found => chunk was within top-k results of retriever
expected_count: int
Expand All @@ -139,7 +134,7 @@ class ChunkFound(BaseModel):
class AggregatedSearchEvaluation(BaseModel):
mean_score: float
mean_reciprocal_rank: float
mean_top_ks: Sequence[MeanTopK]
mean_top_ks: Mapping[int, float]
chunk_found: ChunkFound


Expand All @@ -164,6 +159,8 @@ def aggregate(
chunk_found = True if evaluation.rank else False
chunk_found_accumulator.add(chunk_found)
if chunk_found:
assert evaluation.similarity_score and evaluation.rank

score_accumulator.add(evaluation.similarity_score)
reciprocal_rank_accumulator.add(1 / evaluation.rank)
for top_k in self.top_ks_to_evaluate:
Expand All @@ -174,10 +171,9 @@ def aggregate(
return AggregatedSearchEvaluation(
mean_score=score_accumulator.extract(),
mean_reciprocal_rank=reciprocal_rank_accumulator.extract(),
mean_top_ks=[
MeanTopK(top_k=top_k, mean=acc.extract())
for top_k, acc in top_k_accumulator.items()
],
mean_top_ks={
top_k: acc.extract() for top_k, acc in top_k_accumulator.items()
},
chunk_found=ChunkFound(
found_count=int(chunk_found_accumulator._acc),
expected_count=chunk_found_accumulator._n,
Expand Down
49 changes: 49 additions & 0 deletions tests/use_cases/search/test_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from statistics import mean
from typing import Sequence

from pytest import fixture
Expand All @@ -13,6 +14,8 @@
from intelligence_layer.use_cases import (
ExpectedSearchOutput,
Search,
SearchAggregationLogic,
SearchEvaluation,
SearchEvaluationLogic,
SearchInput,
SearchOutput,
Expand Down Expand Up @@ -58,6 +61,22 @@ def search_eval_logic() -> SearchEvaluationLogic[str]:
return SearchEvaluationLogic[str]()


@fixture
def search_evaluations() -> Sequence[SearchEvaluation]:
return [
SearchEvaluation(rank=1, similarity_score=0.7),
SearchEvaluation(rank=3, similarity_score=0.6),
SearchEvaluation(rank=10, similarity_score=0.5),
SearchEvaluation(rank=None, similarity_score=None),
SearchEvaluation(rank=None, similarity_score=None),
]


@fixture
def search_aggregation_logic() -> SearchAggregationLogic:
return SearchAggregationLogic(top_ks_to_evaluate=[1, 3])


def test_search(
search: Search[int],
no_op_tracer: NoOpTracer,
Expand Down Expand Up @@ -149,3 +168,33 @@ def test_search_evaluation_logic_works_for_non_overlapping_output(

assert not eval.rank
assert not eval.similarity_score


def test_search_aggregation_logic_works(
search_evaluations: Sequence[SearchEvaluation],
search_aggregation_logic: SearchAggregationLogic,
) -> None:
aggregations = search_aggregation_logic.aggregate(search_evaluations)

assert (
aggregations.mean_score
== mean(
[
eval.similarity_score
for eval in search_evaluations
if eval.similarity_score
]
)
== 0.6
)
assert (
round(aggregations.mean_reciprocal_rank, 5)
== round(mean([1 / eval.rank for eval in search_evaluations if eval.rank]), 5)
== round((1 + (1 / 3) + (1 / 10)) / 3, 5)
)
assert aggregations.mean_top_ks
assert aggregations.chunk_found.found_count == 3
assert aggregations.chunk_found.expected_count == len(search_evaluations) == 5
assert aggregations.chunk_found.percentage == 3 / 5
assert aggregations.mean_top_ks[1] == 1 / 3
assert aggregations.mean_top_ks[3] == 2 / 3

0 comments on commit 6a3d5d2

Please sign in to comment.