From 6b7fbe1bb6998ec1f8af5549ef63d09296df9af6 Mon Sep 17 00:00:00 2001 From: Niklas Finken <71568758+NickyHavoc@users.noreply.github.com> Date: Thu, 18 Apr 2024 10:21:15 +0200 Subject: [PATCH] fix: `ExpectedSearchOutput` and `SearchEvaluationLogic` include generic `ID` (#752) * Adjust `ExpectedSearchOutput` and logic to include generic `ID` * adjust CHANGELOG.md * fix spelling --- CHANGELOG.md | 4 +++- .../use_cases/search/search.py | 23 ++++++++++--------- tests/use_cases/search/test_search.py | 17 ++++++-------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b7e29ae5..fc55f82c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,9 @@ - feature: `ExpandChunks` task takes a retriever and some search results to expand the chunks to the desired length ### Fixes -... +- fix: `ExpectedSearchOutput` has only relevant fields and supports generic document-`ID` rather than just str +- fix: `SearchEvaluationLogic` explicitly compares documents by ids + ## 0.9.0 diff --git a/src/intelligence_layer/use_cases/search/search.py b/src/intelligence_layer/use_cases/search/search.py index 6d5348b0d..1472cf06c 100644 --- a/src/intelligence_layer/use_cases/search/search.py +++ b/src/intelligence_layer/use_cases/search/search.py @@ -75,13 +75,10 @@ def do_run(self, input: SearchInput, task_span: TaskSpan) -> SearchOutput[ID]: return SearchOutput(results=results) -class ExpectedSearchOutput(BaseModel): - document_id: str +class ExpectedSearchOutput(BaseModel, Generic[ID]): + document_id: ID start_idx: int end_idx: int - origin_chunk: str - answer: str - task_label: str class SearchEvaluation(BaseModel): @@ -92,26 +89,30 @@ class SearchEvaluation(BaseModel): class SearchEvaluationLogic( Generic[ID], SingleOutputEvaluationLogic[ - SearchInput, SearchOutput[ID], ExpectedSearchOutput, SearchEvaluation + SearchInput, SearchOutput[ID], ExpectedSearchOutput[ID], SearchEvaluation ], ): def do_evaluate_single_output( self, - example: Example[SearchInput, ExpectedSearchOutput], + example: Example[SearchInput, ExpectedSearchOutput[ID]], output: SearchOutput[ID], ) -> SearchEvaluation: results = output.results - def overlaps(a: tuple[int, int], b: tuple[int, int]) -> bool: - a_start, a_end = a - b_start, b_end = b + def same_document(id_a: ID, id_b: ID) -> bool: + return id_a == id_b + + def chunks_overlap(range_a: tuple[int, int], range_b: tuple[int, int]) -> bool: + a_start, a_end = range_a + b_start, b_end = range_b return a_start < b_end and b_start < a_end rank, score = next( ( (index + 1, result.score) for index, result in enumerate(results) - if overlaps( + if same_document(result.id, example.expected_output.document_id) + and chunks_overlap( (result.document_chunk.start, result.document_chunk.end), ( example.expected_output.start_idx, diff --git a/tests/use_cases/search/test_search.py b/tests/use_cases/search/test_search.py index 470534c87..9e36637e8 100644 --- a/tests/use_cases/search/test_search.py +++ b/tests/use_cases/search/test_search.py @@ -38,21 +38,18 @@ def search(asymmetric_in_memory_retriever: QdrantInMemoryRetriever) -> Search[in @fixture -def expected_output() -> ExpectedSearchOutput: +def expected_output() -> ExpectedSearchOutput[str]: return ExpectedSearchOutput( document_id="1", start_idx=0, end_idx=5, - origin_chunk="hallo", - answer="", - task_label="", ) @fixture def example( - expected_output: ExpectedSearchOutput, -) -> Example[SearchInput, ExpectedSearchOutput]: + expected_output: ExpectedSearchOutput[str], +) -> Example[SearchInput, ExpectedSearchOutput[str]]: return Example(input=SearchInput(query=""), expected_output=expected_output) @@ -95,7 +92,7 @@ def test_search( def test_search_evaluation_logic_works_for_overlapping_output( - example: Example[SearchInput, ExpectedSearchOutput], + example: Example[SearchInput, ExpectedSearchOutput[str]], search_eval_logic: SearchEvaluationLogic[str], ) -> None: output = SearchOutput( @@ -114,7 +111,7 @@ def test_search_evaluation_logic_works_for_overlapping_output( def test_search_evaluation_logic_works_for_wholly_included_output( - example: Example[SearchInput, ExpectedSearchOutput], + example: Example[SearchInput, ExpectedSearchOutput[str]], search_eval_logic: SearchEvaluationLogic[str], ) -> None: output = SearchOutput( @@ -133,7 +130,7 @@ def test_search_evaluation_logic_works_for_wholly_included_output( def test_search_evaluation_logic_works_for_identical_ranges( - example: Example[SearchInput, ExpectedSearchOutput], + example: Example[SearchInput, ExpectedSearchOutput[str]], search_eval_logic: SearchEvaluationLogic[str], ) -> None: output = SearchOutput( @@ -152,7 +149,7 @@ def test_search_evaluation_logic_works_for_identical_ranges( def test_search_evaluation_logic_works_for_non_overlapping_output( - example: Example[SearchInput, ExpectedSearchOutput], + example: Example[SearchInput, ExpectedSearchOutput[str]], search_eval_logic: SearchEvaluationLogic[str], ) -> None: output = SearchOutput(