From 25d736edf6e1e24306ed48315bcd8a9107f3a537 Mon Sep 17 00:00:00 2001 From: "niklas.finken" Date: Tue, 23 Apr 2024 17:49:13 +0200 Subject: [PATCH] `MultipleChunkRetrieverQa` consumes chunks greedily --- .../qa/multiple_chunk_retriever_qa.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py index e6bfaa001..1dd61afaa 100644 --- a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py +++ b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py @@ -1,4 +1,5 @@ from collections import defaultdict +from copy import deepcopy from typing import Generic, Optional, Sequence from pydantic import BaseModel @@ -117,27 +118,28 @@ def _get_highlights_per_chunk( def _expand_search_result_chunks( self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan ) -> Sequence[EnrichedChunk[ID]]: - grouped_results: dict[ID, list[SearchResult[ID]]] = defaultdict(list) - for result in search_results: - grouped_results[result.id].append(result) chunks_to_insert: list[EnrichedChunk[ID]] = [] - for id, results in grouped_results.items(): + for result in search_results: input = ExpandChunksInput( - document_id=id, chunks_found=[r.document_chunk for r in results] + document_id=result.id, chunks_found=[result.document_chunk] ) expand_chunks_output = self._expand_chunks.run(input, task_span) for chunk in expand_chunks_output.chunks: if len(chunks_to_insert) >= self._insert_chunk_number: break - chunks_to_insert.append( - EnrichedChunk( - document_id=id, - chunk=chunk.chunk, - indices=(chunk.start_index, chunk.end_index), - ) + + enriched_chunk = EnrichedChunk( + document_id=result.id, + chunk=chunk.chunk, + indices=(chunk.start_index, chunk.end_index), ) + if enriched_chunk in chunks_to_insert: + continue + + chunks_to_insert.append(enriched_chunk) + return chunks_to_insert def do_run(