Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Dec 15, 2024
1 parent 4f5a2b4 commit 537bf21
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 19 deletions.
2 changes: 2 additions & 0 deletions backend/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT

USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"

REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
Expand Down
97 changes: 78 additions & 19 deletions backend/onyx/db/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import contextlib
import os
import re
import subprocess
import threading
import time
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -49,6 +51,9 @@
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"

# New configuration flag
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"

# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
Expand Down Expand Up @@ -127,6 +132,73 @@ def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None


def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-west-2" # Change as needed
) -> str:
"""
Generate an IAM authentication token using the AWS CLI.
In production, prefer using boto3 for proper error handling.
"""
try:
# Example using aws cli; you must have AWS credentials configured
# and the AWS CLI installed on the environment.
cmd = [
"aws",
"rds",
"generate-db-auth-token",
"--hostname",
host,
"--port",
str(port),
"--username",
user,
"--region",
region,
]
token = subprocess.check_output(cmd, text=True).strip()
return token
except subprocess.CalledProcessError as e:
logger.error(f"Error generating IAM auth token: {e.output}")
raise HTTPException(status_code=500, detail="Failed to generate IAM auth token")


def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
"""
Build a connection string that supports both password and IAM auth modes.
If USE_IAM_AUTH is True, we fetch a token and ignore the POSTGRES_PASSWORD.
"""
if use_iam:
# Generate a token and use that as the password
token = get_iam_auth_token(host=host, port=port, user=user, region=region)
# Include SSL mode requirement for IAM auth
# SSL root cert can be included if needed: sslrootcert=rds-combined-ca-bundle.pem
# This example uses sslmode=require for simplicity.
base_conn_str = (
f"postgresql+{db_api}://{user}:{token}@{host}:{port}/{db}?sslmode=require"
)
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"

if app_name:
# Add application_name parameter
if "?" in base_conn_str:
return f"{base_conn_str}&application_name={app_name}"
else:
return f"{base_conn_str}?application_name={app_name}"
return base_conn_str


class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Expand All @@ -152,7 +224,7 @@ def __init__(self) -> None:
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
Expand Down Expand Up @@ -221,31 +293,18 @@ def get_all_tenant_ids() -> list[str] | list[None]:
return valid_tenants


def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"


def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()


def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
app_name=SqlEngine.get_app_name() + "_async",
use_iam=USE_IAM_AUTH,
)
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
Expand Down

0 comments on commit 537bf21

Please sign in to comment.