Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultipleChunkRetrieverBasedQa-task #734

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
- The implementation of the HuggingFace repository creation and deletion got moved to `HuggingFaceRepository`
### New Features
- feature: HuggingFaceDataset- & AggregationRepositories now have an explicit `create_repository` function.
- feature: Add `MultipleChunkRetrieverBasedQa`, a task that performs better on faster on retriever-QA, especially with longer context models

### Fixes
...

Expand Down
2 changes: 1 addition & 1 deletion src/examples/quickstart_task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.0"
}
},
"nbformat": 4,
Expand Down
6 changes: 6 additions & 0 deletions src/intelligence_layer/use_cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
from .qa.multiple_chunk_qa import MultipleChunkQaInput as MultipleChunkQaInput
from .qa.multiple_chunk_qa import MultipleChunkQaOutput as MultipleChunkQaOutput
from .qa.multiple_chunk_qa import Subanswer as Subanswer
from .qa.multiple_chunk_retriever_qa import (
MulMultipleChunkRetrieverQaOutput as MulMultipleChunkRetrieverQaOutput,
)
from .qa.multiple_chunk_retriever_qa import (
MultipleChunkRetrieverQa as MultipleChunkRetrieverQa,
)
from .qa.retriever_based_qa import EnrichedSubanswer as EnrichedSubanswer
from .qa.retriever_based_qa import RetrieverBasedQa as RetrieverBasedQa
from .qa.retriever_based_qa import RetrieverBasedQaInput as RetrieverBasedQaInput
Expand Down
152 changes: 152 additions & 0 deletions src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Generic, Optional, Sequence

from pydantic import BaseModel

from intelligence_layer.connectors.retrievers.base_retriever import (
ID,
BaseRetriever,
SearchResult,
)
from intelligence_layer.core.chunk import TextChunk
from intelligence_layer.core.task import Task
from intelligence_layer.core.text_highlight import ScoredTextHighlight
from intelligence_layer.core.tracer.tracer import TaskSpan
from intelligence_layer.use_cases.search.search import Search, SearchInput

from .retriever_based_qa import RetrieverBasedQaInput
from .single_chunk_qa import SingleChunkQa, SingleChunkQaInput, SingleChunkQaOutput


class AnswerSource(BaseModel, Generic[ID]):
chunk: SearchResult[ID]
highlights: Sequence[ScoredTextHighlight]


class MulMultipleChunkRetrieverQaOutput(BaseModel, Generic[ID]):
answer: Optional[str]
sources: Sequence[AnswerSource[ID]]


class MultipleChunkRetrieverQa(
Task[RetrieverBasedQaInput, MulMultipleChunkRetrieverQaOutput[ID]], Generic[ID]
):
"""Answer a question based on documents found by a retriever.

`MultipleChunkRetrieverBasedQa` is a task that answers a question based on a set of documents.
It relies on some retriever of type `BaseRetriever` that has the ability to access texts.
In contrast to the regular `RetrieverBasedQa`, this tasks injects multiple chunks into one
`SingleChunkQa` task run.

Note:
`model` provided should be a control-type model.

Args:
retriever: Used to access and return a set of texts.
k: number of top chunk search results to inject into :class:`SingleChunkQa`-task.
qa_task: The task that is used to generate an answer for a single chunk (retrieved through
the retriever). Defaults to :class:`SingleChunkQa`.

Example:
>>> import os
>>> from intelligence_layer.connectors import DocumentIndexClient
>>> from intelligence_layer.connectors import DocumentIndexRetriever
>>> from intelligence_layer.core import InMemoryTracer
>>> from intelligence_layer.use_cases import MultipleChunkRetrieverQa, RetrieverBasedQaInput


>>> token = os.getenv("AA_TOKEN")
>>> document_index = DocumentIndexClient(token)
>>> retriever = DocumentIndexRetriever(document_index, "aleph-alpha", "wikipedia-de", 3)
>>> task = MultipleChunkRetrieverQa(retriever, k=2)
>>> input_data = RetrieverBasedQaInput(question="When was Rome founded?")
>>> tracer = InMemoryTracer()
>>> output = task.run(input_data, tracer)
"""

def __init__(
self,
retriever: BaseRetriever[ID],
k: int = 5,
single_chunk_qa: Task[SingleChunkQaInput, SingleChunkQaOutput] | None = None,
):
super().__init__()
self._search = Search(retriever)
self._k = k
self._single_chunk_qa = single_chunk_qa or SingleChunkQa()

@staticmethod
def _combine_input_texts(chunks: Sequence[str]) -> tuple[TextChunk, Sequence[int]]:
start_indices: list[int] = []
combined_text = ""
for chunk in chunks:
start_indices.append(len(combined_text))
combined_text += chunk + "\n\n"
return (TextChunk(combined_text.strip()), start_indices)

@staticmethod
def _get_highlights_per_chunk(
chunk_start_indices: Sequence[int], highlights: Sequence[ScoredTextHighlight]
) -> Sequence[Sequence[ScoredTextHighlight]]:
overlapping_ranges = []
for i in range(len(chunk_start_indices)):
current_start = chunk_start_indices[i]
next_start = (
chunk_start_indices[i + 1]
if i + 1 < len(chunk_start_indices)
else float("inf")
)

current_overlaps = []
for highlight in highlights:
if highlight.start < next_start and highlight.end > current_start:
highlights_with_indices_fixed = ScoredTextHighlight(
start=max(0, highlight.start - current_start),
end=min(highlight.end - current_start, next_start)
if isinstance(next_start, int)
else highlight.end,
score=highlight.score,
)
current_overlaps.append(highlights_with_indices_fixed)

overlapping_ranges.append(current_overlaps)
return overlapping_ranges

def do_run(
self, input: RetrieverBasedQaInput, task_span: TaskSpan
) -> MulMultipleChunkRetrieverQaOutput[ID]:
search_output = self._search.run(
SearchInput(query=input.question), task_span
).results
sorted_search_output = sorted(
search_output,
key=lambda output: output.score, # not reversing on purpose because model performs better if relevant info is at the end
)[-self._k :]

chunk, chunk_start_indices = self._combine_input_texts(
[output.document_chunk.text for output in sorted_search_output]
)

single_chunk_qa_input = SingleChunkQaInput(
chunk=chunk,
question=input.question,
language=input.language,
)

single_chunk_qa_output = self._single_chunk_qa.run(
single_chunk_qa_input, task_span
)

highlights_per_chunk = self._get_highlights_per_chunk(
chunk_start_indices, single_chunk_qa_output.highlights
)

return MulMultipleChunkRetrieverQaOutput(
answer=single_chunk_qa_output.answer,
sources=[
AnswerSource(
chunk=chunk,
highlights=highlights,
)
for chunk, highlights in zip(sorted_search_output, highlights_per_chunk)
],
)
4 changes: 2 additions & 2 deletions src/intelligence_layer/use_cases/qa/retriever_based_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ class RetrieverBasedQa(
Args:
retriever: Used to access and return a set of texts.
qa_task: The task that is used to generate an answer for a single chunk (retrieved through
the retriever). Defaults to :class:`SingleChunkQa` .
the retriever). Defaults to :class:`MultipleChunkQa` .

Example:
>>> import os
>>> from intelligence_layer.connectors import DocumentIndexClient
>>> from intelligence_layer.connectors import DocumentIndexRetriever
>>> from intelligence_layer.core import InMemoryTracer
>>> from intelligence_layer.use_cases import RetrieverBasedQa, RetrieverBasedQaInput, SingleChunkQa
>>> from intelligence_layer.use_cases import RetrieverBasedQa, RetrieverBasedQaInput


>>> token = os.getenv("AA_TOKEN")
Expand Down
Loading