diff --git a/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py b/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py index 94166b820..8a18d9807 100644 --- a/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py +++ b/src/intelligence_layer/connectors/retrievers/in_memory_retriever.py @@ -126,14 +126,14 @@ def _add_texts_to_memory(self, documents: Sequence[Document]) -> None: ) def get_filtered_documents_with_scores( - self, query: str, limit: int, filter: models.Filter + self, query: str, filter: models.Filter ) -> Sequence[SearchResult]: """Specific method for `InMemoryRetriever` to support filtering search results.""" query_embedding = self._embed(query, self._query_representation) search_result = self._search_client.search( collection_name=self._collection_name, query_vector=query_embedding, - limit=limit, + limit=self._k, query_filter=filter, ) return [self._point_to_search_result(point) for point in search_result] diff --git a/src/intelligence_layer/use_cases/classify/embedding_based_classify.py b/src/intelligence_layer/use_cases/classify/embedding_based_classify.py index 63fa56731..7116dec16 100644 --- a/src/intelligence_layer/use_cases/classify/embedding_based_classify.py +++ b/src/intelligence_layer/use_cases/classify/embedding_based_classify.py @@ -12,7 +12,7 @@ RetrieverType, ) from intelligence_layer.core.logger import DebugLogger -from intelligence_layer.core.task import Chunk, Probability, Task +from intelligence_layer.core.task import Chunk, Probability from intelligence_layer.use_cases.classify.classify import ( Classify, ClassifyInput, @@ -111,7 +111,7 @@ def __init__( retriever = InMemoryRetriever( client, documents=documents, - k=len(documents), + k=scoring.value, retriever_type=RetrieverType.SYMMETRIC, ) self._filter_search = FilterSearch(retriever) @@ -149,7 +149,6 @@ def _label_search( ) -> SearchOutput: search_input = FilterSearchInput( query=chunk, - limit=self._scoring.value, filter=models.Filter( must=[ models.FieldCondition( diff --git a/src/intelligence_layer/use_cases/search/filter_search.py b/src/intelligence_layer/use_cases/search/filter_search.py index 469305632..0dbcf7e7b 100644 --- a/src/intelligence_layer/use_cases/search/filter_search.py +++ b/src/intelligence_layer/use_cases/search/filter_search.py @@ -14,12 +14,10 @@ class FilterSearchInput(BaseModel): Attributes: query: The text to be searched with. - limit: The maximum number of items to be retrieved. filter: Conditions to filter by as offered by Qdrant. """ query: str - limit: int filter: Filter @@ -44,7 +42,6 @@ class FilterSearch(Task[FilterSearchInput, SearchOutput]): >>> task = FilterSearch(retriever) >>> input = FilterSearchInput( >>> query="When did East and West Germany reunite?" - >>> limit=1, >>> filter=models.Filter( >>> must=[ >>> models.FieldCondition( @@ -64,6 +61,6 @@ def __init__(self, in_memory_retriever: InMemoryRetriever): def run(self, input: FilterSearchInput, logger: DebugLogger) -> SearchOutput: results = self._in_memory_retriever.get_filtered_documents_with_scores( - input.query, input.limit, input.filter + input.query, input.filter ) return SearchOutput(results=results) diff --git a/tests/use_cases/classify/test_embedding_based_classify.py b/tests/use_cases/classify/test_embedding_based_classify.py index 8714859ab..f68b9768b 100644 --- a/tests/use_cases/classify/test_embedding_based_classify.py +++ b/tests/use_cases/classify/test_embedding_based_classify.py @@ -63,6 +63,39 @@ def test_embedding_based_classify_raises_for_unknown_label( embedding_based_classify.run(classify_input, NoOpDebugLogger()) +def test_embedding_based_classify_works_for_empty_labels_in_request( + embedding_based_classify: EmbeddingBasedClassify, +) -> None: + classify_input = ClassifyInput( + chunk=Chunk("This is good"), + labels=frozenset(), + ) + result = embedding_based_classify.run(classify_input, NoOpDebugLogger()) + assert result.scores == {} + + +def test_embedding_based_classify_works_without_examples( + client: Client, +) -> None: + labels_with_examples = [ + LabelWithExamples( + name="positive", + examples=[], + ), + LabelWithExamples( + name="negative", + examples=[], + ), + ] + embedding_based_classify = EmbeddingBasedClassify(labels_with_examples, client) + classify_input = ClassifyInput( + chunk=Chunk("This is good"), + labels=frozenset(), + ) + result = embedding_based_classify.run(classify_input, NoOpDebugLogger()) + assert result.scores == {} + + def test_can_evaluate_embedding_based_classify( embedding_based_classify: EmbeddingBasedClassify, ) -> None: diff --git a/tests/use_cases/search/test_filter_search.py b/tests/use_cases/search/test_filter_search.py index 7f8e09a30..5959fd2ff 100644 --- a/tests/use_cases/search/test_filter_search.py +++ b/tests/use_cases/search/test_filter_search.py @@ -22,8 +22,8 @@ def in_memory_retriever_documents() -> Sequence[Document]: metadata={"type": "doc"}, ), Document( - text="Cats are small animals. Well, I do not fit at all but I am of the correct type.", - metadata={"type": "doc"}, + text="Cats are small animals. Well, I do not fit at all and I am of the correct type.", + metadata={"type": "no doc"}, ), Document( text="Germany reunited in 1990. This document fits perfectly but it is of the wrong type.", @@ -44,7 +44,6 @@ def test_filter_search( ) -> None: search_input = FilterSearchInput( query="When did Germany reunite?", - limit=1, filter=models.Filter( must=[ models.FieldCondition(