Skip to content

Commit

Permalink
add super user
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Nov 2, 2024
1 parent 5d9b836 commit ec8ae2b
Show file tree
Hide file tree
Showing 24 changed files with 236 additions and 27 deletions.
10 changes: 7 additions & 3 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


logger = setup_logger()
Expand Down Expand Up @@ -510,19 +510,23 @@ async def get_user_manager(

# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def write_token(self, user: User) -> str:
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
"tenant_id": tenant_id,
}
return data

async def write_token(self, user: User) -> str:
data = await self._create_token_data(user)
return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
)


def get_jwt_strategy() -> JWTStrategy:
def get_jwt_strategy() -> TenantAwareJWTStrategy:
return TenantAwareJWTStrategy(
secret=USER_AUTH_SECRET,
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,7 @@

# JWT configuration
JWT_ALGORITHM = "HS256"

# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["[email protected]"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
2 changes: 1 addition & 1 deletion backend/danswer/connectors/file/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/danswerbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/key_value_store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()

Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/server/auth_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep


Expand Down Expand Up @@ -100,6 +101,7 @@ def check_router_auth(
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):
found_auth = True
break
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/server/manage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class UserInfo(BaseModel):
oidc_expiry: datetime | None = None
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
organization_name: str | None = None

@classmethod
Expand All @@ -65,6 +66,7 @@ def from_model(
user: User,
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
organization_name: str | None = None,
) -> "UserInfo":
return cls(
Expand All @@ -90,6 +92,7 @@ def from_model(
oidc_expiry=user.oidc_expiry if TRACK_EXTERNAL_IDP_EXPIRY else None,
current_token_created_at=current_token_created_at,
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
)


Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.auth import get_total_users_count
Expand Down Expand Up @@ -476,6 +477,7 @@ def verify_user_logged_in(
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
# to enforce user verification here - the frontend always wants to get the info about
# the current user regardless of if they are currently verified

if user is None:
# if auth type is disabled, return a dummy user with preferences from
# the key-value store
Expand All @@ -502,6 +504,7 @@ def verify_user_logged_in(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
organization_name=organization_name,
)

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/server/query_and_chat/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


logger = setup_logger()
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from logging.handlers import RotatingFileHandler
from typing import Any

from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import DEV_LOGGING_ENABLED
from shared_configs.configs import LOG_FILE_NAME
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


logging.addLevelName(logging.INFO + 5, "NOTICE")
Expand Down
20 changes: 20 additions & 0 deletions backend/ee/danswer/auth/users.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session

from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import User
Expand Down Expand Up @@ -68,3 +72,19 @@ def get_default_admin_user_emails_() -> list[str]:
if seed_config and seed_config.admin_user_emails:
return seed_config.admin_user_emails
return []


async def current_cloud_superuser(
request: Request,
user: User | None = Depends(current_admin_user),
) -> User | None:
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
if api_key != SUPER_CLOUD_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")

if user and user.email not in SUPER_USERS:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be a cloud superuser to perform this action.",
)
return user
2 changes: 1 addition & 1 deletion backend/ee/danswer/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()

Expand Down
1 change: 1 addition & 0 deletions backend/ee/danswer/db/query_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fetch_chat_sessions_eagerly_by_time(
filters: list[ColumnElement | BinaryExpression] = [
ChatSession.time_created.between(start, end)
]

if initial_id:
filters.append(ChatSession.id < initial_id)
subquery = (
Expand Down
11 changes: 6 additions & 5 deletions backend/ee/danswer/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
Expand All @@ -22,11 +23,11 @@ async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
tenant_id = POSTGRES_DEFAULT_SCHEMA

if MULTI_TENANT:
tenant_id = _get_tenant_id_from_request(request, logger)

tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)

Expand Down
36 changes: 35 additions & 1 deletion backend/ee/danswer/server/tenants/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,36 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response

from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


stripe.api_key = STRIPE_SECRET_KEY
Expand Down Expand Up @@ -132,3 +139,30 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)

with get_session_with_tenant(tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_jwt_strategy().write_token(user_to_impersonate)

response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
4 changes: 4 additions & 0 deletions backend/ee/danswer/server/tenants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ class BillingInformation(BaseModel):

class CheckoutSessionCreationResponse(BaseModel):
id: str


class ImpersonateRequest(BaseModel):
email: str
5 changes: 0 additions & 5 deletions backend/shared_configs/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextvars
import os
from typing import List
from urllib.parse import urlparse
Expand Down Expand Up @@ -134,10 +133,6 @@ def validate_cors_origin(origin: str) -> None:

POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"

CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)

# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"

Expand Down
8 changes: 8 additions & 0 deletions backend/shared_configs/contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import contextvars

from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

# Context variable for the current tenant id
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
Loading

0 comments on commit ec8ae2b

Please sign in to comment.