Skip to content

Commit

Permalink
Return a ServiceUnavailable Error when an SQL db is not reachable, and
Browse files Browse the repository at this point in the history
retry connecting
  • Loading branch information
chaen committed Oct 31, 2023
1 parent b726d2a commit 4f50e7d
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 20 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- aiohttp
- aiomysql
- aiosqlite
- asyncache
- azure-core
- cachetools
########
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ ignore_missing_imports = true
module = 'authlib.*'
ignore_missing_imports = true

# https://github.com/hephex/asyncache/pull/18
[[tool.mypy.overrides]]
module = 'asyncache.*'


[tool.pytest.ini_options]
addopts = ["-v", "--cov=diracx", "--cov-report=term-missing"]
asyncio_mode = "auto"
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
aiohttp
aiomysql
aiosqlite
asyncache
azure-core
cachetools
m2crypto >=0.38.0
Expand Down
8 changes: 8 additions & 0 deletions src/diracx/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, status_code: int, data):

class DiracError(RuntimeError):
http_status_code = status.HTTP_400_BAD_REQUEST
http_headers: dict[str, str] | None = None

def __init__(self, detail: str = "Unknown"):
self.detail = detail
Expand Down Expand Up @@ -42,3 +43,10 @@ class JobNotFound(Exception):
def __init__(self, job_id: int):
self.job_id: int = job_id
super().__init__(f"Job {job_id} not found")


class RouteUnavailableError(DiracError):
""" "The route is not available (bad init)"""

http_status_code = status.HTTP_503_SERVICE_UNAVAILABLE
http_headers = {"Retry-After": "10"}
34 changes: 31 additions & 3 deletions src/diracx/db/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
from sqlalchemy import DateTime, Enum, MetaData
from sqlalchemy import DateTime, Enum, MetaData, select
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -66,6 +67,14 @@ def EnumColumn(enum_type, **kwargs):
return Column(Enum(enum_type, native_enum=False, length=16), **kwargs)


class SQLDBError(Exception):
pass


class SQLDBConnectionError(SQLDBError):
"""Used whenever we encounter a problem with the B connection"""


class BaseSQLDB(metaclass=ABCMeta):
"""This should be the base class of all the DiracX DBs"""

Expand Down Expand Up @@ -146,7 +155,10 @@ async def engine_context(self) -> AsyncIterator[None]:
"""
assert self._engine is None, "engine_context cannot be nested"

engine = create_async_engine(self._db_url)
# Set the pool_recycle to 30mn
# That should prevent the problem of MySQL expiring connection
# after 60mn by default
engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
self._engine = engine

yield
Expand All @@ -166,8 +178,12 @@ async def __aenter__(self):
This is called by the Dependency mechanism (see ``db_transaction``),
It will create a new connection/transaction for each route call.
"""
assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
try:
self._conn.set(await self.engine.connect().__aenter__())
except Exception as e:
raise SQLDBConnectionError("Cannot connect to DB") from e

self._conn.set(await self.engine.connect().__aenter__())
return self

async def __aexit__(self, exc_type, exc, tb):
Expand All @@ -181,6 +197,18 @@ async def __aexit__(self, exc_type, exc, tb):
await self._conn.get().__aexit__(exc_type, exc, tb)
self._conn.set(None)

async def ping(self) -> tuple[bool, str]:
"""
Check whether the connection to the DB is still working.
We could enable the ``pre_ping`` in the engine, but this would
be ran at every query.
"""
try:
await self.conn.scalar(select(1))
return True, ""
except OperationalError as e:
return False, repr(e)


def apply_search_filters(table, stmt, search):
# Apply any filters
Expand Down
81 changes: 64 additions & 17 deletions src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, AsyncContextManager, AsyncGenerator, Iterable, TypeVar

import dotenv
from asyncache import cached
from cachetools import TTLCache
from fastapi import APIRouter, Depends, Request
from fastapi.dependencies.models import Dependant
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -15,18 +17,24 @@
from pydantic import parse_raw_as

from diracx.core.config import ConfigSource
from diracx.core.exceptions import DiracError, DiracHttpResponse
from diracx.core.exceptions import (
DiracError,
DiracHttpResponse,
RouteUnavailableError,
)
from diracx.core.extensions import select_from_extension
from diracx.core.utils import dotenv_files_from_environment
from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
from diracx.db.sql.utils import BaseSQLDB, SQLDBConnectionError

from ..core.settings import ServiceSettingsBase
from .auth import verify_dirac_access_token
from .fastapi_classes import DiracFastAPI, DiracxRouter

T = TypeVar("T")
T2 = TypeVar("T2", bound=AsyncContextManager)
T3 = TypeVar("T3", bound=BaseSQLDB)


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,21 +67,33 @@ def create_app_inner(
# Override the configuration source
app.dependency_overrides[ConfigSource.create] = config_source.read_config

fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()
for db_name, db_url in database_urls.items():
sql_db_classes = BaseSQLDB.available_implementations(db_name)
# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)
app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)
try:
sql_db_classes = BaseSQLDB.available_implementations(db_name)

# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)

app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)

# At least one DB works, so we do not fail the startup
fail_startup = False
except Exception:
logger.exception(f"Failed to initialize DB {db_name}, {db_url}")

if fail_startup:
raise Exception("No SQL database could be initialized, aborting")

# Add the OpenSearch DBs to the application
available_os_db_classes: set[type[BaseOSDB]] = set()
Expand Down Expand Up @@ -199,7 +219,9 @@ def create_app() -> DiracFastAPI:

def dirac_error_handler(request: Request, exc: DiracError) -> Response:
return JSONResponse(
status_code=exc.http_status_code, content={"detail": exc.detail}
status_code=exc.http_status_code,
content={"detail": exc.detail},
headers=exc.http_headers,
)


Expand All @@ -225,9 +247,34 @@ def find_dependents(
yield from find_dependents(dependency.dependencies, cls)


_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10)


@cached(_db_alive_cache)
async def is_db_alive(db: T3):
"""Cache the result of pinging the DB"""
is_alive, reason = await db.ping()
logger.debug("Pinged db %s, is_alive %s, reason %s", type(db), is_alive, reason)
if not is_alive:
raise SQLDBConnectionError(reason)


async def db_transaction(db: T2) -> AsyncGenerator[T2, None]:
async with db:
yield db
"""
Initiate a DB transaction.
:raises: RouteUnavailableError in case the connection to the DB fails
"""

# Entering the context already triggers a connection to the DB
# that may fail, hence the try/except
try:
async with db:
# Check whether the connection still works before executing the query
await is_db_alive(db)
yield db
except SQLDBConnectionError as e:
raise RouteUnavailableError(repr(e)) from e


async def db_session(db: T2) -> AsyncGenerator[T2, None]:
Expand Down
9 changes: 9 additions & 0 deletions tests/db/test_dummyDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.db.sql.dummy.db import DummyDB
from diracx.db.sql.utils import SQLDBConnectionError

# Each DB test class must defined a fixture looking like this one
# It allows to get an instance of an in memory DB,
Expand Down Expand Up @@ -67,3 +68,11 @@ async def test_insert_and_summary(dummy_db: DummyDB):
}
],
)


async def test_bad_connection():
dummy_db = DummyDB("mysql+aiomysql://tata:[email protected]:3306/name")
async with dummy_db.engine_context():
with pytest.raises(SQLDBConnectionError):
async with dummy_db:
dummy_db.ping()
16 changes: 16 additions & 0 deletions tests/routers/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_openapi(test_client):
r = test_client.get("/api/openapi.json")
assert r.status_code == 200
Expand All @@ -15,3 +18,16 @@ def test_installation_metadata(test_client):

assert r.status_code == 200
assert r.json()


@pytest.mark.xfail(reason="TODO")
def test_unavailable_db(monkeypatch, test_client):
# TODO
# That does not work because test_client is already initialized
monkeypatch.setenv(
"DIRACX_DB_URL_JOBDB", "mysql+aiomysql://tata:[email protected]:3306/name"
)

r = test_client.get("/api/job/123")
assert r.status_code == 503
assert r.json()

0 comments on commit 4f50e7d

Please sign in to comment.