Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand chunk task #746

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
## Unreleased

### Breaking Changes
...
- breaking change: `MultipleChunkRetrieverQaOutput` now return `sources` and `search_results`

### New Features
...
- feature: `ExpandChunks` task takes a retriever and some search results to expand the chunks to the desired length

### Fixes
...
Expand Down
3 changes: 3 additions & 0 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .chunk import Chunk as Chunk
from .chunk import ChunkInput as ChunkInput
from .chunk import ChunkOutput as ChunkOutput
from .chunk import ChunkWithIndices as ChunkWithIndices
from .chunk import ChunkWithIndicesOutput as ChunkWithIndicesOutput
from .chunk import ChunkWithStartIndex as ChunkWithStartIndex
from .chunk import TextChunk as TextChunk
from .detect_language import DetectLanguage as DetectLanguage
from .detect_language import DetectLanguageInput as DetectLanguageInput
Expand Down
51 changes: 50 additions & 1 deletion src/intelligence_layer/core/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ChunkInput(BaseModel):
"""The input for a `ChunkTask`.
"""The input for a `Chunk`-task.

Attributes:
text: A text of arbitrary length.
Expand Down Expand Up @@ -60,3 +60,52 @@ def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput:
for t in self._splitter.chunks(input.text, self._max_tokens_per_chunk)
]
return ChunkOutput(chunks=chunks)


class ChunkWithStartIndex(BaseModel):
"""A `TextChunk` and its `start_index` relative to its parent document.

Attributes:
chunk: The actual text.
start_index: The character start index of the chunk within the respective document.
"""

chunk: TextChunk
start_index: int


class ChunkWithIndicesOutput(BaseModel):
"""The output of a `ChunkWithIndices`-task.

Attributes:
chunks_with_indices: A list of smaller sections of the input text with the respective start_index.
"""

chunks_with_indices: Sequence[ChunkWithStartIndex]


class ChunkWithIndices(Task[ChunkInput, ChunkWithIndicesOutput]):
"""Splits a longer text into smaller text chunks and returns the chunks' start indices.

Provide a text of any length and chunk it into smaller pieces using a
tokenizer that is available within the Aleph Alpha client. For each chunk, the respective
start index relative to the document is also returned.

Args:
model: A valid Aleph Alpha model.
max_tokens_per_chunk: The maximum number of tokens to fit into one chunk.
"""

def __init__(self, model: AlephAlphaModel, max_tokens_per_chunk: int = 512):
super().__init__()
self._splitter = TextSplitter.from_huggingface_tokenizer(model.get_tokenizer())
self._max_tokens_per_chunk = max_tokens_per_chunk

def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkWithIndicesOutput:
chunks_with_indices = [
ChunkWithStartIndex(chunk=TextChunk(t[1]), start_index=t[0])
for t in self._splitter.chunk_indices(
input.text, self._max_tokens_per_chunk
)
]
return ChunkWithIndicesOutput(chunks_with_indices=chunks_with_indices)
3 changes: 3 additions & 0 deletions src/intelligence_layer/use_cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@
from .qa.single_chunk_qa import SingleChunkQa as SingleChunkQa
from .qa.single_chunk_qa import SingleChunkQaInput as SingleChunkQaInput
from .qa.single_chunk_qa import SingleChunkQaOutput as SingleChunkQaOutput
from .search.expand_chunks import ExpandChunks as ExpandChunks
from .search.expand_chunks import ExpandChunksInput as ExpandChunksInput
from .search.expand_chunks import ExpandChunksOutput as ExpandChunksOutput
from .search.search import AggregatedSearchEvaluation as AggregatedSearchEvaluation
from .search.search import ChunkFound as ChunkFound
from .search.search import ExpectedSearchOutput as ExpectedSearchOutput
Expand Down
87 changes: 55 additions & 32 deletions src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Generic, Optional, Sequence

from pydantic import BaseModel
Expand All @@ -8,23 +9,30 @@
SearchResult,
)
from intelligence_layer.core.chunk import TextChunk
from intelligence_layer.core.model import ControlModel, LuminousControlModel
from intelligence_layer.core.task import Task
from intelligence_layer.core.text_highlight import ScoredTextHighlight
from intelligence_layer.core.tracer.tracer import TaskSpan
from intelligence_layer.use_cases.search.expand_chunks import (
ExpandChunks,
ExpandChunksInput,
)
from intelligence_layer.use_cases.search.search import Search, SearchInput

from .retriever_based_qa import RetrieverBasedQaInput
from .single_chunk_qa import SingleChunkQa, SingleChunkQaInput, SingleChunkQaOutput


class AnswerSource(BaseModel, Generic[ID]):
search_result: SearchResult[ID]
document_id: ID
chunk: TextChunk
highlights: Sequence[ScoredTextHighlight]


class MultipleChunkRetrieverQaOutput(BaseModel, Generic[ID]):
answer: Optional[str]
sources: Sequence[AnswerSource[ID]]
search_results: Sequence[SearchResult[ID]]


class MultipleChunkRetrieverQa(
Expand All @@ -45,34 +53,22 @@ class MultipleChunkRetrieverQa(
k: number of top chunk search results to inject into :class:`SingleChunkQa`-task.
qa_task: The task that is used to generate an answer for a single chunk (retrieved through
the retriever). Defaults to :class:`SingleChunkQa`.

Example:
>>> import os
>>> from intelligence_layer.connectors import DocumentIndexClient
>>> from intelligence_layer.connectors import DocumentIndexRetriever
>>> from intelligence_layer.core import InMemoryTracer
>>> from intelligence_layer.use_cases import MultipleChunkRetrieverQa, RetrieverBasedQaInput


>>> token = os.getenv("AA_TOKEN")
>>> document_index = DocumentIndexClient(token)
>>> retriever = DocumentIndexRetriever(document_index, "aleph-alpha", "wikipedia-de", 3)
>>> task = MultipleChunkRetrieverQa(retriever, k=2)
>>> input_data = RetrieverBasedQaInput(question="When was Rome founded?")
>>> tracer = InMemoryTracer()
>>> output = task.run(input_data, tracer)
"""

def __init__(
self,
retriever: BaseRetriever[ID],
k: int = 5,
model: ControlModel | None = None,
insert_chunk_number: int = 5,
insert_chunk_size: int = 256,
single_chunk_qa: Task[SingleChunkQaInput, SingleChunkQaOutput] | None = None,
):
super().__init__()
self._model = model or LuminousControlModel("luminous-supreme-control")
self._search = Search(retriever)
self._k = k
self._single_chunk_qa = single_chunk_qa or SingleChunkQa()
self._expand_chunks = ExpandChunks(retriever, self._model, insert_chunk_size)
self._single_chunk_qa = single_chunk_qa or SingleChunkQa(self._model)
self._insert_chunk_number = insert_chunk_number

@staticmethod
def _combine_input_texts(chunks: Sequence[str]) -> tuple[TextChunk, Sequence[int]]:
Expand Down Expand Up @@ -101,33 +97,56 @@ def _get_highlights_per_chunk(
if highlight.start < next_start and highlight.end > current_start:
highlights_with_indices_fixed = ScoredTextHighlight(
start=max(0, highlight.start - current_start),
end=min(highlight.end - current_start, next_start)
if isinstance(next_start, int)
else highlight.end,
end=highlight.end - current_start
if isinstance(next_start, float)
else min(next_start, highlight.end - current_start),
score=highlight.score,
)
current_overlaps.append(highlights_with_indices_fixed)

overlapping_ranges.append(current_overlaps)
return overlapping_ranges

def _expand_search_result_chunks(
self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan
) -> Sequence[tuple[ID, TextChunk]]:
grouped_results: dict[ID, list[SearchResult[ID]]] = defaultdict(list)
for result in search_results:
grouped_results[result.id].append(result)

chunks_to_insert: list[tuple[ID, TextChunk]] = []
for id, results in grouped_results.items():
input = ExpandChunksInput(
document_id=id, chunks_found=[r.document_chunk for r in results]
)
expand_chunks_output = self._expand_chunks.run(input, task_span)
for chunk in expand_chunks_output.chunks:
if len(chunks_to_insert) >= self._insert_chunk_number:
break
chunks_to_insert.append((id, chunk))

return chunks_to_insert

def do_run(
self, input: RetrieverBasedQaInput, task_span: TaskSpan
) -> MultipleChunkRetrieverQaOutput[ID]:
search_output = self._search.run(
SearchInput(query=input.question), task_span
).results
sorted_search_output = sorted(
search_output,
key=lambda output: output.score, # not reversing on purpose because model performs better if relevant info is at the end
)[-self._k :]
sorted_search_results = sorted(
search_output, key=lambda output: output.score, reverse=True
)

chunks_to_insert = self._expand_search_result_chunks(
sorted_search_results, task_span
)

chunk, chunk_start_indices = self._combine_input_texts(
[output.document_chunk.text for output in sorted_search_output]
chunk_for_prompt, chunk_start_indices = self._combine_input_texts(
[c[1] for c in chunks_to_insert]
)

single_chunk_qa_input = SingleChunkQaInput(
chunk=chunk,
chunk=chunk_for_prompt,
question=input.question,
language=input.language,
)
Expand All @@ -144,9 +163,13 @@ def do_run(
answer=single_chunk_qa_output.answer,
sources=[
AnswerSource(
search_result=chunk,
document_id=id_and_chunk[0],
chunk=id_and_chunk[1],
highlights=highlights,
)
for chunk, highlights in zip(sorted_search_output, highlights_per_chunk)
for id_and_chunk, highlights in zip(
chunks_to_insert, highlights_per_chunk
)
],
search_results=sorted_search_results,
)
86 changes: 86 additions & 0 deletions src/intelligence_layer/use_cases/search/expand_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Generic, Sequence

from pydantic import BaseModel

from intelligence_layer.connectors import BaseRetriever, DocumentChunk
from intelligence_layer.connectors.retrievers.base_retriever import ID
from intelligence_layer.core.chunk import ChunkInput, ChunkWithIndices, TextChunk
from intelligence_layer.core.model import AlephAlphaModel
from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer.tracer import TaskSpan


class ExpandChunksInput(BaseModel, Generic[ID]):
document_id: ID
chunks_found: Sequence[DocumentChunk]


class ExpandChunksOutput(BaseModel):
chunks: Sequence[TextChunk]


class ExpandChunks(Generic[ID], Task[ExpandChunksInput[ID], ExpandChunksOutput]):
"""Expand chunks found during search.

Args:
retriever: Used to access and return a set of texts.
model: The model's tokenizer is relevant to calculate the correct size of the returned chunks.
max_chunk_size: The maximum chunk size of each returned chunk.
"""

def __init__(
self,
retriever: BaseRetriever[ID],
model: AlephAlphaModel,
max_chunk_size: int = 512,
):
super().__init__()
self._retriever = retriever
self._chunk_with_indices = ChunkWithIndices(model, max_chunk_size)

def do_run(
self, input: ExpandChunksInput[ID], task_span: TaskSpan
) -> ExpandChunksOutput:
full_doc = self._retriever.get_full_document(input.document_id)
if not full_doc:
raise RuntimeError(f"No document for id '{input.document_id}' found")

chunk_with_indices = self._chunk_with_indices.run(
ChunkInput(text=full_doc.text), task_span
).chunks_with_indices

overlapping_chunk_indices = self._overlapping_chunk_indices(
[c.start_index for c in chunk_with_indices],
[(chunk.start, chunk.end) for chunk in input.chunks_found],
)

return ExpandChunksOutput(
chunks=[
chunk_with_indices[index].chunk for index in overlapping_chunk_indices
]
)

def _overlapping_chunk_indices(
self,
chunk_start_indices: Sequence[int],
target_ranges: Sequence[tuple[int, int]],
) -> list[int]:
n = len(chunk_start_indices)
overlapping_indices: list[int] = []

for i in range(n):
if i < n - 1:
chunk_end: float = chunk_start_indices[i + 1]
else:
chunk_end = float("inf")

if any(
(
chunk_start_indices[i] <= target_range[1]
and chunk_end > target_range[0]
)
for target_range in target_ranges
):
overlapping_indices.append(i)

return overlapping_indices
Loading