From cba2284c54af8e2c0105353c4483913973ba6312 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 28 Sep 2024 10:29:15 -0700 Subject: [PATCH] quick typing fix --- backend/danswer/configs/app_configs.py | 4 +++ backend/danswer/configs/constants.py | 1 + backend/danswer/db/engine.py | 45 ++++++++++++++++++-------- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 4ccc3c0bafc..a1248461db6 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -388,3 +388,7 @@ ENTERPRISE_EDITION_ENABLED = ( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" ) + + +MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" +SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 52314db920c..3580817f3e2 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -39,6 +39,7 @@ POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy" POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_UNKNOWN_APP_NAME = "unknown" +POSTGRES_DEFAULT_SCHEMA = "public" # API Keys DANSWER_API_KEY_PREFIX = "API_KEY__" diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 204dd165567..48b029d5b22 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -34,7 +34,7 @@ from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER from danswer.configs.app_configs import SECRET_JWT_KEY -from danswer.configs.constants import DEFAULT_SCHEMA +from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.utils.logger import setup_logger @@ -47,11 +47,10 @@ # global so we don't create more than one engine per process # outside of being best practice, this is needed so we can properly pool # connections and not create a new pool on every request -_ASYNC_ENGINE: AsyncEngine | None = None +_ASYNC_ENGINE: AsyncEngine | None = None SessionFactory: sessionmaker[Session] | None = None - if LOG_POSTGRES_LATENCY: # Function to log before query execution @event.listens_for(Engine, "before_cursor_execute") @@ -187,17 +186,13 @@ def get_app_name(cls) -> str: return cls._app_name -# Global variable for async engine -_ASYNC_ENGINE: AsyncEngine | None = None - - def build_connection_string( *, db_api: str = ASYNC_DB_API, user: str = POSTGRES_USER, password: str = POSTGRES_PASSWORD, host: str = POSTGRES_HOST, - port: int = POSTGRES_PORT, + port: str = POSTGRES_PORT, db: str = POSTGRES_DB, app_name: str | None = None, ) -> str: @@ -236,21 +231,23 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: # Context variable to store the current tenant ID -current_tenant_id = contextvars.ContextVar("current_tenant_id", default=DEFAULT_SCHEMA) +current_tenant_id = contextvars.ContextVar( + "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA +) # Dependency to get the current tenant ID and set the context variable def get_current_tenant_id(request: Request) -> str: """Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable.""" if not MULTI_TENANT: - tenant_id = DEFAULT_SCHEMA + tenant_id = POSTGRES_DEFAULT_SCHEMA current_tenant_id.set(tenant_id) return tenant_id token = request.cookies.get("tenant_details") if not token: # If no token is present, use the default schema or handle accordingly - tenant_id = DEFAULT_SCHEMA + tenant_id = POSTGRES_DEFAULT_SCHEMA current_tenant_id.set(tenant_id) return tenant_id @@ -304,12 +301,32 @@ def get_session_context_manager() -> ContextManager[Session]: return contextlib.contextmanager(get_session)() -SessionFactory: sessionmaker[Session] | None = None - - def get_session_factory() -> sessionmaker[Session]: """Get a session factory.""" global SessionFactory if SessionFactory is None: SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) return SessionFactory + + +async def warm_up_connections( + sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20 +) -> None: + sync_postgres_engine = get_sqlalchemy_engine() + connections = [ + sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up) + ] + for conn in connections: + conn.execute(text("SELECT 1")) + for conn in connections: + conn.close() + + async_postgres_engine = get_sqlalchemy_async_engine() + async_connections = [ + await async_postgres_engine.connect() + for _ in range(async_connections_to_warm_up) + ] + for async_conn in async_connections: + await async_conn.execute(text("SELECT 1")) + for async_conn in async_connections: + await async_conn.close()