Skip to content

Commit

Permalink
Return a ServiceUnavailable Error when a db is not reachable, and
Browse files Browse the repository at this point in the history
retry connecting
  • Loading branch information
chaen committed Nov 2, 2023
1 parent 4066558 commit cb47b37
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 44 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- aiohttp
- aiomysql
- aiosqlite
- asyncache
- azure-core
- cachetools
########
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ ignore_missing_imports = true
module = 'authlib.*'
ignore_missing_imports = true

# https://github.com/hephex/asyncache/pull/18
[[tool.mypy.overrides]]
module = 'asyncache.*'
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = ["-v", "--cov=diracx", "--cov-report=term-missing"]
asyncio_mode = "auto"
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ install_requires =
aiohttp
aiomysql
aiosqlite
asyncache
azure-core
cachetools
m2crypto >=0.38.0
Expand Down
1 change: 1 addition & 0 deletions src/diracx/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, status_code: int, data):

class DiracError(RuntimeError):
http_status_code = status.HTTP_400_BAD_REQUEST
http_headers: dict[str, str] | None = None

def __init__(self, detail: str = "Unknown"):
self.detail = detail
Expand Down
4 changes: 2 additions & 2 deletions src/diracx/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

__all__ = ("sql", "os")
__all__ = ("sql", "os", "exceptions")

from . import os, sql
from . import exceptions, os, sql
2 changes: 2 additions & 0 deletions src/diracx/db/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class DBUnavailable(Exception):
pass
26 changes: 19 additions & 7 deletions src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.db.exceptions import DBUnavailable

OS_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z"

Expand All @@ -24,7 +25,7 @@ class OpenSearchDBError(Exception):
pass


class OpenSearchDBUnavailable(OpenSearchDBError):
class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
pass


Expand Down Expand Up @@ -93,16 +94,27 @@ async def client_context(self) -> AsyncIterator[None]:
"""
assert self._client is None, "client_context cannot be nested"
async with AsyncOpenSearch(**self._connection_kwargs) as self._client:
if not await self._client.ping():
raise OpenSearchDBUnavailable(
f"Failed to connect to {self.__class__.__qualname__}"
)
yield
self._client = None

async def ping(self):
"""
Check whether the connection to the DB is still working.
We could enable the ``pre_ping`` in the engine, but this would
be ran at every query.
"""
if not await self.client.ping():
raise OpenSearchDBUnavailable(
f"Failed to connect to {self.__class__.__qualname__}"
)

async def __aenter__(self):
"""This is entered on every request. It does nothing"""
assert self._client is None, "client_context hasn't been entered"
"""This is entered on every request.
At the moment it does nothing, however, we keep it here
in case we ever want to use OpenSearch equivalent of a transaction
"""

assert self._client is not None, "client_context hasn't been entered"
return self

async def __aexit__(self, exc_type, exc, tb):
Expand Down
34 changes: 31 additions & 3 deletions src/diracx/db/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
from sqlalchemy import DateTime, Enum, MetaData
from sqlalchemy import DateTime, Enum, MetaData, select
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import expression

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
from diracx.core.settings import SqlalchemyDsn
from diracx.db.exceptions import DBUnavailable

if TYPE_CHECKING:
from sqlalchemy.types import TypeEngine
Expand Down Expand Up @@ -66,6 +68,14 @@ def EnumColumn(enum_type, **kwargs):
return Column(Enum(enum_type, native_enum=False, length=16), **kwargs)


class SQLDBError(Exception):
pass


class SQLDBUnavailable(DBUnavailable, SQLDBError):
"""Used whenever we encounter a problem with the B connection"""


class BaseSQLDB(metaclass=ABCMeta):
"""This should be the base class of all the DiracX DBs"""

Expand Down Expand Up @@ -146,7 +156,10 @@ async def engine_context(self) -> AsyncIterator[None]:
"""
assert self._engine is None, "engine_context cannot be nested"

engine = create_async_engine(self._db_url)
# Set the pool_recycle to 30mn
# That should prevent the problem of MySQL expiring connection
# after 60mn by default
engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
self._engine = engine

yield
Expand All @@ -166,8 +179,12 @@ async def __aenter__(self):
This is called by the Dependency mechanism (see ``db_transaction``),
It will create a new connection/transaction for each route call.
"""
assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
try:
self._conn.set(await self.engine.connect().__aenter__())
except Exception as e:
raise SQLDBUnavailable("Cannot connect to DB") from e

self._conn.set(await self.engine.connect().__aenter__())
return self

async def __aexit__(self, exc_type, exc, tb):
Expand All @@ -181,6 +198,17 @@ async def __aexit__(self, exc_type, exc, tb):
await self._conn.get().__aexit__(exc_type, exc, tb)
self._conn.set(None)

async def ping(self):
"""
Check whether the connection to the DB is still working.
We could enable the ``pre_ping`` in the engine, but this would
be ran at every query.
"""
try:
await self.conn.scalar(select(1))
except OperationalError as e:
raise SQLDBUnavailable("Cannot ping the DB") from e

Check warning on line 210 in src/diracx/db/sql/utils.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/db/sql/utils.py#L209-L210

Added lines #L209 - L210 were not covered by tests


def apply_search_filters(table, stmt, search):
# Apply any filters
Expand Down
98 changes: 76 additions & 22 deletions src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,29 @@
import inspect
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import AbstractAsyncContextManager
from functools import partial
from typing import Any, AsyncContextManager, AsyncGenerator, Iterable, TypeVar
from typing import Any, Iterable, TypeVar

import dotenv
from fastapi import APIRouter, Depends, Request
from asyncache import cached
from cachetools import TTLCache
from fastapi import APIRouter, Depends, Request, status
from fastapi.dependencies.models import Dependant
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from fastapi.routing import APIRoute
from pydantic import parse_raw_as

from diracx.core.config import ConfigSource
from diracx.core.exceptions import DiracError, DiracHttpResponse
from diracx.core.exceptions import (
DiracError,
DiracHttpResponse,
)
from diracx.core.extensions import select_from_extension
from diracx.core.utils import dotenv_files_from_environment
from diracx.db.exceptions import DBUnavailable
from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB

Expand All @@ -26,7 +34,8 @@
from .fastapi_classes import DiracFastAPI, DiracxRouter

T = TypeVar("T")
T2 = TypeVar("T2", bound=AsyncContextManager)
T2 = TypeVar("T2", bound=AbstractAsyncContextManager[BaseSQLDB | BaseOSDB])


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,21 +68,33 @@ def create_app_inner(
# Override the configuration source
app.dependency_overrides[ConfigSource.create] = config_source.read_config

fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()
for db_name, db_url in database_urls.items():
sql_db_classes = BaseSQLDB.available_implementations(db_name)
# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)
app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)
try:
sql_db_classes = BaseSQLDB.available_implementations(db_name)

# The first DB is the highest priority one
sql_db = sql_db_classes[0](db_url=db_url)

app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)
app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)

# At least one DB works, so we do not fail the startup
fail_startup = False
except Exception:
logger.exception("Failed to initialize DB %s", db_name)

Check warning on line 94 in src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/routers/__init__.py#L93-L94

Added lines #L93 - L94 were not covered by tests

if fail_startup:
raise Exception("No SQL database could be initialized, aborting")

Check warning on line 97 in src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/routers/__init__.py#L97

Added line #L97 was not covered by tests

# Add the OpenSearch DBs to the application
available_os_db_classes: set[type[BaseOSDB]] = set()
Expand All @@ -87,7 +108,9 @@ def create_app_inner(
for os_db_class in os_db_classes:
assert os_db_class.session not in app.dependency_overrides
available_os_db_classes.add(os_db_class)
app.dependency_overrides[os_db_class.session] = partial(db_session, os_db)
app.dependency_overrides[os_db_class.session] = partial(
db_transaction, os_db
)

# Load the requested routers
routers: dict[str, APIRouter] = {}
Expand Down Expand Up @@ -145,6 +168,7 @@ def create_app_inner(
# Add exception handlers
app.add_exception_handler(DiracError, dirac_error_handler)
app.add_exception_handler(DiracHttpResponse, http_response_handler)
app.add_exception_handler(DBUnavailable, route_unavailable_error_hander)

# TODO: remove the CORSMiddleware once we figure out how to launch
# diracx and diracx-web under the same origin
Expand Down Expand Up @@ -199,14 +223,24 @@ def create_app() -> DiracFastAPI:

def dirac_error_handler(request: Request, exc: DiracError) -> Response:
return JSONResponse(
status_code=exc.http_status_code, content={"detail": exc.detail}
status_code=exc.http_status_code,
content={"detail": exc.detail},
headers=exc.http_headers,
)


def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response:
return JSONResponse(status_code=exc.status_code, content=exc.data)


def route_unavailable_error_hander(request: Request, exc: DBUnavailable):
return JSONResponse(

Check warning on line 237 in src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/routers/__init__.py#L237

Added line #L237 was not covered by tests
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
headers={"Retry-After": "10"},
content={"detail": str(exc.args)},
)


def find_dependents(
obj: APIRouter | Iterable[Dependant], cls: type[T]
) -> Iterable[type[T]]:
Expand All @@ -225,10 +259,30 @@ def find_dependents(
yield from find_dependents(dependency.dependencies, cls)


_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10)


@cached(_db_alive_cache)
async def is_db_unavailable(db: BaseSQLDB) -> str:
"""Cache the result of pinging the DB
(exceptions are not cachable)
"""
try:
await db.ping()
return ""
except DBUnavailable as e:
return e.args[0]

Check warning on line 274 in src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/routers/__init__.py#L272-L274

Added lines #L272 - L274 were not covered by tests


async def db_transaction(db: T2) -> AsyncGenerator[T2, None]:
"""
Initiate a DB transaction.
"""

# Entering the context already triggers a connection to the DB
# that may fail
async with db:
# Check whether the connection still works before executing the query
if reason := await is_db_unavailable(db):
raise DBUnavailable(reason)

Check warning on line 287 in src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/routers/__init__.py#L287

Added line #L287 was not covered by tests
yield db


async def db_session(db: T2) -> AsyncGenerator[T2, None]:
yield db
17 changes: 7 additions & 10 deletions tests/db/opensearch/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

async def _ensure_db_unavailable(db: DummyOSDB):
"""Helper function which raises an exception if we manage to connect to the DB."""
# Normally we would use "async with db.client_context()" but here
# __aenter__ is used explicitly to ensure the exception is raised
# while entering the context manager
acm = db.client_context()
with pytest.raises(OpenSearchDBUnavailable):
await acm.__aenter__()
async with db.client_context():
async with db:
with pytest.raises(OpenSearchDBUnavailable):
await db.ping()


async def test_connection(dummy_opensearch_db: DummyOSDB):
Expand Down Expand Up @@ -85,10 +83,9 @@ async def test_sanity_checks(opensearch_conn_kwargs):
db = DummyOSDB(opensearch_conn_kwargs)
# Check that the client is not available before entering the context manager
with pytest.raises(RuntimeError):
await db.client.ping()
await db.ping()

# It shouldn't be possible to enter the context manager twice
async with db.client_context():
assert await db.client.ping()
with pytest.raises(AssertionError):
await db.__aenter__()
async with db:
await db.ping()
9 changes: 9 additions & 0 deletions tests/db/test_dummyDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.db.sql.dummy.db import DummyDB
from diracx.db.sql.utils import SQLDBUnavailable

# Each DB test class must defined a fixture looking like this one
# It allows to get an instance of an in memory DB,
Expand Down Expand Up @@ -67,3 +68,11 @@ async def test_insert_and_summary(dummy_db: DummyDB):
}
],
)


async def test_bad_connection():
dummy_db = DummyDB("mysql+aiomysql://tata:[email protected]:3306/name")
async with dummy_db.engine_context():
with pytest.raises(SQLDBUnavailable):
async with dummy_db:
dummy_db.ping()
Loading

0 comments on commit cb47b37

Please sign in to comment.