diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 8393985cc5f..e73db597ac5 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -26,19 +26,22 @@ # target_metadata = mymodel.Base.metadata target_metadata = [Base.metadata, ResultModelBase.metadata] + def get_schema_options() -> str: x_args_raw = context.get_x_argument() x_args = {} for arg in x_args_raw: - for pair in arg.split(','): - if '=' in pair: - key, value = pair.split('=', 1) + for pair in arg.split(","): + if "=" in pair: + key, value = pair.split("=", 1) x_args[key] = value - schema_name = x_args.get('schema', 'public') + schema_name = x_args.get("schema", "public") return schema_name + EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} + def include_object( object: SchemaItem, name: str, @@ -56,10 +59,9 @@ def run_migrations_offline() -> None: url = build_connection_string() schema = get_schema_options() - context.configure( url=url, - target_metadata=target_metadata, # type: ignore + target_metadata=target_metadata, # type: ignore literal_binds=True, dialect_opts={"paramstyle": "named"}, version_table_schema=schema, @@ -69,17 +71,18 @@ def run_migrations_offline() -> None: with context.begin_transaction(): context.run_migrations() + def do_run_migrations(connection: Connection) -> None: schema = get_schema_options() connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')) - connection.execute(text('COMMIT')) + connection.execute(text("COMMIT")) connection.execute(text(f'SET search_path TO "{schema}"')) context.configure( connection=connection, - target_metadata=target_metadata, # type: ignore + target_metadata=target_metadata, # type: ignore version_table_schema=schema, include_schemas=True, compare_type=True, @@ -89,6 +92,7 @@ def do_run_migrations(connection: Connection) -> None: with context.begin_transaction(): context.run_migrations() + async def run_async_migrations() -> None: print("Running async migrations") """Run migrations in 'online' mode.""" @@ -102,10 +106,12 @@ async def run_async_migrations() -> None: await connectable.dispose() + def run_migrations_online() -> None: """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py index 1f10430d2c0..c6d30519219 100644 --- a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py +++ b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py @@ -24,11 +24,11 @@ def upgrade() -> None: sa.text('SELECT id, chosen_assistants FROM "user"') ) op.drop_column( - 'user', + "user", "chosen_assistants", ) op.add_column( - 'user', + "user", sa.Column( "chosen_assistants", postgresql.JSONB(astext_type=sa.Text()), @@ -38,7 +38,7 @@ def upgrade() -> None: for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - 'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id' + "UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id" ), {"chosen_assistants": json.dumps(chosen_assistants), "id": id}, ) @@ -47,20 +47,20 @@ def upgrade() -> None: def downgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( - sa.text('SELECT id, chosen_assistants FROM user') + sa.text("SELECT id, chosen_assistants FROM user") ) op.drop_column( - 'user', + "user", "chosen_assistants", ) op.add_column( - 'user', + "user", sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True), ) for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - 'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id' + "UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id" ), {"chosen_assistants": chosen_assistants, "id": id}, ) diff --git a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py index f1b0a7c45f4..2250836d561 100644 --- a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py +++ b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py @@ -16,6 +16,7 @@ branch_labels = None depends_on = None + def upgrade() -> None: op.create_table( "embedding_model", @@ -68,24 +69,22 @@ def upgrade() -> None: column("query_prefix", String), column("passage_prefix", String), column("index_name", String), - column("status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)), + column( + "status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False) + ), ) # Insert the old embedding model op.bulk_insert( EmbeddingModel, - [ - old_embedding_model - ], + [old_embedding_model], ) # If the user has not overridden the embedding model, insert the new default model if not user_overridden_embedding_model: op.bulk_insert( EmbeddingModel, - [ - new_embedding_model - ], + [new_embedding_model], ) op.add_column( @@ -123,6 +122,7 @@ def upgrade() -> None: postgresql_where=sa.text("status = 'FUTURE'"), ) + def downgrade() -> None: op.drop_constraint( "index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey" diff --git a/backend/alembic_utils.py b/backend/alembic_utils.py index 980186d9b17..85d63b34944 100644 --- a/backend/alembic_utils.py +++ b/backend/alembic_utils.py @@ -7,6 +7,7 @@ ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or "" + def _get_trimmed_key(key: str) -> bytes: encoded_key = key.encode() key_length = len(encoded_key) @@ -20,6 +21,7 @@ def _get_trimmed_key(key: str) -> bytes: return encoded_key + def encrypt_string(input_str: str) -> bytes: if not ENCRYPTION_KEY_SECRET: return input_str.encode() @@ -35,8 +37,10 @@ def encrypt_string(input_str: str) -> bytes: return iv + encrypted_data + NUM_POSTPROCESSED_RESULTS = 20 + class IndexModelStatus(str, Enum): PAST = "PAST" PRESENT = "PRESENT" @@ -56,7 +60,6 @@ class SearchType(str, Enum): SEMANTIC = "semantic" - class DocumentSource(str, Enum): # Special case, document passed in via Danswer APIs without specifying a source type INGESTION_API = "ingestion_api" diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index df972bbd65e..9bf23c1477d 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -182,6 +182,7 @@ def send_user_verification_email( s.login(SMTP_USER, SMTP_PASS) s.send_message(msg) + def verify_sso_token(token: str) -> dict: try: payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"]) @@ -232,14 +233,13 @@ async def create_user_session(user: User, tenant_id: str) -> str: "sub": str(user.id), "email": user.email, "tenant_id": tenant_id, - "exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS) + "exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS), } token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256") return token - class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 2367f498b89..fb285c69bc7 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1,4 +1,3 @@ - from danswer.background.update import get_all_tenant_ids import logging import time @@ -476,5 +475,5 @@ def schedule_tenant_tasks() -> None: } ) -schedule_tenant_tasks() +schedule_tenant_tasks() diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 076e93d07bb..1f2a2eec84e 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -46,14 +46,16 @@ def check_for_prune_task(tenant_id: str | None) -> None: kwargs=dict( connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, - tenant_id=tenant_id + tenant_id=tenant_id, ) ) @build_celery_task_wrapper(name_cc_prune_task) @celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT) -def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None: +def prune_documents_task( + connector_id: int, credential_id: int, tenant_id: str | None +) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" @@ -113,7 +115,7 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | connector_id=connector_id, credential_id=credential_id, document_index=document_index, - tenant_id=tenant_id + tenant_id=tenant_id, ) except Exception as e: task_logger.exception( diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 4a2a7699f0a..ab42f83a005 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -49,7 +49,7 @@ def delete_connector_credential_pair_batch( connector_id: int, credential_id: int, document_index: DocumentIndex, - tenant_id: str | None + tenant_id: str | None, ) -> None: """ Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore @@ -134,7 +134,11 @@ def delete_connector_credential_pair_batch( max_retries=3, ) def document_by_cc_pair_cleanup_task( - self: Task, document_id: str, connector_id: int, credential_id: int, tenant_id: str | None + self: Task, + document_id: str, + connector_id: int, + credential_id: int, + tenant_id: str | None, ) -> bool: task_logger.info(f"document_id={document_id}") diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 1fa233a8a8c..c0986f7f210 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -1,4 +1,3 @@ - import time import traceback from datetime import datetime @@ -44,7 +43,7 @@ def _get_connector_runner( attempt: IndexAttempt, start_time: datetime, end_time: datetime, - tenant_id: str | None + tenant_id: str | None, ) -> ConnectorRunner: """ NOTE: `start_time` and `end_time` are only used for poll connectors @@ -81,9 +80,7 @@ def _get_connector_runner( def _run_indexing( - db_session: Session, - index_attempt: IndexAttempt, - tenant_id: str | None + db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -104,7 +101,6 @@ def _run_indexing( primary_index_name=index_name, secondary_index_name=None ) - embedding_model = DefaultIndexingEmbedder.from_db_search_settings( search_settings=search_settings ) @@ -173,7 +169,7 @@ def _run_indexing( attempt=index_attempt, start_time=window_start, end_time=window_end, - tenant_id=tenant_id + tenant_id=tenant_id, ) all_connector_doc_ids: set[str] = set() @@ -201,7 +197,9 @@ def _run_indexing( db_session.refresh(index_attempt) if index_attempt.status != IndexingStatus.IN_PROGRESS: # Likely due to user manually disabling it or model swap - raise RuntimeError(f"Index Attempt was canceled, status is {index_attempt.status}") + raise RuntimeError( + f"Index Attempt was canceled, status is {index_attempt.status}" + ) batch_description = [] for doc in doc_batch: @@ -388,14 +386,13 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA return attempt + def run_indexing_entrypoint( index_attempt_id: int, tenant_id: str | None, connector_credential_pair_id: int, is_ee: bool = False, ) -> None: - - try: if is_ee: global_version.set_ee() @@ -410,8 +407,10 @@ def run_indexing_entrypoint( attempt = _prepare_index_attempt(db_session, index_attempt_id) logger.info( - f"Indexing starting for tenant {tenant_id}: " if tenant_id is not None else "" + - f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing starting for tenant {tenant_id}: " + if tenant_id is not None + else "" + + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) @@ -419,10 +418,14 @@ def run_indexing_entrypoint( _run_indexing(db_session, attempt, tenant_id) logger.info( - f"Indexing finished for tenant {tenant_id}: " if tenant_id is not None else "" + - f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing finished for tenant {tenant_id}: " + if tenant_id is not None + else "" + + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) except Exception as e: - logger.exception(f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}") + logger.exception( + f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}" + ) diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index d0035e51198..904131d422e 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -14,7 +14,9 @@ from danswer.db.tasks import register_task -def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None = None) -> str: +def name_cc_cleanup_task( + connector_id: int, credential_id: int, tenant_id: str | None = None +) -> str: task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" if tenant_id is not None: task_name += f"_{tenant_id}" diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 7ce4eff5c91..50c2c4286e2 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -150,8 +150,9 @@ def _mark_run_failed( """Main funcs""" -def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None: - +def create_indexing_jobs( + existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None +) -> None: """Creates new indexing jobs for each connector / credential pair which is: 1. Enabled 2. `refresh_frequency` time has passed since the last indexing run for this pair @@ -300,7 +301,6 @@ def kickoff_indexing_jobs( secondary_client: Client | SimpleJobClient, tenant_id: str | None, ) -> dict[int, Future | SimpleJob]: - existing_jobs_copy = existing_jobs.copy() engine = get_sqlalchemy_engine(schema=tenant_id) @@ -408,14 +408,22 @@ def kickoff_indexing_jobs( def get_all_tenant_ids() -> list[str] | list[None]: if not MULTI_TENANT: return [None] - with Session(get_sqlalchemy_engine(schema='public')) as session: - result = session.execute(text(""" + with Session(get_sqlalchemy_engine(schema="public")) as session: + result = session.execute( + text( + """ SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public') - """)) + """ + ) + ) tenant_ids = [row[0] for row in result] - valid_tenants = [tenant for tenant in tenant_ids if tenant is None or not tenant.startswith('pg_')] + valid_tenants = [ + tenant + for tenant in tenant_ids + if tenant is None or not tenant.startswith("pg_") + ] return valid_tenants @@ -485,14 +493,18 @@ def update_loop( for tenant_id in tenants: try: - logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}") + logger.debug( + f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}" + ) engine = get_sqlalchemy_engine(schema=tenant_id) with Session(engine) as db_session: check_index_swap(db_session=db_session) if not MULTI_TENANT: search_settings = get_current_search_settings(db_session) if search_settings.provider_type is None: - logger.notice("Running a first inference to warm up embedding model") + logger.notice( + "Running a first inference to warm up embedding model" + ) embedding_model = EmbeddingModel.from_db_model( search_settings=search_settings, server_host=INDEXING_MODEL_SERVER_HOST, @@ -504,13 +516,9 @@ def update_loop( tenant_jobs = existing_jobs.get(tenant_id, {}) tenant_jobs = cleanup_indexing_jobs( - existing_jobs=tenant_jobs, - tenant_id=tenant_id - ) - create_indexing_jobs( - existing_jobs=tenant_jobs, - tenant_id=tenant_id + existing_jobs=tenant_jobs, tenant_id=tenant_id ) + create_indexing_jobs(existing_jobs=tenant_jobs, tenant_id=tenant_id) tenant_jobs = kickoff_indexing_jobs( existing_jobs=tenant_jobs, client=client_primary, @@ -521,7 +529,9 @@ def update_loop( existing_jobs[tenant_id] = tenant_jobs except Exception as e: - logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}") + logger.exception( + f"Failed to process tenant {tenant_id or 'default'}: {e}" + ) except Exception as e: logger.exception(f"Failed to run update due to {e}") @@ -530,6 +540,7 @@ def update_loop( if sleep_time > 0: time.sleep(sleep_time) + def update__main() -> None: set_is_ee_based_on_env_variable() init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME) diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index f87a213ea68..e8a19c158b2 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -18,8 +18,7 @@ def load_prompts_from_yaml( - db_session: Session, - prompts_yaml: str = PROMPTS_YAML + db_session: Session, prompts_yaml: str = PROMPTS_YAML ) -> None: with open(prompts_yaml, "r") as file: data = yaml.safe_load(file) @@ -47,7 +46,6 @@ def load_personas_from_yaml( personas_yaml: str = PERSONAS_YAML, default_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> None: - with open(personas_yaml, "r") as file: data = yaml.safe_load(file) @@ -100,9 +98,7 @@ def load_personas_from_yaml( llm_model_version_override = "gpt-4o" existing_persona = ( - db_session.query(Persona) - .filter(Persona.name == persona["name"]) - .first() + db_session.query(Persona).filter(Persona.name == persona["name"]).first() ) upsert_persona( @@ -135,9 +131,9 @@ def load_personas_from_yaml( db_session=db_session, ) + def load_input_prompts_from_yaml( - db_session: Session, - input_prompts_yaml: str = INPUT_PROMPT_YAML + db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML ) -> None: with open(input_prompts_yaml, "r") as file: data = yaml.safe_load(file) @@ -159,8 +155,6 @@ def load_input_prompts_from_yaml( ) - - def load_chat_yamls( db_session: Session, prompt_yaml: str = PROMPTS_YAML, diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 87bbeca8998..e74c3752649 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -331,7 +331,6 @@ def stream_chat_message_objects( except GenAIDisabledException: raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.") - llm_provider = llm.config.model_provider llm_model_name = llm.config.model_name @@ -843,7 +842,6 @@ def stream_chat_message( litellm_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, ) -> Iterator[str]: - objects = stream_chat_message_objects( new_msg_req=new_msg_req, user=user, diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 72742167a3d..abaa77daebb 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -399,7 +399,7 @@ STRIPE_WEBHOOK_SECRET = os.environ.get( "STRIPE_WEBHOOK_SECRET", - "whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3" + "whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3", ) DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public") diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 23877cfe8a6..bbf494d87ee 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -159,7 +159,7 @@ def __init__( self, file_locations: list[Path | str], batch_size: int = INDEX_BATCH_SIZE, - tenant_id: str | None = None + tenant_id: str | None = None, ) -> None: self.file_locations = [Path(file_location) for file_location in file_locations] self.batch_size = batch_size diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 463d508bdf6..7b926518d8a 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -196,8 +196,10 @@ def build_connection_string( def init_sqlalchemy_engine(app_name: str) -> None: SqlEngine.set_app_name(app_name) + _engines: dict[str, Engine] = {} + # NOTE: this is a hack to allow for multiple postgres schemas per engine for now. def get_sqlalchemy_engine(*, schema: str | None = DEFAULT_SCHEMA) -> Engine: if schema is None: @@ -205,16 +207,14 @@ def get_sqlalchemy_engine(*, schema: str | None = DEFAULT_SCHEMA) -> Engine: global _engines if schema not in _engines: - connection_string = build_connection_string( - db_api=SYNC_DB_API - ) + connection_string = build_connection_string(db_api=SYNC_DB_API) _engines[schema] = create_engine( connection_string, pool_size=40, max_overflow=10, pool_pre_ping=POSTGRES_POOL_PRE_PING, pool_recycle=POSTGRES_POOL_RECYCLE, - connect_args={"options": f"-c search_path={schema}"} + connect_args={"options": f"-c search_path={schema}"}, ) return _engines[schema] @@ -237,13 +237,16 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: ) return _ASYNC_ENGINE -current_tenant_id = contextvars.ContextVar( - "current_tenant_id", default=DEFAULT_SCHEMA -) + +current_tenant_id = contextvars.ContextVar("current_tenant_id", default=DEFAULT_SCHEMA) + def get_session_context_manager() -> ContextManager[Session]: tenant_id = current_tenant_id.get() - return contextlib.contextmanager(lambda: get_session(override_tenant_id=tenant_id))() + return contextlib.contextmanager( + lambda: get_session(override_tenant_id=tenant_id) + )() + def get_current_tenant_id(request: Request) -> str | None: if not MULTI_TENANT: @@ -257,7 +260,9 @@ def get_current_tenant_id(request: Request) -> str | None: payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) tenant_id = payload.get("tenant_id") if not tenant_id: - raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing") + raise HTTPException( + status_code=400, detail="Invalid token: tenant_id missing" + ) current_tenant_id.set(tenant_id) return tenant_id except (DecodeError, InvalidTokenError): @@ -265,22 +270,29 @@ def get_current_tenant_id(request: Request) -> str | None: except Exception: raise HTTPException(status_code=500, detail="Internal server error") + def get_session( tenant_id: str = Depends(get_current_tenant_id), - override_tenant_id: str | None = None + override_tenant_id: str | None = None, ) -> Generator[Session, None, None]: if override_tenant_id: tenant_id = override_tenant_id - with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session: + with Session( + get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False + ) as session: yield session -async def get_async_session(tenant_id: str | None = None) -> AsyncGenerator[AsyncSession, None]: + +async def get_async_session( + tenant_id: str | None = None, +) -> AsyncGenerator[AsyncSession, None]: async with AsyncSession( get_sqlalchemy_async_engine(), expire_on_commit=False ) as async_session: yield async_session + async def warm_up_connections( sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20 ) -> None: @@ -304,6 +316,7 @@ async def warm_up_connections( for async_conn in async_connections: await async_conn.close() + def get_session_factory() -> sessionmaker[Session]: global SessionFactory if SessionFactory is None: diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 89872394317..698081aa018 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -154,7 +154,6 @@ def fetch_embedding_provider( def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: - provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_provider == True # noqa: E712 diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 9217784e9fb..11edcf96be2 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -205,4 +205,3 @@ def update_search_settings_status( ) -> None: search_settings.status = new_status db_session.commit() - diff --git a/backend/danswer/db_setup.py b/backend/danswer/db_setup.py index 75054675403..0246fa56c60 100644 --- a/backend/danswer/db_setup.py +++ b/backend/danswer/db_setup.py @@ -1,7 +1,9 @@ from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair from danswer.db.credentials import create_initial_public_credential -from ee.danswer.db.standard_answer import create_initial_default_standard_answer_category +from ee.danswer.db.standard_answer import ( + create_initial_default_standard_answer_category, +) from danswer.db.persona import delete_old_default_personas from danswer.chat.load_yamls import load_chat_yamls from danswer.tools.built_in_tools import auto_add_search_tool_to_personas @@ -12,6 +14,7 @@ logger = setup_logger() + def setup_postgres(db_session: Session) -> None: logger.notice("Verifying default connector/credential exist.") create_initial_public_credential(db_session) diff --git a/backend/danswer/document_index/factory.py b/backend/danswer/document_index/factory.py index 290b1e26a2e..80476f0c075 100644 --- a/backend/danswer/document_index/factory.py +++ b/backend/danswer/document_index/factory.py @@ -8,7 +8,7 @@ def get_default_document_index( primary_index_name: str | None = None, indices: list[str] | None = None, - secondary_index_name: str | None = None + secondary_index_name: str | None = None, ) -> DocumentIndex: """Primary index is the index that is used for querying/updating etc. Secondary index is for when both the currently used index and the upcoming @@ -19,10 +19,8 @@ def get_default_document_index( if not indices: raise ValueError("No indices provided") - return VespaIndex( - indices=indices, - secondary_index_name=secondary_index_name - ) + return VespaIndex(indices=indices, secondary_index_name=secondary_index_name) + def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex: """ diff --git a/backend/danswer/document_index/interfaces.py b/backend/danswer/document_index/interfaces.py index c1a4f18f1ca..99404686222 100644 --- a/backend/danswer/document_index/interfaces.py +++ b/backend/danswer/document_index/interfaces.py @@ -100,7 +100,7 @@ def ensure_indices_exist( self, embedding_dims: list[int] | None = None, index_embedding_dim: int | None = None, - secondary_index_embedding_dim: int | None = None + secondary_index_embedding_dim: int | None = None, ) -> None: """ Verify that the document index exists and is consistent with the expectations in the code. diff --git a/backend/danswer/document_index/vespa/chunk_retrieval.py b/backend/danswer/document_index/vespa/chunk_retrieval.py index b026ba0bcf7..e4b2ad83ce2 100644 --- a/backend/danswer/document_index/vespa/chunk_retrieval.py +++ b/backend/danswer/document_index/vespa/chunk_retrieval.py @@ -341,8 +341,6 @@ def query_vespa( return inference_chunks - - def _get_chunks_via_batch_search( index_name: str, chunk_requests: list[VespaChunkRequest], diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 6fda7ee058a..13d6868ae82 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -127,9 +127,8 @@ def ensure_indices_exist( self, embedding_dims: list[int] | None = None, index_embedding_dim: int | None = None, - secondary_index_embedding_dim: int | None = None + secondary_index_embedding_dim: int | None = None, ) -> None: - if embedding_dims is None: if index_embedding_dim is not None: embedding_dims = [index_embedding_dim] @@ -192,23 +191,33 @@ def ensure_indices_exist( for i, index_name in enumerate(self.indices): embedding_dim = embedding_dims[i] - logger.info(f"Creating index: {index_name} with embedding dimension: {embedding_dim}") + logger.info( + f"Creating index: {index_name} with embedding dimension: {embedding_dim}" + ) schema = schema_template.replace( DANSWER_CHUNK_REPLACEMENT_PAT, index_name ).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim)) - schema = schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "") + schema = schema.replace( + TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "" + ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8") if self.secondary_index_name: - logger.info("Creating secondary index:" - f"{self.secondary_index_name} with embedding dimension: {secondary_index_embedding_dim}") + logger.info( + "Creating secondary index:" + f"{self.secondary_index_name} with embedding dimension: {secondary_index_embedding_dim}" + ) upcoming_schema = schema_template.replace( DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name ).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim)) - upcoming_schema = upcoming_schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "") - zip_dict[f"schemas/{self.secondary_index_name}.sd"] = upcoming_schema.encode("utf-8") + upcoming_schema = upcoming_schema.replace( + TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "" + ) + zip_dict[ + f"schemas/{self.secondary_index_name}.sd" + ] = upcoming_schema.encode("utf-8") zip_file = in_memory_zip_from_file_bytes(zip_dict) @@ -220,7 +229,6 @@ def ensure_indices_exist( f"Failed to prepare Vespa Danswer Indexes. Response: {response.text}" ) - def index( self, chunks: list[DocMetadataAwareIndexChunk], diff --git a/backend/danswer/document_index/vespa/indexing_utils.py b/backend/danswer/document_index/vespa/indexing_utils.py index 01a61527004..b35fc05c8ae 100644 --- a/backend/danswer/document_index/vespa/indexing_utils.py +++ b/backend/danswer/document_index/vespa/indexing_utils.py @@ -98,7 +98,7 @@ def get_existing_documents_from_chunks( try: chunk_existence_future = { executor.submit( - _does_document_exist, + _does_document_exist, str(get_uuid_from_chunk(chunk)), index_name, http_client, diff --git a/backend/danswer/document_index/vespa/shared_utils/vespa_request_builders.py b/backend/danswer/document_index/vespa/shared_utils/vespa_request_builders.py index 22a62a9bdf4..312e779647f 100644 --- a/backend/danswer/document_index/vespa/shared_utils/vespa_request_builders.py +++ b/backend/danswer/document_index/vespa/shared_utils/vespa_request_builders.py @@ -20,7 +20,6 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str: - def _build_or_filters(key: str, vals: list[str] | None) -> str: if vals is None: return "" @@ -58,7 +57,6 @@ def _build_time_filter( if filters.tenant_id: filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and ' - # CAREFUL touching this one, currently there is no second ACL double-check post retrieval if filters.access_control_list is not None: filter_str += _build_or_filters( diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 1a672e899cc..39cfa2cca0c 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -55,10 +55,12 @@ def to_short_descriptor(self) -> str: f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}" ) + class IndexChunk(DocAwareChunk): embeddings: ChunkEmbedding title_embedding: Embedding | None + # TODO(rkuo): currently, this extra metadata sent during indexing is just for speed, # but full consistency happens on background sync class DocMetadataAwareIndexChunk(IndexChunk): diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 79f0b2bfd1f..a21881b0a68 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -105,7 +105,9 @@ def compute_max_document_tokens_for_persona( prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only() return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), - llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona, db_session=db_session)).config, + llm_config=get_main_llm_from_tuple( + get_llms_for_persona(persona, db_session=db_session) + ).config, actual_user_input=actual_user_input, max_llm_token_override=max_llm_token_override, ) diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index b3b98311c7e..6bc4b8e80cf 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -16,11 +16,13 @@ logger = setup_logger() + def get_main_llm_from_tuple( llms: tuple[LLM, LLM], ) -> LLM: return llms[0] + def get_llms_for_persona( persona: Persona, db_session: Session, @@ -62,6 +64,7 @@ def _create_llm(model: str) -> LLM: custom_config=llm_provider.custom_config, additional_headers=additional_headers, ) + return _create_llm(model), _create_llm(fast_model) @@ -74,14 +77,12 @@ def get_default_llms( if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - if db_session is None: with get_session_context_manager() as db_session: llm_provider = fetch_default_provider(db_session) else: llm_provider = fetch_default_provider(db_session) - if not llm_provider: raise ValueError("No default LLM provider found") diff --git a/backend/danswer/main.py b/backend/danswer/main.py index d71e9bc07bc..919961813a2 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -171,7 +171,6 @@ def include_router_with_global_prefix_prepended( application.include_router(router, **final_kwargs) - def translate_saved_search_settings(db_session: Session) -> None: kv_store = get_dynamic_config_store() @@ -239,7 +238,7 @@ def mark_reindex_flag(db_session: Session) -> None: def setup_vespa( document_index: DocumentIndex, embedding_dims: list[int], - secondary_embedding_dim: int | None = None + secondary_embedding_dim: int | None = None, ) -> bool: # Vespa startup is a bit slow, so give it a few seconds wait_time = 5 @@ -249,7 +248,7 @@ def setup_vespa( logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") document_index.ensure_indices_exist( embedding_dims=embedding_dims, - secondary_index_embedding_dim=secondary_embedding_dim + secondary_index_embedding_dim=secondary_embedding_dim, ) return True except Exception: @@ -258,6 +257,7 @@ def setup_vespa( logger.exception("Error ensuring multi-tenant indices exist") return False + @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME) @@ -288,7 +288,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # Break bad state for thrashing indexes if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: - expire_index_attempts( search_settings_id=search_settings.id, db_session=db_session ) @@ -301,14 +300,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice(f'Using Embedding model: "{search_settings.model_name}"') if search_settings.query_prefix or search_settings.passage_prefix: - logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') logger.notice( f'Passage embedding prefix: "{search_settings.passage_prefix}"' ) if search_settings: - if not search_settings.disable_rerank_for_streaming: logger.notice("Reranking is enabled.") @@ -352,13 +349,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: setup_vespa( document_index, [model.dim for model in SUPPORTED_EMBEDDING_MODELS], - secondary_embedding_dim=secondary_search_settings.model_dim if secondary_search_settings else None + secondary_embedding_dim=secondary_search_settings.model_dim + if secondary_search_settings + else None, ) else: document_index = get_default_document_index( indices=[search_settings.index_name], - secondary_index_name=secondary_search_settings.index_name if secondary_search_settings else None + secondary_index_name=secondary_search_settings.index_name + if secondary_search_settings + else None, ) setup_vespa( @@ -368,7 +369,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: IndexingSetting.from_db_model(secondary_search_settings).model_dim if secondary_search_settings else None - ) + ), ) logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") @@ -485,9 +486,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router ) - include_router_with_global_prefix_prepended( - application, tenants_router - ) + include_router_with_global_prefix_prepended(application, tenants_router) include_router_with_global_prefix_prepended(application, indexing_router) if AUTH_TYPE == AuthType.DISABLED: diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 9fe81027457..7a981432c3e 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -128,7 +128,6 @@ def stream_answer_objects( persona = temporary_persona if temporary_persona else chat_session.persona - llm, fast_llm = get_llms_for_persona(persona=persona, db_session=db_session) llm_tokenizer = get_tokenizer( @@ -148,9 +147,7 @@ def stream_answer_objects( ) rephrased_query = query_req.query_override or thread_based_query_rephrase( - user_query=query_msg.message, - history_str=history_str, - db_session=db_session + user_query=query_msg.message, history_str=history_str, db_session=db_session ) # Given back ahead of the documents for latency reasons @@ -219,7 +216,9 @@ def stream_answer_objects( question=query_msg.message, answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona, db_session=db_session)), + llm=get_main_llm_from_tuple( + get_llms_for_persona(persona=chat_session.persona, db_session=db_session) + ), single_message_history=history_str, tools=[search_tool] if search_tool else [], force_use_tool=( diff --git a/backend/danswer/search/enums.py b/backend/danswer/search/enums.py index c0166c80c03..28f81704789 100644 --- a/backend/danswer/search/enums.py +++ b/backend/danswer/search/enums.py @@ -11,6 +11,7 @@ class RecencyBiasSetting(str, Enum): # Determine based on query if to use base_decay or favor_recent AUTO = "auto" + class OptionalSearchSetting(str, Enum): ALWAYS = "always" NEVER = "never" diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index 1493712f337..e2e00a84d73 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -34,6 +34,7 @@ logger = setup_logger() + def query_analysis(query: str) -> tuple[bool, list[str]]: analysis_model = QueryAnalysisModel() return analysis_model.predict(query) @@ -156,7 +157,6 @@ def retrieval_preprocessing( None if bypass_acl else build_access_filters_for_user(user, db_session) ) - final_filters = IndexFilters( source_type=preset_filters.source_type or predicted_source_filters, document_set=preset_filters.document_set, diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 9c4ef00e5d9..cb2032402d8 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -16,10 +16,13 @@ from danswer.utils.text_processing import count_punctuation from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel from sqlalchemy.orm import Session + logger = setup_logger() -def llm_multilingual_query_expansion(query: str, language: str, db_session: Session) -> str: +def llm_multilingual_query_expansion( + query: str, language: str, db_session: Session +) -> str: def _get_rephrase_messages() -> list[dict[str, str]]: messages = [ { @@ -66,7 +69,8 @@ def multilingual_query_expansion( else: query_rephrases = [ - llm_multilingual_query_expansion(query, language, db_session) for language in languages + llm_multilingual_query_expansion(query, language, db_session) + for language in languages ] return query_rephrases @@ -138,7 +142,7 @@ def thread_based_query_rephrase( db_session: Session, llm: LLM | None = None, size_heuristic: int = 200, - punctuation_heuristic: int = 10 + punctuation_heuristic: int = 10, ) -> str: if not history_str: return user_query diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 56785ecbc33..17d385419a0 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -48,7 +48,8 @@ def extract_answerability_bool(model_raw: str) -> bool: def get_query_answerability( db_session: Session, - user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, + user_query: str, + skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, ) -> tuple[str, bool]: if skip_check: return "Query Answerability Evaluation feature is turned off", True @@ -69,7 +70,9 @@ def get_query_answerability( def stream_query_answerability( - db_session: Session, user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, + db_session: Session, + user_query: str, + skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY, ) -> Iterator[str]: if skip_check: yield get_json_line( diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index a46fbb0548c..449b6c4eb07 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -193,5 +193,7 @@ def _extract_source_filters_from_llm_out( while True: user_input = input("Query to Extract Sources: ") sources = extract_source_filter( - user_input, get_main_llm_from_tuple(get_default_llms(db_session=db_session)), db_session + user_input, + get_main_llm_from_tuple(get_default_llms(db_session=db_session)), + db_session, ) diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index 55fa4b7d4e0..c8433467f6c 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -62,7 +62,6 @@ def set_new_search_settings( search_settings = get_current_search_settings(db_session) if search_settings_new.index_name is None: - # We define index name here index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}" if ( diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 78c74ad69cd..372af94e9cb 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -226,7 +226,8 @@ def rename_chat_session( try: llm, _ = get_default_llms( - additional_headers=get_litellm_additional_request_headers(request.headers), db_session=db_session + additional_headers=get_litellm_additional_request_headers(request.headers), + db_session=db_session, ) except GenAIDisabledException: # This may be longer than what the LLM tends to produce but is the most @@ -431,7 +432,9 @@ def get_max_document_tokens( raise HTTPException(status_code=404, detail="Persona not found") return MaxSelectedDocumentTokens( - max_tokens=compute_max_document_tokens_for_persona(persona, db_session=db_session), + max_tokens=compute_max_document_tokens_for_persona( + persona, db_session=db_session + ), ) @@ -481,7 +484,9 @@ def seed_chat( root_message = get_or_create_root_message( chat_session_id=new_chat_session.id, db_session=db_session ) - llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona, db_session=db_session) + llm, fast_llm = get_llms_for_persona( + persona=new_chat_session.persona, db_session=db_session + ) tokenizer = get_tokenizer( model_name=llm.config.model_name, diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index e3484094477..bc8a44efd1e 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -133,13 +133,17 @@ def get_tags( @basic_router.post("/query-validation") def query_validation( - simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session) + simple_query: SimpleQueryRequest, + _: User = Depends(current_user), + db_session: Session = Depends(get_session), ) -> QueryValidationResponse: # Note if weak model prompt is chosen, this check does not occur and will simply return that # the query is valid, this is because weaker models cannot really handle this task well. # Additionally, some weak model servers cannot handle concurrent inferences. logger.notice(f"Validating query: {simple_query.query}") - reasoning, answerable = get_query_answerability(db_session=db_session, user_query=simple_query.query) + reasoning, answerable = get_query_answerability( + db_session=db_session, user_query=simple_query.query + ) return QueryValidationResponse(reasoning=reasoning, answerable=answerable) @@ -247,14 +251,19 @@ def get_search_session( # No search responses are answered with a conversational generative AI response @basic_router.post("/stream-query-validation") def stream_query_validation( - simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session) + simple_query: SimpleQueryRequest, + _: User = Depends(current_user), + db_session: Session = Depends(get_session), ) -> StreamingResponse: # Note if weak model prompt is chosen, this check does not occur and will simply return that # the query is valid, this is because weaker models cannot really handle this task well. # Additionally, some weak model servers cannot handle concurrent inferences. logger.notice(f"Validating query: {simple_query.query}") return StreamingResponse( - stream_query_answerability(user_query=simple_query.query, db_session=db_session), media_type="application/json" + stream_query_answerability( + user_query=simple_query.query, db_session=db_session + ), + media_type="application/json", ) diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index f8a58b6e37a..c6871544690 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -36,7 +36,6 @@ basic_router = APIRouter(prefix="/settings") - @admin_router.put("") def put_settings( settings: Settings, _: User | None = Depends(current_admin_user) diff --git a/backend/danswer/server/tenants/api.py b/backend/danswer/server/tenants/api.py index 916060f6652..a88abcaf507 100644 --- a/backend/danswer/server/tenants/api.py +++ b/backend/danswer/server/tenants/api.py @@ -21,8 +21,11 @@ logger = setup_logger() basic_router = APIRouter(prefix="/tenants") + @basic_router.post("/create") -def create_tenant(tenant_id: str, _: None= Depends(control_plane_dep)) -> dict[str, str]: +def create_tenant( + tenant_id: str, _: None = Depends(control_plane_dep) +) -> dict[str, str]: if not MULTI_TENANT: raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") @@ -47,14 +50,14 @@ async def sso_callback( raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") payload = verify_sso_token(sso_token) - user = await user_manager.sso_authenticate( - payload["email"], payload["tenant_id"] - ) + user = await user_manager.sso_authenticate(payload["email"], payload["tenant_id"]) tenant_id = payload["tenant_id"] schema_exists = await check_schema_exists(tenant_id) if not schema_exists: - raise HTTPException(status_code=403, detail="Your Danswer app has not been set up yet!") + raise HTTPException( + status_code=403, detail="Your Danswer app has not been set up yet!" + ) session_token = await create_user_session(user, payload["tenant_id"]) diff --git a/backend/danswer/server/tenants/provisioning.py b/backend/danswer/server/tenants/provisioning.py index 9f14f9ad7e9..185a28ec2f0 100644 --- a/backend/danswer/server/tenants/provisioning.py +++ b/backend/danswer/server/tenants/provisioning.py @@ -15,7 +15,9 @@ from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair from danswer.db.credentials import create_initial_public_credential -from ee.danswer.db.standard_answer import create_initial_default_standard_answer_category +from ee.danswer.db.standard_answer import ( + create_initial_default_standard_answer_category, +) from danswer.db.persona import delete_old_default_personas from danswer.chat.load_yamls import load_chat_yamls from danswer.tools.built_in_tools import auto_add_search_tool_to_personas @@ -33,40 +35,46 @@ logger = setup_logger() + def run_alembic_migrations(schema_name: str) -> None: logger.info(f"Starting Alembic migrations for schema: {schema_name}") try: current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) - alembic_ini_path = os.path.join(root_dir, 'alembic.ini') + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") # Configure Alembic alembic_cfg = Config(alembic_ini_path) - alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string()) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) # Prepare the x arguments x_arguments = [f"schema={schema_name}"] - alembic_cfg.cmd_opts.x = x_arguments # type: ignore + alembic_cfg.cmd_opts.x = x_arguments # type: ignore # Run migrations programmatically - command.upgrade(alembic_cfg, 'head') + command.upgrade(alembic_cfg, "head") - logger.info(f"Alembic migrations completed successfully for schema: {schema_name}") + logger.info( + f"Alembic migrations completed successfully for schema: {schema_name}" + ) except Exception as e: logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") raise + def create_tenant_schema(tenant_id: str) -> None: with Session(get_sqlalchemy_engine()) as db_session: with db_session.begin(): result = db_session.execute( - text(""" + text( + """ SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name - """), - {"schema_name": tenant_id} + """ + ), + {"schema_name": tenant_id}, ) schema_exists = result.scalar() is not None @@ -97,9 +105,7 @@ def setup_postgres_and_initial_settings(db_session: Session) -> None: logger.notice(f'Using Embedding model: "{search_settings.model_name}"') if search_settings.query_prefix or search_settings.passage_prefix: logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') - logger.notice( - f'Passage embedding prefix: "{search_settings.passage_prefix}"' - ) + logger.notice(f'Passage embedding prefix: "{search_settings.passage_prefix}"') if search_settings: if not search_settings.disable_rerank_for_streaming: @@ -122,7 +128,6 @@ def setup_postgres_and_initial_settings(db_session: Session) -> None: # ensure Vespa is setup correctly logger.notice("Verifying Document Index(s) is/are available.") - logger.notice("Verifying default connector/credential exist.") create_initial_public_credential(db_session) create_initial_default_connector(db_session) @@ -142,12 +147,12 @@ def setup_postgres_and_initial_settings(db_session: Session) -> None: async def check_schema_exists(tenant_id: str) -> bool: - get_async_session_context = contextlib.asynccontextmanager( - get_async_session - ) + get_async_session_context = contextlib.asynccontextmanager(get_async_session) async with get_async_session_context() as session: result = await session.execute( - text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"), - {"schema_name": tenant_id} + text( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" + ), + {"schema_name": tenant_id}, ) return result.scalar() is not None diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index 2a7848e6377..1b931c7b311 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -19,6 +19,7 @@ logger = setup_logger() + def verify_auth_setting() -> None: # All the Auth flows are valid for EE version logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") @@ -72,7 +73,6 @@ async def optional_user_( return user - def api_key_dep( request: Request, db_session: Session = Depends(get_session) ) -> User | None: @@ -92,7 +92,6 @@ def api_key_dep( return user - def get_default_admin_user_emails_() -> list[str]: seed_config = get_seed_config() if seed_config and seed_config.admin_user_emails: diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index a5613fed918..6c89bd349b5 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -272,9 +272,7 @@ def handle_send_message_simple_with_history( ) rephrased_query = req.query_override or thread_based_query_rephrase( - user_query=query, - history_str=history_str, - db_session=db_session + user_query=query, history_str=history_str, db_session=db_session ) if req.retrieval_options is None and req.search_doc_ids is None: diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 59e4180b2c2..ca544a4dde6 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -153,7 +153,9 @@ def get_answer_with_quote( raise KeyError("Must provide persona ID or Persona Config") llm = get_main_llm_from_tuple( - get_default_llms(db_session=db_session) if not persona else get_llms_for_persona(persona, db_session=db_session) + get_default_llms(db_session=db_session) + if not persona + else get_llms_for_persona(persona, db_session=db_session) ) input_tokens = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index f1d5089d438..1d3e5b4432b 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -94,7 +94,7 @@ def generate_dummy_chunk( ), document_sets={document_set for document_set in document_set_names}, boost=random.randint(-1, 1), - tenant_id=None + tenant_id=None, ) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5089fce5130..c7d158c2252 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -77,31 +77,33 @@ "query_prefix", ] + class SupportedEmbeddingModel(BaseModel): name: str dim: int index_name: str + SUPPORTED_EMBEDDING_MODELS = [ SupportedEmbeddingModel( name="intfloat/e5-small-v2", dim=384, - index_name="danswer_chunk_intfloat_e5_small_v2" + index_name="danswer_chunk_intfloat_e5_small_v2", ), SupportedEmbeddingModel( name="intfloat/e5-large-v2", dim=1024, - index_name="danswer_chunk_intfloat_e5_large_v2" + index_name="danswer_chunk_intfloat_e5_large_v2", ), SupportedEmbeddingModel( name="sentence-transformers/all-distilroberta-v1", dim=768, - index_name="danswer_chunk_sentence_transformers_all_distilroberta_v1" + index_name="danswer_chunk_sentence_transformers_all_distilroberta_v1", ), SupportedEmbeddingModel( name="sentence-transformers/all-mpnet-base-v2", dim=768, - index_name="danswer_chunk_sentence_transformers_all_mpnet_base_v2" + index_name="danswer_chunk_sentence_transformers_all_mpnet_base_v2", ), ] diff --git a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx index 7165abfb843..36ab1f9863e 100644 --- a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx +++ b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx @@ -214,8 +214,9 @@ function ConnectorRow({ return ( { router.push(`/admin/connector/${ccPairsIndexingStatus.cc_pair_id}`); }} @@ -530,4 +531,4 @@ export function CCPairIndexingStatusTable({ ); -} \ No newline at end of file +} diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 510c3ad50d2..9d1430308a6 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -61,7 +61,6 @@ export interface CombinedSettings { webVersion: string | null; } - export const defaultCombinedSettings: CombinedSettings = { settings: { chat_page_enabled: true, @@ -87,4 +86,4 @@ export const defaultCombinedSettings: CombinedSettings = { customAnalyticsScript: null, isMobile: false, webVersion: null, -}; \ No newline at end of file +}; diff --git a/web/src/app/auth/login/page.tsx b/web/src/app/auth/login/page.tsx index e4adcb36116..fd032f4fffd 100644 --- a/web/src/app/auth/login/page.tsx +++ b/web/src/app/auth/login/page.tsx @@ -89,7 +89,7 @@ const Page = async ({ /> )} - + {authTypeMetadata?.authType === "basic" && (
diff --git a/web/src/app/auth/sso-callback/page.tsx b/web/src/app/auth/sso-callback/page.tsx index 998f18b6cdf..090f8be2d5d 100644 --- a/web/src/app/auth/sso-callback/page.tsx +++ b/web/src/app/auth/sso-callback/page.tsx @@ -17,7 +17,7 @@ export default function SSOCallback() { if (verificationStartedRef.current) { return; } - + verificationStartedRef.current = true; const hashParams = new URLSearchParams(window.location.hash.slice(1)); const ssoToken = hashParams.get("sso_token"); @@ -27,7 +27,7 @@ export default function SSOCallback() { return; } - window.history.replaceState(null, '', window.location.pathname); + window.history.replaceState(null, "", window.location.pathname); if (!ssoToken) { setError("No SSO token found"); @@ -36,17 +36,14 @@ export default function SSOCallback() { try { setAuthStatus("Verifying SSO token..."); - const response = await fetch( - `/api/tenants/auth/sso-callback`, - { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - credentials: "include", - body: JSON.stringify({ sso_token: ssoToken }), - } - ) + const response = await fetch(`/api/tenants/auth/sso-callback`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + credentials: "include", + body: JSON.stringify({ sso_token: ssoToken }), + }); if (response.ok) { setAuthStatus("Authentication successful!"); @@ -64,7 +61,6 @@ export default function SSOCallback() { verifyToken(); }, [router, searchParams]); - return (
diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index be3ccac4add..20722a9cbad 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -233,7 +233,6 @@ export function ChatPage({ destructureValue(user?.preferences.default_model || "openai:gpt-4o") ); - if (personaDefault) { llmOverrideManager.setLlmOverride(personaDefault); } else if (user?.preferences.default_model) { @@ -962,7 +961,7 @@ export function ChatPage({ console.log("model override", modelOverRide); console.log(modelOverRide?.name); console.log(llmOverrideManager.llmOverride.name); - console.log("HII") + console.log("HII"); console.log(llmOverrideManager.globalDefault.name); let frozenSessionId = currentSessionId(); updateCanContinue(false, frozenSessionId); diff --git a/web/src/app/chat/folders/FolderList.tsx b/web/src/app/chat/folders/FolderList.tsx index 3c01f99f6c1..01e69b3a1a4 100644 --- a/web/src/app/chat/folders/FolderList.tsx +++ b/web/src/app/chat/folders/FolderList.tsx @@ -206,7 +206,6 @@ const FolderItem = ({ /> ) : (
- {editedFolderName || folder.folder_name}
)} diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 117442c5002..64535d82b20 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -120,7 +120,6 @@ export function ChatInputBar({ const { llmProviders } = useChatContext(); const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null); - const suggestionsRef = useRef(null); const [showSuggestions, setShowSuggestions] = useState(false); const [showPrompts, setShowPrompts] = useState(false); diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 2cde39a771b..9dddd80ab26 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -154,8 +154,8 @@ export async function* sendMessage({ }): AsyncGenerator { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; - - console.log("llm ovverride deatilas", modelProvider, modelVersion) + + console.log("llm ovverride deatilas", modelProvider, modelVersion); const body = JSON.stringify({ alternate_assistant_id: alternateAssistantId, @@ -168,30 +168,30 @@ export async function* sendMessage({ regenerate, retrieval_options: !documentsAreSelected ? { - run_search: - promptId === null || + run_search: + promptId === null || promptId === undefined || queryOverride || forceSearch - ? "always" - : "auto", - real_time: true, - filters: filters, - } + ? "always" + : "auto", + real_time: true, + filters: filters, + } : null, query_override: queryOverride, prompt_override: systemPromptOverride ? { - system_prompt: systemPromptOverride, - } + system_prompt: systemPromptOverride, + } : null, llm_override: temperature || modelVersion ? { - temperature, - model_provider: modelProvider, - model_version: modelVersion, - } + temperature, + model_provider: modelProvider, + model_version: modelVersion, + } : null, use_existing_user_message: useExistingUserMessage, }); @@ -431,11 +431,11 @@ export function processRawChatHistory( // this is identical to what is computed at streaming time ...(messageInfo.message_type === "assistant" ? { - retrievalType: retrievalType, - query: messageInfo.rephrased_query, - documents: messageInfo?.context_docs?.top_documents || [], - citations: messageInfo?.citations || {}, - } + retrievalType: retrievalType, + query: messageInfo.rephrased_query, + documents: messageInfo?.context_docs?.top_documents || [], + citations: messageInfo?.citations || {}, + } : {}), toolCalls: messageInfo.tool_calls, parentMessageId: messageInfo.parent_message, @@ -600,7 +600,8 @@ export function buildChatUrl( const finalSearchParams: string[] = []; if (chatSessionId) { finalSearchParams.push( - `${search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID + `${ + search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID }=${chatSessionId}` ); } diff --git a/web/src/app/ee/admin/plan/BillingSettings.tsx b/web/src/app/ee/admin/plan/BillingSettings.tsx index c952eae8842..769c6bc67d4 100644 --- a/web/src/app/ee/admin/plan/BillingSettings.tsx +++ b/web/src/app/ee/admin/plan/BillingSettings.tsx @@ -27,15 +27,13 @@ export function BillingSettings({ newUser }: { newUser: boolean }) { const searchParams = useSearchParams(); const [isOpen, setIsOpen] = useState(false); - const [isNewUserOpen, setIsNewUserOpen] = useState(true) + const [isNewUserOpen, setIsNewUserOpen] = useState(true); const [newSeats, setNewSeats] = useState(null); const [newPlan, setNewPlan] = useState(null); - const { popup, setPopup } = usePopup(); - useEffect(() => { const success = searchParams.get("success"); if (success === "true") { @@ -68,13 +66,11 @@ export function BillingSettings({ newUser }: { newUser: boolean }) { setNewPlan(cloudSettings.planType); } }, [cloudSettings]); - if (!cloudSettings) { return null; } - const features = [ { name: "All Connector Access", included: true }, { name: "Basic Support", included: true }, @@ -124,11 +120,10 @@ export function BillingSettings({ newUser }: { newUser: boolean }) { Cookies.set("new_auth_user", "false"); }; - if (newSeats === null || currentPlan === undefined) { return null; } - + return (
{newUser && isNewUserOpen && ( @@ -143,8 +138,8 @@ export function BillingSettings({ newUser }: { newUser: boolean }) {

- We're thrilled to have you on board! Here, you can manage your - billing settings and explore your plan details. + We're thrilled to have you on board! Here, you can manage + your billing settings and explore your plan details.

diff --git a/web/src/app/ee/admin/plan/StripeCheckoutButton.tsx b/web/src/app/ee/admin/plan/StripeCheckoutButton.tsx index 90d579a030e..604596ddf18 100644 --- a/web/src/app/ee/admin/plan/StripeCheckoutButton.tsx +++ b/web/src/app/ee/admin/plan/StripeCheckoutButton.tsx @@ -51,10 +51,11 @@ export function StripeCheckoutButton({ return ( diff --git a/web/src/app/ee/layout.tsx b/web/src/app/ee/layout.tsx index 86fe042c23d..fadbd592f4d 100644 --- a/web/src/app/ee/layout.tsx +++ b/web/src/app/ee/layout.tsx @@ -16,4 +16,4 @@ export default async function AdminLayout({ } return children; -} \ No newline at end of file +} diff --git a/web/src/middleware.ts b/web/src/middleware.ts index 9332e82cda6..8e08eee0689 100644 --- a/web/src/middleware.ts +++ b/web/src/middleware.ts @@ -13,7 +13,6 @@ const eePaths = [ "/admin/performance/custom-analytics/:path*", "/admin/standard-answer/:path*", "/admin/plan/:path*", - ]; // removes the "/:path*" from the end