diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index aa5054983fa..0cb4ae2326c 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -93,9 +93,9 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -510,19 +510,23 @@ async def get_user_manager( # This strategy is used to add tenant_id to the JWT token class TenantAwareJWTStrategy(JWTStrategy): - async def write_token(self, user: User) -> str: + async def _create_token_data(self, user: User, impersonate: bool = False) -> dict: tenant_id = get_tenant_id_for_email(user.email) data = { "sub": str(user.id), "aud": self.token_audience, "tenant_id": tenant_id, } + return data + + async def write_token(self, user: User) -> str: + data = await self._create_token_data(user) return generate_jwt( data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm ) -def get_jwt_strategy() -> JWTStrategy: +def get_jwt_strategy() -> TenantAwareJWTStrategy: return TenantAwareJWTStrategy( secret=USER_AUTH_SECRET, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS, diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index fb6b4996363..b76ad6042d1 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -478,3 +478,7 @@ # JWT configuration JWT_ALGORITHM = "HS256" + +# Super Users +SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]')) +SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index d07a224478e..13744f02b2f 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -27,8 +27,8 @@ from danswer.file_processing.extract_file_text import read_text_file from danswer.file_store.file_store import get_default_file_store from danswer.utils.logger import setup_logger -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index a40dbe9a9b9..2078d621325 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -57,10 +57,10 @@ from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import SLACK_CHANNEL_ID +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index a1fbbddc65f..9dc6024ab8c 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -37,10 +37,10 @@ from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.utils.logger import setup_logger -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index 5fec7dea40c..8e60dab88b1 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -16,9 +16,9 @@ from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 69aede4241f..4300bc464cb 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -10,6 +10,7 @@ from danswer.auth.users import current_user_with_expired_token from danswer.configs.app_configs import APP_API_PREFIX from danswer.server.danswer_api.ingestion import api_key_dep +from ee.danswer.auth.users import current_cloud_superuser from ee.danswer.server.tenants.access import control_plane_dep @@ -100,6 +101,7 @@ def check_router_auth( or depends_fn == api_key_dep or depends_fn == current_user_with_expired_token or depends_fn == control_plane_dep + or depends_fn == current_cloud_superuser ): found_auth = True break diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 6e021b562ff..e24b96e9a1e 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -57,6 +57,7 @@ class UserInfo(BaseModel): oidc_expiry: datetime | None = None current_token_created_at: datetime | None = None current_token_expiry_length: int | None = None + is_cloud_superuser: bool = False organization_name: str | None = None @classmethod @@ -65,6 +66,7 @@ def from_model( user: User, current_token_created_at: datetime | None = None, expiry_length: int | None = None, + is_cloud_superuser: bool = False, organization_name: str | None = None, ) -> "UserInfo": return cls( @@ -90,6 +92,7 @@ def from_model( oidc_expiry=user.oidc_expiry if TRACK_EXTERNAL_IDP_EXPIRY else None, current_token_created_at=current_token_created_at, current_token_expiry_length=expiry_length, + is_cloud_superuser=is_cloud_superuser, ) diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index f0675c39282..9e385908fa4 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -35,6 +35,7 @@ from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import ENABLE_EMAIL_INVITES from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS +from danswer.configs.app_configs import SUPER_USERS from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType from danswer.db.auth import get_total_users_count @@ -476,6 +477,7 @@ def verify_user_logged_in( # NOTE: this does not use `current_user` / `current_admin_user` because we don't want # to enforce user verification here - the frontend always wants to get the info about # the current user regardless of if they are currently verified + if user is None: # if auth type is disabled, return a dummy user with preferences from # the key-value store @@ -502,6 +504,7 @@ def verify_user_logged_in( user, current_token_created_at=token_created_at, expiry_length=SESSION_EXPIRE_TIME_SECONDS, + is_cloud_superuser=user.email in SUPER_USERS, organization_name=organization_name, ) diff --git a/backend/danswer/server/query_and_chat/token_limit.py b/backend/danswer/server/query_and_chat/token_limit.py index ec94e2ece4d..d439e15a379 100644 --- a/backend/danswer/server/query_and_chat/token_limit.py +++ b/backend/danswer/server/query_and_chat/token_limit.py @@ -21,7 +21,7 @@ from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index bc872b0c423..bd784513898 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -5,7 +5,6 @@ from logging.handlers import RotatingFileHandler from typing import Any -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import DEV_LOGGING_ENABLED from shared_configs.configs import LOG_FILE_NAME from shared_configs.configs import LOG_LEVEL @@ -13,6 +12,7 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SLACK_CHANNEL_ID from shared_configs.configs import TENANT_ID_PREFIX +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logging.addLevelName(logging.INFO + 5, "NOTICE") diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index 18dff6ab064..1ad384555c1 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -1,9 +1,13 @@ from fastapi import Depends from fastapi import HTTPException from fastapi import Request +from fastapi import status from sqlalchemy.orm import Session +from danswer.auth.users import current_admin_user from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import SUPER_CLOUD_API_KEY +from danswer.configs.app_configs import SUPER_USERS from danswer.configs.constants import AuthType from danswer.db.engine import get_session from danswer.db.models import User @@ -68,3 +72,19 @@ def get_default_admin_user_emails_() -> list[str]: if seed_config and seed_config.admin_user_emails: return seed_config.admin_user_emails return [] + + +async def current_cloud_superuser( + request: Request, + user: User | None = Depends(current_admin_user), +) -> User | None: + api_key = request.headers.get("Authorization", "").replace("Bearer ", "") + if api_key != SUPER_CLOUD_API_KEY: + raise HTTPException(status_code=401, detail="Invalid API key") + + if user and user.email not in SUPER_USERS: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. User must be a cloud superuser to perform this action.", + ) + return user diff --git a/backend/ee/danswer/background/celery/apps/primary.py b/backend/ee/danswer/background/celery/apps/primary.py index b9929068688..fecc21b58ef 100644 --- a/backend/ee/danswer/background/celery/apps/primary.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -28,8 +28,8 @@ run_external_group_permission_sync, ) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() diff --git a/backend/ee/danswer/db/query_history.py b/backend/ee/danswer/db/query_history.py index 868afef23ce..b6a79cb7727 100644 --- a/backend/ee/danswer/db/query_history.py +++ b/backend/ee/danswer/db/query_history.py @@ -29,6 +29,7 @@ def fetch_chat_sessions_eagerly_by_time( filters: list[ColumnElement | BinaryExpression] = [ ChatSession.time_created.between(start, end) ] + if initial_id: filters.append(ChatSession.id < initial_id) subquery = ( diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py index f9fe75425e0..2ab9c946df5 100644 --- a/backend/ee/danswer/server/middleware/tenant_tracking.py +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -14,6 +14,7 @@ from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None: @@ -22,11 +23,11 @@ async def set_tenant_id( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: try: - tenant_id = POSTGRES_DEFAULT_SCHEMA - - if MULTI_TENANT: - tenant_id = _get_tenant_id_from_request(request, logger) - + tenant_id = ( + _get_tenant_id_from_request(request, logger) + if MULTI_TENANT + else POSTGRES_DEFAULT_SCHEMA + ) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return await call_next(request) diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index 66485975f31..05d66b4c582 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -2,29 +2,36 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Response +from danswer.auth.users import auth_backend from danswer.auth.users import current_admin_user +from danswer.auth.users import get_jwt_strategy +from danswer.auth.users import get_tenant_id_for_email from danswer.auth.users import User from danswer.configs.app_configs import WEB_DOMAIN from danswer.db.engine import get_session_with_tenant from danswer.db.notification import create_notification +from danswer.db.users import get_user_by_email from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger +from ee.danswer.auth.users import current_cloud_superuser from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY from ee.danswer.server.tenants.access import control_plane_dep from ee.danswer.server.tenants.billing import fetch_billing_information from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information from ee.danswer.server.tenants.models import BillingInformation from ee.danswer.server.tenants.models import CreateTenantRequest +from ee.danswer.server.tenants.models import ImpersonateRequest from ee.danswer.server.tenants.models import ProductGatingRequest from ee.danswer.server.tenants.provisioning import add_users_to_tenant from ee.danswer.server.tenants.provisioning import ensure_schema_exists from ee.danswer.server.tenants.provisioning import run_alembic_migrations from ee.danswer.server.tenants.provisioning import user_owns_a_tenant -from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MULTI_TENANT +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY @@ -132,3 +139,30 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user)) except Exception as e: logger.exception("Failed to create customer portal session") raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/impersonate") +async def impersonate_user( + impersonate_request: ImpersonateRequest, + _: User = Depends(current_cloud_superuser), +) -> Response: + """Allows a cloud superuser to impersonate another user by generating an impersonation JWT token""" + tenant_id = get_tenant_id_for_email(impersonate_request.email) + + with get_session_with_tenant(tenant_id) as tenant_session: + user_to_impersonate = get_user_by_email( + impersonate_request.email, tenant_session + ) + if user_to_impersonate is None: + raise HTTPException(status_code=404, detail="User not found") + token = await get_jwt_strategy().write_token(user_to_impersonate) + + response = await auth_backend.transport.get_login_response(token) + response.set_cookie( + key="fastapiusersauth", + value=token, + httponly=True, + secure=True, + samesite="lax", + ) + return response diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index 30f656c0824..2c1fdbecdb3 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -29,3 +29,7 @@ class BillingInformation(BaseModel): class CheckoutSessionCreationResponse(BaseModel): id: str + + +class ImpersonateRequest(BaseModel): + email: str diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5f25deedfca..d4378251aa5 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -1,4 +1,3 @@ -import contextvars import os from typing import List from urllib.parse import urlparse @@ -134,10 +133,6 @@ def validate_cors_origin(origin: str) -> None: POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public" -CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar( - "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA -) - # Prefix used for all tenant ids TENANT_ID_PREFIX = "tenant_" diff --git a/backend/shared_configs/contextvars.py b/backend/shared_configs/contextvars.py new file mode 100644 index 00000000000..df66b141c6e --- /dev/null +++ b/backend/shared_configs/contextvars.py @@ -0,0 +1,8 @@ +import contextvars + +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + +# Context variable for the current tenant id +CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar( + "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA +) diff --git a/web/src/app/auth/impersonate/page.tsx b/web/src/app/auth/impersonate/page.tsx new file mode 100644 index 00000000000..1a2c77d2cdb --- /dev/null +++ b/web/src/app/auth/impersonate/page.tsx @@ -0,0 +1,132 @@ +"use client"; +import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { useUser } from "@/components/user/UserProvider"; +import { redirect, useRouter } from "next/navigation"; +import { Formik, Form, Field } from "formik"; +import * as Yup from "yup"; +import { usePopup } from "@/components/admin/connectors/Popup"; + +const ImpersonateSchema = Yup.object().shape({ + email: Yup.string().email("Invalid email").required("Required"), + apiKey: Yup.string().required("Required"), +}); + +export default function ImpersonatePage() { + const router = useRouter(); + const { user, isLoadingUser, isCloudSuperuser } = useUser(); + const { popup, setPopup } = usePopup(); + + if (isLoadingUser) { + return null; + } + + if (!user) { + redirect("/auth/login"); + } + + if (!isCloudSuperuser) { + redirect("/search"); + } + + const handleImpersonate = async (values: { + email: string; + apiKey: string; + }) => { + try { + const response = await fetch("/api/tenants/impersonate", { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${values.apiKey}`, + }, + body: JSON.stringify({ email: values.email }), + credentials: "same-origin", + }); + + if (!response.ok) { + const errorData = await response.json(); + setPopup({ + message: errorData.detail || "Failed to impersonate user", + type: "error", + }); + } else { + router.push("/search"); + } + } catch (error) { + setPopup({ + message: + error instanceof Error ? error.message : "Failed to impersonate user", + type: "error", + }); + } + }; + + return ( + + {popup} +
+ +
+ +
+

+ Impersonate User +

+ + + {({ errors, touched }) => ( +
+
+ +
+ {errors.email && touched.email && ( +
+ {errors.email} +
+ )} +
+
+ +
+ +
+ {errors.apiKey && touched.apiKey && ( +
+ {errors.apiKey} +
+ )} +
+
+ + +
+ )} +
+ +
+ Note: This feature is only available for @danswer.ai administrators +
+
+
+ ); +} diff --git a/web/src/app/auth/login/LoginText.tsx b/web/src/app/auth/login/LoginText.tsx index e31aeb81321..7b5eb97fb8e 100644 --- a/web/src/app/auth/login/LoginText.tsx +++ b/web/src/app/auth/login/LoginText.tsx @@ -5,11 +5,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; export const LoginText = () => { const settings = useContext(SettingsContext); - - // if (!settings) { - // throw new Error("SettingsContext is not available"); - // } - return ( <> Log In to{" "} diff --git a/web/src/components/auth/AuthFlowContainer.tsx b/web/src/components/auth/AuthFlowContainer.tsx index 3be441a0a7b..c6790028088 100644 --- a/web/src/components/auth/AuthFlowContainer.tsx +++ b/web/src/components/auth/AuthFlowContainer.tsx @@ -7,7 +7,7 @@ export default function AuthFlowContainer({ }) { return (
-
+
{children}
diff --git a/web/src/components/user/UserProvider.tsx b/web/src/components/user/UserProvider.tsx index 5a3e7cbc099..48ea8826f22 100644 --- a/web/src/components/user/UserProvider.tsx +++ b/web/src/components/user/UserProvider.tsx @@ -11,6 +11,7 @@ interface UserContextType { isAdmin: boolean; isCurator: boolean; refreshUser: () => Promise; + isCloudSuperuser: boolean; } const UserContext = createContext(undefined); @@ -67,6 +68,7 @@ export function UserProvider({ refreshUser, isAdmin: upToDateUser?.role === UserRole.ADMIN, isCurator: upToDateUser?.role === UserRole.CURATOR, + isCloudSuperuser: upToDateUser?.is_cloud_superuser ?? false, }} > {children} diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 25d98a1df38..f532a09ee31 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -42,6 +42,7 @@ export interface User { current_token_created_at?: Date; current_token_expiry_length?: number; oidc_expiry?: Date; + is_cloud_superuser?: boolean; organization_name: string | null; }