Skip to content

Commit

Permalink
feat: ExpandChunks caches chunked documents by id
Browse files Browse the repository at this point in the history
  • Loading branch information
NickyHavoc committed Apr 24, 2024
1 parent da7a2b3 commit b3dbc53
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
36 changes: 26 additions & 10 deletions src/intelligence_layer/use_cases/search/expand_chunks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from functools import lru_cache
from typing import Generic, Sequence

from pydantic import BaseModel

from intelligence_layer.connectors import BaseRetriever, DocumentChunk
from intelligence_layer.connectors.retrievers.base_retriever import ID
from intelligence_layer.core.chunk import (
from intelligence_layer.core import (
ChunkInput,
ChunkWithIndices,
ChunkWithStartEndIndices,
NoOpTracer,
)
from intelligence_layer.core.model import AlephAlphaModel
from intelligence_layer.core.task import Task
Expand Down Expand Up @@ -41,27 +43,41 @@ def __init__(
super().__init__()
self._retriever = retriever
self._chunk_with_indices = ChunkWithIndices(model, max_chunk_size)
self._no_op_tracer = NoOpTracer()

def do_run(
self, input: ExpandChunksInput[ID], task_span: TaskSpan
) -> ExpandChunksOutput:
full_doc = self._retriever.get_full_document(input.document_id)
if not full_doc:
raise RuntimeError(f"No document for id '{input.document_id}' found")

chunk_with_indices = self._chunk_with_indices.run(
ChunkInput(text=full_doc.text), task_span
).chunks_with_indices
chunked_text = self._retrieve_and_chunk(input.document_id)

overlapping_chunk_indices = self._overlapping_chunk_indices(
[(c.start_index, c.end_index) for c in chunk_with_indices],
[(c.start_index, c.end_index) for c in chunked_text],
[(chunk.start, chunk.end) for chunk in input.chunks_found],
)

return ExpandChunksOutput(
chunks=[chunk_with_indices[index] for index in overlapping_chunk_indices],
chunks=[chunked_text[index] for index in overlapping_chunk_indices],
)

@lru_cache(maxsize=100)
def _retrieve_and_chunk(
self, document_id: ID
) -> Sequence[ChunkWithStartEndIndices]:
text = self._retrieve_text(document_id)
return self._chunk_text(text)

def _retrieve_text(self, document_id: ID) -> str:
full_document = self._retriever.get_full_document(document_id)
if not full_document:
raise RuntimeError(f"No document for id '{document_id}' found")
return full_document.text

def _chunk_text(self, text: str) -> Sequence[ChunkWithStartEndIndices]:
# NoOpTracer used to allow caching {ID: Sequence[ChunkWithStartEndIndices]}
return self._chunk_with_indices.run(
ChunkInput(text=text), self._no_op_tracer
).chunks_with_indices

def _overlapping_chunk_indices(
self,
chunk_indices: Sequence[tuple[int, int]],
Expand Down
4 changes: 3 additions & 1 deletion tests/use_cases/qa/test_multiple_chunk_retriever_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from intelligence_layer.connectors import QdrantInMemoryRetriever
from intelligence_layer.core import NoOpTracer
from intelligence_layer.core.tracer.in_memory_tracer import InMemoryTracer
from intelligence_layer.use_cases import MultipleChunkRetrieverQa, RetrieverBasedQaInput


Expand All @@ -18,7 +19,8 @@ def test_multiple_chunk_retriever_qa_using_in_memory_retriever(
) -> None:
question = "When was Robert Moses born?"
input = RetrieverBasedQaInput(question=question)
output = multiple_chunk_retriever_qa.run(input, no_op_tracer)
tracer = InMemoryTracer()
output = multiple_chunk_retriever_qa.run(input, tracer)
assert output.answer
assert "1888" in output.answer
assert len(output.sources) == 5

0 comments on commit b3dbc53

Please sign in to comment.