Skip to content

Commit

Permalink
finish rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Sep 27, 2024
1 parent 4896ad9 commit f3aacbc
Show file tree
Hide file tree
Showing 18 changed files with 80 additions and 65 deletions.
16 changes: 0 additions & 16 deletions backend/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from logging.config import fileConfig

from typing import Tuple
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
Expand Down Expand Up @@ -40,21 +39,6 @@ def get_schema_options() -> str:

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}


def include_object(
object: SchemaItem,
name: str,
type_: str,
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True

EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}


def include_object(
object: SchemaItem,
name: str,
Expand Down
1 change: 0 additions & 1 deletion backend/danswer/background/celery/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

from danswer.configs.app_configs import MULTI_TENANT
from danswer.background.update import get_all_tenant_ids
import logging
import time
Expand Down
8 changes: 5 additions & 3 deletions backend/danswer/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
def check_for_prune_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""

with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)

for cc_pair in all_cc_pairs:
Expand All @@ -46,13 +46,14 @@ def check_for_prune_task() -> None:
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id
)
)


@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
Expand Down Expand Up @@ -112,6 +113,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id
)
except Exception as e:
task_logger.exception(
Expand Down
5 changes: 0 additions & 5 deletions backend/danswer/background/connector_deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
from danswer.configs.app_configs import DEFAULT_SCHEMA

logger = setup_logger()

Expand Down
13 changes: 11 additions & 2 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,23 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA

return attempt

def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None:
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
) -> None:


try:
if is_ee:
global_version.set_ee()

# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)

with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
attempt = _prepare_index_attempt(db_session, index_attempt_id)
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/configs/chat_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
2 changes: 1 addition & 1 deletion backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def get_sqlalchemy_engine(*, schema: str | None = DEFAULT_SCHEMA) -> Engine:
global _engines
if schema not in _engines:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=f"{POSTGRES_APP_NAME}_{schema}_sync"
db_api=SYNC_DB_API
)
_engines[schema] = create_engine(
connection_string,
Expand Down
7 changes: 1 addition & 6 deletions backend/danswer/db_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@

from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from ee.danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.persona import delete_old_default_personas
from danswer.chat.load_yamls import load_chat_yamls
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
Expand All @@ -23,9 +21,6 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)

logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)

logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
Expand Down
13 changes: 11 additions & 2 deletions backend/danswer/document_index/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ def get_default_document_index(
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa

indices = [primary_index_name] if primary_index_name is not None else indices
if not indices:
raise ValueError("No indices provided")

return VespaIndex(
indices=indices,
secondary_index_name=secondary_index_name
)

def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
"""
TODO: Use redis to cache this or something
"""
search_settings = get_current_search_settings(db_session)
return get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)
2 changes: 1 addition & 1 deletion backend/danswer/document_index/vespa/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def ensure_indices_exist(
index_embedding_dim: int | None = None,
secondary_index_embedding_dim: int | None = None
) -> None:

if embedding_dims is None:
if index_embedding_dim is not None:
embedding_dims = [index_embedding_dim]
Expand Down
60 changes: 44 additions & 16 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from danswer.utils.gpu_utils import gpu_status_request

from danswer.document_index.vespa.index import VespaIndex
import time
import traceback
from collections.abc import AsyncGenerator
Expand All @@ -21,12 +21,14 @@
from danswer.document_index.interfaces import DocumentIndex
from danswer.configs.app_configs import MULTI_TENANT
from danswer import __version__
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from sqlalchemy.orm import Session
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.indexing.models import IndexingSetting
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
Expand All @@ -43,9 +45,6 @@
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.connector import check_connectors_exist
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
Expand Down Expand Up @@ -93,7 +92,6 @@
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.search_settings import router as search_settings_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
Expand All @@ -105,8 +103,6 @@
from danswer.server.query_and_chat.query_backend import basic_router as query_router
from danswer.server.settings.api import admin_router as settings_admin_router
from danswer.server.settings.api import basic_router as settings_router
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
Expand Down Expand Up @@ -244,25 +240,23 @@ def setup_vespa(
document_index: DocumentIndex,
embedding_dims: list[int],
secondary_embedding_dim: int | None = None
) -> None:
) -> bool:
# Vespa startup is a bit slow, so give it a few seconds
WAIT_SECONDS = 5
wait_time = 5
VESPA_ATTEMPTS = 5
for x in range(VESPA_ATTEMPTS):
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
embedding_dims=embedding_dims,
secondary_index_embedding_dim=secondary_embedding_dim
)
break
return True
except Exception:
logger.notice(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
logger.exception("Error ensuring multi-tenant indices exist")

return False

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
Expand Down Expand Up @@ -395,6 +389,40 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
yield


def update_default_multipass_indexing(db_session: Session) -> None:
docs_exist = check_docs_exist(db_session)
connectors_exist = check_connectors_exist(db_session)
logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}")

if not docs_exist and not connectors_exist:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")

current_settings = get_current_search_settings(db_session)

logger.notice(f"Updating multipass indexing setting to: {gpu_available}")
updated_settings = SavedSearchSettings.from_db_model(current_settings)
# Enable multipass indexing if GPU is available or if using a cloud provider
updated_settings.multipass_indexing = (
gpu_available or current_settings.cloud_provider is not None
)
update_current_search_settings(db_session, updated_settings)

# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")

else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
)


def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
if status_code >= 400:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona, db_session=db_session))
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona, db_session=db_session)),
single_message_history=history_str,
tools=[search_tool] if search_tool else [],
force_use_tool=(
Expand Down
1 change: 0 additions & 1 deletion backend/danswer/one_shot_answer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class DirectQARequest(ChunkContext):

messages: list[ThreadMessage]
prompt_id: int | None = None
persona_id: int
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
Expand Down
1 change: 0 additions & 1 deletion backend/danswer/server/manage/administrative.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from danswer.server.manage.models import HiddenUpdateRequest
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.db.engine import current_tenant_id

router = APIRouter(prefix="/manage")
logger = setup_logger()
Expand Down
6 changes: 1 addition & 5 deletions backend/danswer/server/tenants/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from danswer.db.swap_index import check_index_swap

from sqlalchemy.orm import Session
from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from ee.danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.persona import delete_old_default_personas
from danswer.chat.load_yamls import load_chat_yamls
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
Expand Down Expand Up @@ -132,9 +131,6 @@ def setup_postgres_and_initial_settings(db_session: Session) -> None:
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)

logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)

logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
Expand Down
3 changes: 2 additions & 1 deletion backend/scripts/query_time_check/seed_dummy_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
tenant_id=None
)


Expand Down Expand Up @@ -124,7 +125,7 @@ def seed_dummy_docs(
index_name = search_settings.index_name
embedding_dim = search_settings.model_dim

vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
vespa_index = VespaIndex(indices=[index_name], secondary_index_name=None)
print(index_name)

all_chunks = []
Expand Down
2 changes: 1 addition & 1 deletion backend/scripts/query_time_check/test_query_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_hybrid_retrieval_times(
search_settings = get_current_search_settings(db_session)
index_name = search_settings.index_name

vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
vespa_index = VespaIndex(indices=[index_name], secondary_index_name=None)

# Generate random queries
queries = [f"Random Query {i}" for i in range(number_of_queries)]
Expand Down
1 change: 0 additions & 1 deletion backend/tests/integration/common_utils/reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.db_setup import setup_postgres
from danswer.main import setup_vespa
from danswer.utils.logger import setup_logger
Expand Down

0 comments on commit f3aacbc

Please sign in to comment.