Skip to content

Commit

Permalink
quick typing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Sep 28, 2024
1 parent fd4fbbf commit cba2284
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
4 changes: 4 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,7 @@
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)


MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
1 change: 1 addition & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"

# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
Expand Down
45 changes: 31 additions & 14 deletions backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger

Expand All @@ -47,11 +47,10 @@
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_ASYNC_ENGINE: AsyncEngine | None = None

_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None


if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
Expand Down Expand Up @@ -187,17 +186,13 @@ def get_app_name(cls) -> str:
return cls._app_name


# Global variable for async engine
_ASYNC_ENGINE: AsyncEngine | None = None


def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: int = POSTGRES_PORT,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
Expand Down Expand Up @@ -236,21 +231,23 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:


# Context variable to store the current tenant ID
current_tenant_id = contextvars.ContextVar("current_tenant_id", default=DEFAULT_SCHEMA)
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)


# Dependency to get the current tenant ID and set the context variable
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = DEFAULT_SCHEMA
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id

token = request.cookies.get("tenant_details")
if not token:
# If no token is present, use the default schema or handle accordingly
tenant_id = DEFAULT_SCHEMA
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id

Expand Down Expand Up @@ -304,12 +301,32 @@ def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()


SessionFactory: sessionmaker[Session] | None = None


def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory


async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
sync_postgres_engine = get_sqlalchemy_engine()
connections = [
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
]
for conn in connections:
conn.execute(text("SELECT 1"))
for conn in connections:
conn.close()

async_postgres_engine = get_sqlalchemy_async_engine()
async_connections = [
await async_postgres_engine.connect()
for _ in range(async_connections_to_warm_up)
]
for async_conn in async_connections:
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()

0 comments on commit cba2284

Please sign in to comment.