diff --git a/src/intelligence_layer/use_cases/__init__.py b/src/intelligence_layer/use_cases/__init__.py index a97a5d630..cc8c06709 100644 --- a/src/intelligence_layer/use_cases/__init__.py +++ b/src/intelligence_layer/use_cases/__init__.py @@ -62,9 +62,9 @@ from .qa.single_chunk_qa import SingleChunkQa as SingleChunkQa from .qa.single_chunk_qa import SingleChunkQaInput as SingleChunkQaInput from .qa.single_chunk_qa import SingleChunkQaOutput as SingleChunkQaOutput -from .search.expand_chunk import ExpandChunkInput as ExpandChunkInput -from .search.expand_chunk import ExpandChunkOutput as ExpandChunkOutput -from .search.expand_chunk import ExpandChunks as ExpandChunks +from .search.expand_chunks import ExpandChunks as ExpandChunks +from .search.expand_chunks import ExpandChunksInput as ExpandChunksInput +from .search.expand_chunks import ExpandChunksOutput as ExpandChunksOutput from .search.search import AggregatedSearchEvaluation as AggregatedSearchEvaluation from .search.search import ChunkFound as ChunkFound from .search.search import ExpectedSearchOutput as ExpectedSearchOutput diff --git a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py index 45bcfc183..6c591c7a8 100644 --- a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py +++ b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import Generic, Optional, Sequence from pydantic import BaseModel @@ -8,9 +9,14 @@ SearchResult, ) from intelligence_layer.core.chunk import TextChunk +from intelligence_layer.core.model import ControlModel, LuminousControlModel from intelligence_layer.core.task import Task from intelligence_layer.core.text_highlight import ScoredTextHighlight from intelligence_layer.core.tracer.tracer import TaskSpan +from intelligence_layer.use_cases.search.expand_chunks import ( + ExpandChunks, + ExpandChunksInput, +) from intelligence_layer.use_cases.search.search import Search, SearchInput from .retriever_based_qa import RetrieverBasedQaInput @@ -18,13 +24,15 @@ class AnswerSource(BaseModel, Generic[ID]): - search_result: SearchResult[ID] + document_id: ID + chunk: TextChunk highlights: Sequence[ScoredTextHighlight] class MultipleChunkRetrieverQaOutput(BaseModel, Generic[ID]): answer: Optional[str] sources: Sequence[AnswerSource[ID]] + search_results: Sequence[SearchResult[ID]] class MultipleChunkRetrieverQa( @@ -66,13 +74,17 @@ class MultipleChunkRetrieverQa( def __init__( self, retriever: BaseRetriever[ID], - k: int = 5, + model: ControlModel | None = None, + insert_chunk_number: int = 5, + insert_chunk_size: int = 256, single_chunk_qa: Task[SingleChunkQaInput, SingleChunkQaOutput] | None = None, ): super().__init__() + self._model = model or LuminousControlModel("luminous-supreme-control") self._search = Search(retriever) - self._k = k - self._single_chunk_qa = single_chunk_qa or SingleChunkQa() + self._expand_chunks = ExpandChunks(retriever, self._model, insert_chunk_size) + self._single_chunk_qa = single_chunk_qa or SingleChunkQa(self._model) + self._insert_chunk_number = insert_chunk_number @staticmethod def _combine_input_texts(chunks: Sequence[str]) -> tuple[TextChunk, Sequence[int]]: @@ -111,23 +123,46 @@ def _get_highlights_per_chunk( overlapping_ranges.append(current_overlaps) return overlapping_ranges + def _expand_search_result_chunks( + self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan + ) -> Sequence[tuple[ID, TextChunk]]: + grouped_results: dict[ID, list[SearchResult[ID]]] = defaultdict(list) + for result in search_results: + grouped_results[result.id].append(result) + + chunks_to_insert: list[tuple[ID, TextChunk]] = [] + for id, results in grouped_results.items(): + input = ExpandChunksInput( + document_id=id, chunks_found=[r.document_chunk for r in results] + ) + expand_chunks_output = self._expand_chunks.run(input, task_span) + for chunk in expand_chunks_output.chunks: + if len(chunks_to_insert) >= self._insert_chunk_number: + break + chunks_to_insert.append((id, chunk)) + + return chunks_to_insert + def do_run( self, input: RetrieverBasedQaInput, task_span: TaskSpan ) -> MultipleChunkRetrieverQaOutput[ID]: search_output = self._search.run( SearchInput(query=input.question), task_span ).results - sorted_search_output = sorted( - search_output, - key=lambda output: output.score, # not reversing on purpose because model performs better if relevant info is at the end - )[-self._k :] + sorted_search_results = sorted( + search_output, key=lambda output: output.score, reverse=True + ) + + chunks_to_insert = self._expand_search_result_chunks( + sorted_search_results, task_span + ) - chunk, chunk_start_indices = self._combine_input_texts( - [output.document_chunk.text for output in sorted_search_output] + chunk_for_prompt, chunk_start_indices = self._combine_input_texts( + [c[1] for c in chunks_to_insert] ) single_chunk_qa_input = SingleChunkQaInput( - chunk=chunk, + chunk=chunk_for_prompt, question=input.question, language=input.language, ) @@ -144,9 +179,13 @@ def do_run( answer=single_chunk_qa_output.answer, sources=[ AnswerSource( - search_result=chunk, + document_id=id_and_chunk[0], + chunk=id_and_chunk[1], highlights=highlights, ) - for chunk, highlights in zip(sorted_search_output, highlights_per_chunk) + for id_and_chunk, highlights in zip( + chunks_to_insert, highlights_per_chunk + ) ], + search_results=sorted_search_results, ) diff --git a/src/intelligence_layer/use_cases/search/expand_chunk.py b/src/intelligence_layer/use_cases/search/expand_chunks.py similarity index 88% rename from src/intelligence_layer/use_cases/search/expand_chunk.py rename to src/intelligence_layer/use_cases/search/expand_chunks.py index 4d24147ac..f13a2e784 100644 --- a/src/intelligence_layer/use_cases/search/expand_chunk.py +++ b/src/intelligence_layer/use_cases/search/expand_chunks.py @@ -10,16 +10,16 @@ from intelligence_layer.core.tracer.tracer import TaskSpan -class ExpandChunkInput(BaseModel, Generic[ID]): +class ExpandChunksInput(BaseModel, Generic[ID]): document_id: ID chunks_found: Sequence[DocumentChunk] -class ExpandChunkOutput(BaseModel): +class ExpandChunksOutput(BaseModel): chunks: Sequence[TextChunk] -class ExpandChunks(Generic[ID], Task[ExpandChunkInput[ID], ExpandChunkOutput]): +class ExpandChunks(Generic[ID], Task[ExpandChunksInput[ID], ExpandChunksOutput]): def __init__( self, retriever: BaseRetriever[ID], @@ -31,8 +31,8 @@ def __init__( self._chunk_with_indices = ChunkWithIndices(model, max_chunk_size) def do_run( - self, input: ExpandChunkInput[ID], task_span: TaskSpan - ) -> ExpandChunkOutput: + self, input: ExpandChunksInput[ID], task_span: TaskSpan + ) -> ExpandChunksOutput: full_doc = self._retriever.get_full_document(input.document_id) if not full_doc: raise RuntimeError(f"No document for id '{input.document_id}' found") @@ -46,7 +46,7 @@ def do_run( [(chunk.start, chunk.end) for chunk in input.chunks_found], ) - return ExpandChunkOutput( + return ExpandChunksOutput( chunks=[ chunk_with_indices[index].chunk for index in overlapping_chunk_indices ] diff --git a/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py b/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py index 45b285fbc..c905e33ff 100644 --- a/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py +++ b/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py @@ -1,8 +1,6 @@ from pytest import fixture -from intelligence_layer.connectors.retrievers.qdrant_in_memory_retriever import ( - QdrantInMemoryRetriever, -) +from intelligence_layer.connectors import QdrantInMemoryRetriever from intelligence_layer.core import NoOpTracer from intelligence_layer.use_cases import MultipleChunkRetrieverQa, RetrieverBasedQaInput diff --git a/tests/use_cases/search/test_expand_chunk.py b/tests/use_cases/search/test_expand_chunk.py index 474037a81..a6a9c5fc9 100644 --- a/tests/use_cases/search/test_expand_chunk.py +++ b/tests/use_cases/search/test_expand_chunk.py @@ -8,7 +8,7 @@ QdrantInMemoryRetriever, ) from intelligence_layer.core import LuminousControlModel, NoOpTracer -from intelligence_layer.use_cases import ExpandChunkInput, ExpandChunks +from intelligence_layer.use_cases import ExpandChunks, ExpandChunksInput @fixture @@ -46,8 +46,8 @@ def in_memory_retriever_documents() -> Sequence[Document]: def build_expand_chunk_input( document: Document, index_ranges: Sequence[tuple[int, int]] -) -> ExpandChunkInput[int]: - return ExpandChunkInput( +) -> ExpandChunksInput[int]: + return ExpandChunksInput( document_id=0, chunks_found=[ DocumentChunk( @@ -63,7 +63,7 @@ def build_expand_chunk_input( @fixture def wholly_included_expand_chunk_input( in_memory_retriever_documents: Sequence[Document], -) -> ExpandChunkInput[int]: +) -> ExpandChunksInput[int]: assert len(in_memory_retriever_documents) == 1 start_index, end_index = ( int(len(in_memory_retriever_documents[0].text) * 0.5), @@ -78,7 +78,7 @@ def wholly_included_expand_chunk_input( @fixture def overlapping_expand_chunk_input( in_memory_retriever_documents: Sequence[Document], -) -> ExpandChunkInput[int]: +) -> ExpandChunksInput[int]: assert len(in_memory_retriever_documents) == 1 start_index, end_index = ( int(len(in_memory_retriever_documents[0].text) * 0.2), @@ -93,7 +93,7 @@ def overlapping_expand_chunk_input( @fixture def multiple_chunks_expand_chunk_input( in_memory_retriever_documents: Sequence[Document], -) -> ExpandChunkInput[int]: +) -> ExpandChunksInput[int]: assert len(in_memory_retriever_documents) == 1 start_index_1, end_index_1 = ( int(len(in_memory_retriever_documents[0].text) * 0.3), @@ -113,7 +113,7 @@ def multiple_chunks_expand_chunk_input( def test_expand_chunk_works_for_wholly_included_chunk( asymmetric_in_memory_retriever: QdrantInMemoryRetriever, luminous_control_model: LuminousControlModel, - wholly_included_expand_chunk_input: ExpandChunkInput[int], + wholly_included_expand_chunk_input: ExpandChunksInput[int], no_op_tracer: NoOpTracer, ) -> None: expand_chunk_task = ExpandChunks( @@ -137,7 +137,7 @@ def test_expand_chunk_works_for_wholly_included_chunk( def test_expand_chunk_works_for_overlapping_chunk( asymmetric_in_memory_retriever: QdrantInMemoryRetriever, luminous_control_model: LuminousControlModel, - overlapping_expand_chunk_input: ExpandChunkInput[int], + overlapping_expand_chunk_input: ExpandChunksInput[int], no_op_tracer: NoOpTracer, ) -> None: expand_chunk_task = ExpandChunks( @@ -153,7 +153,7 @@ def test_expand_chunk_works_for_overlapping_chunk( def test_expand_chunk_works_for_multiple_chunks( asymmetric_in_memory_retriever: QdrantInMemoryRetriever, luminous_control_model: LuminousControlModel, - multiple_chunks_expand_chunk_input: ExpandChunkInput[int], + multiple_chunks_expand_chunk_input: ExpandChunksInput[int], no_op_tracer: NoOpTracer, ) -> None: expand_chunk_task = ExpandChunks(