From 19e70fd0389a4836d8cadd56568092f7b902b8d3 Mon Sep 17 00:00:00 2001 From: Ivo Schaper Date: Mon, 8 Apr 2024 15:21:06 +0200 Subject: [PATCH] fix types --- .../use_cases/search/search.py | 18 ++- tests/use_cases/search/test_search.py | 108 ++++++++---------- 2 files changed, 52 insertions(+), 74 deletions(-) diff --git a/src/intelligence_layer/use_cases/search/search.py b/src/intelligence_layer/use_cases/search/search.py index e8f6c1629..a0f7c59c1 100644 --- a/src/intelligence_layer/use_cases/search/search.py +++ b/src/intelligence_layer/use_cases/search/search.py @@ -8,11 +8,7 @@ SearchResult, ) from intelligence_layer.core import Task, TaskSpan -from intelligence_layer.evaluation import ( - EvaluationLogic, - Example, - SuccessfulExampleOutput, -) +from intelligence_layer.evaluation import Example, SingleOutputEvaluationLogic from intelligence_layer.evaluation.aggregation.aggregator import AggregationLogic @@ -90,17 +86,17 @@ class SearchEvaluation(BaseModel): class SearchEvaluationLogic( - EvaluationLogic[ + Generic[ID], + SingleOutputEvaluationLogic[ SearchInput, SearchOutput[ID], ExpectedSearchOutput, SearchEvaluation - ] + ], ): - def do_evaluate( + def do_evaluate_single_output( self, example: Example[SearchInput, ExpectedSearchOutput], - *output: SuccessfulExampleOutput[SearchOutput[ID]], + output: SearchOutput[ID], ) -> SearchEvaluation: - assert len(output) == 1 - results = output[0].output.results + results = output.results def overlaps(a: tuple[int, int], b: tuple[int, int]) -> bool: a_start, a_end = a diff --git a/tests/use_cases/search/test_search.py b/tests/use_cases/search/test_search.py index 2a02a5f8a..87169113b 100644 --- a/tests/use_cases/search/test_search.py +++ b/tests/use_cases/search/test_search.py @@ -10,7 +10,6 @@ ) from intelligence_layer.core import NoOpTracer from intelligence_layer.evaluation import Example -from intelligence_layer.evaluation.run.domain import SuccessfulExampleOutput from intelligence_layer.use_cases import ( ExpectedSearchOutput, Search, @@ -55,8 +54,8 @@ def example( @fixture -def search_eval_logic() -> SearchEvaluationLogic: - return SearchEvaluationLogic() +def search_eval_logic() -> SearchEvaluationLogic[str]: + return SearchEvaluationLogic[str]() def test_search( @@ -78,92 +77,75 @@ def test_search( def test_search_evaluation_logic_works_for_overlapping_output( example: Example[SearchInput, ExpectedSearchOutput], - search_eval_logic: SearchEvaluationLogic, + search_eval_logic: SearchEvaluationLogic[str], ) -> None: - output = SuccessfulExampleOutput( - run_id="1", - example_id="1", - output=SearchOutput( - results=[ - SearchResult[str]( - id="1", - score=0.5, - document_chunk=DocumentChunk(text="llo", start=2, end=5), - ) - ] - ), + output = SearchOutput( + results=[ + SearchResult( + id="1", + score=0.5, + document_chunk=DocumentChunk(text="llo", start=2, end=5), + ) + ] ) - eval = search_eval_logic.do_evaluate(example, output) + eval = search_eval_logic.do_evaluate_single_output(example, output) assert eval.rank == 1 - assert eval.similarity_score == output.output.results[0].score + assert eval.similarity_score == output.results[0].score def test_search_evaluation_logic_works_for_wholly_included_output( example: Example[SearchInput, ExpectedSearchOutput], - search_eval_logic: SearchEvaluationLogic, + search_eval_logic: SearchEvaluationLogic[str], ) -> None: - output = SuccessfulExampleOutput( - run_id="1", - example_id="1", - output=SearchOutput( - results=[ - SearchResult( - id="1", - score=0.5, - document_chunk=DocumentChunk(text="l", start=2, end=3), - ) - ] - ), + output = SearchOutput( + results=[ + SearchResult( + id="1", + score=0.5, + document_chunk=DocumentChunk(text="l", start=2, end=3), + ) + ] ) - eval = search_eval_logic.do_evaluate(example, *[output]) + eval = search_eval_logic.do_evaluate_single_output(example, output) assert eval.rank == 1 - assert eval.similarity_score == output.output.results[0].score + assert eval.similarity_score == output.results[0].score def test_search_evaluation_logic_works_for_identical_ranges( example: Example[SearchInput, ExpectedSearchOutput], - search_eval_logic: SearchEvaluationLogic, + search_eval_logic: SearchEvaluationLogic[str], ) -> None: - logic = SearchEvaluationLogic() - output = SuccessfulExampleOutput( - run_id="1", - example_id="1", - output=SearchOutput( - results=[ - SearchResult( - id="1", - score=0.5, - document_chunk=DocumentChunk(text="hallo", start=0, end=5), - ) - ] - ), + output = SearchOutput( + results=[ + SearchResult( + id="1", + score=0.5, + document_chunk=DocumentChunk(text="hallo", start=0, end=5), + ) + ] ) - eval = search_eval_logic.do_evaluate(example, *[output]) + eval = search_eval_logic.do_evaluate_single_output(example, output) assert eval.rank == 1 - assert eval.similarity_score == output.output.results[0].score + assert eval.similarity_score == output.results[0].score def test_search_evaluation_logic_works_for_non_overlapping_output( example: Example[SearchInput, ExpectedSearchOutput], - search_eval_logic: SearchEvaluationLogic, + search_eval_logic: SearchEvaluationLogic[str], ) -> None: - output = SuccessfulExampleOutput( - run_id="1", - example_id="1", - output=SearchOutput( - results=[ - SearchResult( - id="1", - score=0.5, - document_chunk=DocumentChunk(text=" test.", start=5, end=10), - ) - ] - ), + output = SearchOutput( + results=[ + SearchResult( + id="1", + score=0.5, + document_chunk=DocumentChunk(text=" test.", start=5, end=10), + ) + ] ) - eval = search_eval_logic.do_evaluate(example, *[output]) + eval = search_eval_logic.do_evaluate_single_output(example, output) assert not eval.rank assert not eval.similarity_score