diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c266b804..b6c3e3e0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,10 +3,14 @@ ## Unreleased ### Breaking Changes +- breaking change: `ExpandChunksOutput` now returns `ChunkWithStartEndIndices` instead of `TextChunk` +- breaking change: `MultipleChunkRetrieverQa`'s `AnswerSource` now contains `EnrichedChunk` instead of just the `TextChunk` + ### New Features ### Fixes +- fix: `ChunkWithIndices` now additionally returns end_index ## 0.9.1 diff --git a/src/intelligence_layer/core/__init__.py b/src/intelligence_layer/core/__init__.py index 3966e8828..307031287 100644 --- a/src/intelligence_layer/core/__init__.py +++ b/src/intelligence_layer/core/__init__.py @@ -3,7 +3,7 @@ from .chunk import ChunkOutput as ChunkOutput from .chunk import ChunkWithIndices as ChunkWithIndices from .chunk import ChunkWithIndicesOutput as ChunkWithIndicesOutput -from .chunk import ChunkWithStartIndex as ChunkWithStartIndex +from .chunk import ChunkWithStartEndIndices as ChunkWithStartEndIndices from .chunk import TextChunk as TextChunk from .detect_language import DetectLanguage as DetectLanguage from .detect_language import DetectLanguageInput as DetectLanguageInput diff --git a/src/intelligence_layer/core/chunk.py b/src/intelligence_layer/core/chunk.py index b71194a57..fca702d58 100644 --- a/src/intelligence_layer/core/chunk.py +++ b/src/intelligence_layer/core/chunk.py @@ -62,16 +62,18 @@ def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput: return ChunkOutput(chunks=chunks) -class ChunkWithStartIndex(BaseModel): - """A `TextChunk` and its `start_index` relative to its parent document. +class ChunkWithStartEndIndices(BaseModel): + """A `TextChunk` and its `start_index` and `end_index` within the given text. Attributes: chunk: The actual text. - start_index: The character start index of the chunk within the respective document. + start_index: The character start index of the chunk within the given text. + end_index: The character end index of the chunk within the given text. """ chunk: TextChunk start_index: int + end_index: int class ChunkWithIndicesOutput(BaseModel): @@ -81,7 +83,7 @@ class ChunkWithIndicesOutput(BaseModel): chunks_with_indices: A list of smaller sections of the input text with the respective start_index. """ - chunks_with_indices: Sequence[ChunkWithStartIndex] + chunks_with_indices: Sequence[ChunkWithStartEndIndices] class ChunkWithIndices(Task[ChunkInput, ChunkWithIndicesOutput]): @@ -98,13 +100,19 @@ class ChunkWithIndices(Task[ChunkInput, ChunkWithIndicesOutput]): def __init__(self, model: AlephAlphaModel, max_tokens_per_chunk: int = 512): super().__init__() - self._splitter = TextSplitter.from_huggingface_tokenizer(model.get_tokenizer()) + self._splitter = TextSplitter.from_huggingface_tokenizer( + model.get_tokenizer(), trim_chunks=False + ) self._max_tokens_per_chunk = max_tokens_per_chunk def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkWithIndicesOutput: chunks_with_indices = [ - ChunkWithStartIndex(chunk=TextChunk(t[1]), start_index=t[0]) - for t in self._splitter.chunk_indices( + ChunkWithStartEndIndices( + chunk=TextChunk(chunk), + start_index=start_index, + end_index=start_index + len(chunk), + ) + for (start_index, chunk) in self._splitter.chunk_indices( input.text, self._max_tokens_per_chunk ) ] diff --git a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py index 229af5f3a..e6bfaa001 100644 --- a/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py +++ b/src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py @@ -23,9 +23,14 @@ from .single_chunk_qa import SingleChunkQa, SingleChunkQaInput, SingleChunkQaOutput -class AnswerSource(BaseModel, Generic[ID]): +class EnrichedChunk(BaseModel, Generic[ID]): document_id: ID chunk: TextChunk + indices: tuple[int, int] + + +class AnswerSource(BaseModel, Generic[ID]): + chunk: EnrichedChunk[ID] highlights: Sequence[ScoredTextHighlight] @@ -76,7 +81,7 @@ def _combine_input_texts(chunks: Sequence[str]) -> tuple[TextChunk, Sequence[int combined_text = "" for chunk in chunks: start_indices.append(len(combined_text)) - combined_text += chunk + "\n\n" + combined_text += chunk.strip() + "\n\n" return (TextChunk(combined_text.strip()), start_indices) @staticmethod @@ -97,9 +102,11 @@ def _get_highlights_per_chunk( if highlight.start < next_start and highlight.end > current_start: highlights_with_indices_fixed = ScoredTextHighlight( start=max(0, highlight.start - current_start), - end=highlight.end - current_start - if isinstance(next_start, float) - else min(next_start, highlight.end - current_start), + end=( + highlight.end - current_start + if isinstance(next_start, float) + else min(next_start, highlight.end - current_start) + ), score=highlight.score, ) current_overlaps.append(highlights_with_indices_fixed) @@ -109,12 +116,12 @@ def _get_highlights_per_chunk( def _expand_search_result_chunks( self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan - ) -> Sequence[tuple[ID, TextChunk]]: + ) -> Sequence[EnrichedChunk[ID]]: grouped_results: dict[ID, list[SearchResult[ID]]] = defaultdict(list) for result in search_results: grouped_results[result.id].append(result) - chunks_to_insert: list[tuple[ID, TextChunk]] = [] + chunks_to_insert: list[EnrichedChunk[ID]] = [] for id, results in grouped_results.items(): input = ExpandChunksInput( document_id=id, chunks_found=[r.document_chunk for r in results] @@ -123,7 +130,13 @@ def _expand_search_result_chunks( for chunk in expand_chunks_output.chunks: if len(chunks_to_insert) >= self._insert_chunk_number: break - chunks_to_insert.append((id, chunk)) + chunks_to_insert.append( + EnrichedChunk( + document_id=id, + chunk=chunk.chunk, + indices=(chunk.start_index, chunk.end_index), + ) + ) return chunks_to_insert @@ -142,7 +155,7 @@ def do_run( ) chunk_for_prompt, chunk_start_indices = self._combine_input_texts( - [c[1] for c in chunks_to_insert] + [c.chunk for c in chunks_to_insert] ) single_chunk_qa_input = SingleChunkQaInput( @@ -163,11 +176,10 @@ def do_run( answer=single_chunk_qa_output.answer, sources=[ AnswerSource( - document_id=id_and_chunk[0], - chunk=id_and_chunk[1], + chunk=enriched_chunk, highlights=highlights, ) - for id_and_chunk, highlights in zip( + for enriched_chunk, highlights in zip( chunks_to_insert, highlights_per_chunk ) ], diff --git a/src/intelligence_layer/use_cases/search/expand_chunks.py b/src/intelligence_layer/use_cases/search/expand_chunks.py index edbace1ef..5425751a7 100644 --- a/src/intelligence_layer/use_cases/search/expand_chunks.py +++ b/src/intelligence_layer/use_cases/search/expand_chunks.py @@ -4,7 +4,11 @@ from intelligence_layer.connectors import BaseRetriever, DocumentChunk from intelligence_layer.connectors.retrievers.base_retriever import ID -from intelligence_layer.core.chunk import ChunkInput, ChunkWithIndices, TextChunk +from intelligence_layer.core.chunk import ( + ChunkInput, + ChunkWithIndices, + ChunkWithStartEndIndices, +) from intelligence_layer.core.model import AlephAlphaModel from intelligence_layer.core.task import Task from intelligence_layer.core.tracer.tracer import TaskSpan @@ -16,7 +20,7 @@ class ExpandChunksInput(BaseModel, Generic[ID]): class ExpandChunksOutput(BaseModel): - chunks: Sequence[TextChunk] + chunks: Sequence[ChunkWithStartEndIndices] class ExpandChunks(Generic[ID], Task[ExpandChunksInput[ID], ExpandChunksOutput]): @@ -50,34 +54,26 @@ def do_run( ).chunks_with_indices overlapping_chunk_indices = self._overlapping_chunk_indices( - [c.start_index for c in chunk_with_indices], + [(c.start_index, c.end_index) for c in chunk_with_indices], [(chunk.start, chunk.end) for chunk in input.chunks_found], ) return ExpandChunksOutput( - chunks=[ - chunk_with_indices[index].chunk for index in overlapping_chunk_indices - ] + chunks=[chunk_with_indices[index] for index in overlapping_chunk_indices], ) def _overlapping_chunk_indices( self, - chunk_start_indices: Sequence[int], + chunk_indices: Sequence[tuple[int, int]], target_ranges: Sequence[tuple[int, int]], ) -> list[int]: - n = len(chunk_start_indices) overlapping_indices: list[int] = [] - for i in range(n): - if i < n - 1: - chunk_end: float = chunk_start_indices[i + 1] - else: - chunk_end = float("inf") - + for i in range(len(chunk_indices)): if any( ( - chunk_start_indices[i] <= target_range[1] - and chunk_end > target_range[0] + chunk_indices[i][0] <= target_range[1] + and chunk_indices[i][1] > target_range[0] ) for target_range in target_ranges ): diff --git a/tests/core/test_chunk.py b/tests/core/test_chunk.py index e3968abbb..637eeb609 100644 --- a/tests/core/test_chunk.py +++ b/tests/core/test_chunk.py @@ -54,3 +54,7 @@ def test_chunk_with_indices( c.start_index < output.chunks_with_indices[idx + 1].start_index for idx, c in enumerate(output.chunks_with_indices[:-1]) ) + assert all( + c.end_index == output.chunks_with_indices[idx + 1].start_index + for idx, c in enumerate(output.chunks_with_indices[:-1]) + ) diff --git a/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py b/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py index ed8003744..029a7e34b 100644 --- a/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py +++ b/tests/use_cases/qa/test_multiple_chunk_retriever_qa.py @@ -12,7 +12,7 @@ def multiple_chunk_retriever_qa( return MultipleChunkRetrieverQa(retriever=asymmetric_in_memory_retriever) -def test_retriever_based_qa_using_in_memory_retriever( +def test_multiple_chunk_retriever_qa_using_in_memory_retriever( multiple_chunk_retriever_qa: MultipleChunkRetrieverQa[int], no_op_tracer: NoOpTracer, ) -> None: diff --git a/tests/use_cases/search/test_expand_chunk.py b/tests/use_cases/search/test_expand_chunk.py index a6a9c5fc9..c4d0fe0b2 100644 --- a/tests/use_cases/search/test_expand_chunk.py +++ b/tests/use_cases/search/test_expand_chunk.py @@ -130,7 +130,7 @@ def test_expand_chunk_works_for_wholly_included_chunk( ) assert ( wholly_included_expand_chunk_input.chunks_found[0].text - in expand_chunk_output.chunks[0] + in expand_chunk_output.chunks[0].chunk ) @@ -165,6 +165,6 @@ def test_expand_chunk_works_for_multiple_chunks( assert len(expand_chunk_output.chunks) == 3 - combined_chunks = "\n\n".join(expand_chunk_output.chunks) + combined_chunks = "".join(chunk.chunk for chunk in expand_chunk_output.chunks) for chunk_found in multiple_chunks_expand_chunk_input.chunks_found: assert chunk_found.text in combined_chunks