Skip to content

Commit

Permalink
merge again
Browse files Browse the repository at this point in the history
  • Loading branch information
LostVector committed Sep 30, 2024
2 parents 98a89d9 + 728a41a commit 1dfb2c7
Show file tree
Hide file tree
Showing 20 changed files with 378 additions and 166 deletions.
16 changes: 13 additions & 3 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -103,15 +104,24 @@ def _run_indexing(
)

embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)

indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

Expand Down
15 changes: 14 additions & 1 deletion backend/danswer/indexing/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -123,6 +124,7 @@ def __init__(
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter

Expand All @@ -131,6 +133,7 @@ def __init__(
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat

self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
Expand Down Expand Up @@ -255,7 +258,7 @@ def _create_chunk(
# If the chunk does not have any useable content, it will not be indexed
return chunks

def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
Expand Down Expand Up @@ -302,3 +305,13 @@ def chunk(self, document: Document) -> list[DocAwareChunk]:
normal_chunks.extend(large_chunks)

return normal_chunks

def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))

if self.heartbeat:
self.heartbeat.heartbeat()

return final_chunks
42 changes: 11 additions & 31 deletions backend/danswer/indexing/embedder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod

from sqlalchemy.orm import Session

from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
Expand All @@ -24,6 +20,9 @@


class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""

def __init__(
self,
model_name: str,
Expand All @@ -33,6 +32,7 @@ def __init__(
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
Expand All @@ -54,6 +54,7 @@ def __init__(
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)

@abstractmethod
Expand All @@ -74,6 +75,7 @@ def __init__(
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
Expand All @@ -83,6 +85,7 @@ def __init__(
provider_type,
api_key,
api_url,
heartbeat,
)

@log_function_time()
Expand Down Expand Up @@ -166,7 +169,7 @@ def embed_chunks(
title_embed_dict[title] = title_embedding

new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
Expand All @@ -180,7 +183,7 @@ def embed_chunks(

@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
Expand All @@ -190,28 +193,5 @@ def from_db_search_settings(
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)


def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")

return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)
41 changes: 41 additions & 0 deletions backend/danswer/indexing/indexing_heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import abc
from typing import Any

from sqlalchemy import func
from sqlalchemy.orm import Session

from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger

logger = setup_logger()


class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""

@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError


class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0

self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq

def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")
20 changes: 10 additions & 10 deletions backend/danswer/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0

logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)

logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []

updatable_ids = [doc.id for doc in ctx.updatable_docs]

Expand Down Expand Up @@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)

return partial(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
Expand All @@ -107,6 +109,7 @@ def __init__(
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat

model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
Expand Down Expand Up @@ -166,6 +169,9 @@ def _batch_encode_texts(

response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)

if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings

def encode(
Expand Down
18 changes: 18 additions & 0 deletions backend/tests/unit/danswer/indexing/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any

import pytest

from danswer.indexing.indexing_heartbeat import Heartbeat


class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0

def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1


@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()
Loading

0 comments on commit 1dfb2c7

Please sign in to comment.