Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth specific rate limiting #3463

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backend/ee/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_router_with_global_prefix_prepended
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
Expand Down Expand Up @@ -75,6 +76,7 @@ def get_application() -> FastAPI:
),
prefix="/auth/oauth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

# Need basic auth router for `logout` endpoint
Expand All @@ -83,6 +85,7 @@ def get_application() -> FastAPI:
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

if AUTH_TYPE == AuthType.OIDC:
Expand All @@ -98,6 +101,7 @@ def get_application() -> FastAPI:
),
prefix="/auth/oidc",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

# need basic auth router for `logout` endpoint
Expand All @@ -106,10 +110,15 @@ def get_application() -> FastAPI:
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router)
include_router_with_global_prefix_prepended(
application,
saml_router,
dependencies=get_auth_rate_limiters(),
)

# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)
Expand Down
1 change: 1 addition & 0 deletions backend/model_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""

for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():
Expand Down
19 changes: 19 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,25 @@
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""

# Rate limiting for auth endpoints


RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass

RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass

# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))

Expand Down
15 changes: 15 additions & 0 deletions backend/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
Expand Down Expand Up @@ -194,8 +197,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_multitenant_onyx()

optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

# Set up rate limiter
await setup_limiter()

yield

# Close rate limiter
await close_limiter()


def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
Expand Down Expand Up @@ -288,32 +298,37 @@ def get_application() -> FastAPI:
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)

include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
dependencies=get_auth_rate_limiters(),
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
dependencies=get_auth_rate_limiters(),
)

if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
Expand Down
46 changes: 46 additions & 0 deletions backend/onyx/server/middleware/rate_limiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections.abc import Callable
from typing import List

from fastapi import Depends
from fastapi import Request
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
from redis import asyncio as aioredis

from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PASSWORD
from onyx.configs.app_configs import REDIS_PORT


async def setup_limiter() -> None:
redis = await aioredis.from_url(
f"redis://{REDIS_HOST}:{REDIS_PORT}", password=REDIS_PASSWORD
)
await FastAPILimiter.init(redis)


async def close_limiter() -> None:
await FastAPILimiter.close()


def rate_limit_key(request: Request) -> str:
return (
request.client.host if request.client else "unknown"
) # Use IP address for unauthenticated users


# Custom rate limiter that uses the client's IP address
def get_auth_rate_limiters() -> List[Callable]:
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
return []

return [
Depends(
RateLimiter(
times=RATE_LIMIT_MAX_REQUESTS,
seconds=RATE_LIMIT_WINDOW_SECONDS,
)
)
]
3 changes: 2 additions & 1 deletion backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,5 @@ stripe==10.12.0
urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6
7 changes: 7 additions & 0 deletions web/src/app/auth/login/EmailPasswordForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ export function EmailPasswordForm({
errorMsg =
"An account already exists with the specified email.";
}
if (response.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({
type: "error",
message: `Failed to sign up - ${errorMsg}`,
});
setIsWorking(false);
return;
}
}
Expand All @@ -87,6 +91,9 @@ export function EmailPasswordForm({
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
errorMsg = "Create an account to set a password";
}
if (loginResponse.status === 429) {
errorMsg = "Too many requests. Please try again later.";
}
setPopup({
type: "error",
message: `Failed to login - ${errorMsg}`,
Expand Down
Loading