From cb47b379dc1bad37ad25e41eb61c66fee95d2e52 Mon Sep 17 00:00:00 2001 From: Christophe Haen Date: Fri, 27 Oct 2023 11:21:21 +0200 Subject: [PATCH] Return a ServiceUnavailable Error when a db is not reachable, and retry connecting --- environment.yml | 1 + pyproject.toml | 5 ++ setup.cfg | 1 + src/diracx/core/exceptions.py | 1 + src/diracx/db/__init__.py | 4 +- src/diracx/db/exceptions.py | 2 + src/diracx/db/os/utils.py | 26 +++++-- src/diracx/db/sql/utils.py | 34 ++++++++- src/diracx/routers/__init__.py | 98 ++++++++++++++++++++------ tests/db/opensearch/test_connection.py | 17 ++--- tests/db/test_dummyDB.py | 9 +++ tests/routers/test_generic.py | 16 +++++ 12 files changed, 170 insertions(+), 44 deletions(-) create mode 100644 src/diracx/db/exceptions.py diff --git a/environment.yml b/environment.yml index 8a11e9920..7834cef17 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - aiohttp - aiomysql - aiosqlite + - asyncache - azure-core - cachetools ######## diff --git a/pyproject.toml b/pyproject.toml index 599daabdd..d792de200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,11 @@ ignore_missing_imports = true module = 'authlib.*' ignore_missing_imports = true +# https://github.com/hephex/asyncache/pull/18 +[[tool.mypy.overrides]] +module = 'asyncache.*' +ignore_missing_imports = true + [tool.pytest.ini_options] addopts = ["-v", "--cov=diracx", "--cov-report=term-missing"] asyncio_mode = "auto" diff --git a/setup.cfg b/setup.cfg index 0006c3161..2c286a5a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ install_requires = aiohttp aiomysql aiosqlite + asyncache azure-core cachetools m2crypto >=0.38.0 diff --git a/src/diracx/core/exceptions.py b/src/diracx/core/exceptions.py index efb00f4c6..4eef53ece 100644 --- a/src/diracx/core/exceptions.py +++ b/src/diracx/core/exceptions.py @@ -9,6 +9,7 @@ def __init__(self, status_code: int, data): class DiracError(RuntimeError): http_status_code = status.HTTP_400_BAD_REQUEST + http_headers: dict[str, str] | None = None def __init__(self, detail: str = "Unknown"): self.detail = detail diff --git a/src/diracx/db/__init__.py b/src/diracx/db/__init__.py index 6390e3fb5..824eb5624 100644 --- a/src/diracx/db/__init__.py +++ b/src/diracx/db/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -__all__ = ("sql", "os") +__all__ = ("sql", "os", "exceptions") -from . import os, sql +from . import exceptions, os, sql diff --git a/src/diracx/db/exceptions.py b/src/diracx/db/exceptions.py new file mode 100644 index 000000000..ca0cf0ecc --- /dev/null +++ b/src/diracx/db/exceptions.py @@ -0,0 +1,2 @@ +class DBUnavailable(Exception): + pass diff --git a/src/diracx/db/os/utils.py b/src/diracx/db/os/utils.py index 2eb842bc0..b5feed959 100644 --- a/src/diracx/db/os/utils.py +++ b/src/diracx/db/os/utils.py @@ -14,6 +14,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension +from diracx.db.exceptions import DBUnavailable OS_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" @@ -24,7 +25,7 @@ class OpenSearchDBError(Exception): pass -class OpenSearchDBUnavailable(OpenSearchDBError): +class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError): pass @@ -93,16 +94,27 @@ async def client_context(self) -> AsyncIterator[None]: """ assert self._client is None, "client_context cannot be nested" async with AsyncOpenSearch(**self._connection_kwargs) as self._client: - if not await self._client.ping(): - raise OpenSearchDBUnavailable( - f"Failed to connect to {self.__class__.__qualname__}" - ) yield self._client = None + async def ping(self): + """ + Check whether the connection to the DB is still working. + We could enable the ``pre_ping`` in the engine, but this would + be ran at every query. + """ + if not await self.client.ping(): + raise OpenSearchDBUnavailable( + f"Failed to connect to {self.__class__.__qualname__}" + ) + async def __aenter__(self): - """This is entered on every request. It does nothing""" - assert self._client is None, "client_context hasn't been entered" + """This is entered on every request. + At the moment it does nothing, however, we keep it here + in case we ever want to use OpenSearch equivalent of a transaction + """ + + assert self._client is not None, "client_context hasn't been entered" return self async def __aexit__(self, exc_type, exc, tb): diff --git a/src/diracx/db/sql/utils.py b/src/diracx/db/sql/utils.py index 60e15183f..d729436fc 100644 --- a/src/diracx/db/sql/utils.py +++ b/src/diracx/db/sql/utils.py @@ -13,7 +13,8 @@ from pydantic import parse_obj_as from sqlalchemy import Column as RawColumn -from sqlalchemy import DateTime, Enum, MetaData +from sqlalchemy import DateTime, Enum, MetaData, select +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import expression @@ -21,6 +22,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension from diracx.core.settings import SqlalchemyDsn +from diracx.db.exceptions import DBUnavailable if TYPE_CHECKING: from sqlalchemy.types import TypeEngine @@ -66,6 +68,14 @@ def EnumColumn(enum_type, **kwargs): return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +class SQLDBError(Exception): + pass + + +class SQLDBUnavailable(DBUnavailable, SQLDBError): + """Used whenever we encounter a problem with the B connection""" + + class BaseSQLDB(metaclass=ABCMeta): """This should be the base class of all the DiracX DBs""" @@ -146,7 +156,10 @@ async def engine_context(self) -> AsyncIterator[None]: """ assert self._engine is None, "engine_context cannot be nested" - engine = create_async_engine(self._db_url) + # Set the pool_recycle to 30mn + # That should prevent the problem of MySQL expiring connection + # after 60mn by default + engine = create_async_engine(self._db_url, pool_recycle=60 * 30) self._engine = engine yield @@ -166,8 +179,12 @@ async def __aenter__(self): This is called by the Dependency mechanism (see ``db_transaction``), It will create a new connection/transaction for each route call. """ + assert self._conn.get() is None, "BaseSQLDB context cannot be nested" + try: + self._conn.set(await self.engine.connect().__aenter__()) + except Exception as e: + raise SQLDBUnavailable("Cannot connect to DB") from e - self._conn.set(await self.engine.connect().__aenter__()) return self async def __aexit__(self, exc_type, exc, tb): @@ -181,6 +198,17 @@ async def __aexit__(self, exc_type, exc, tb): await self._conn.get().__aexit__(exc_type, exc, tb) self._conn.set(None) + async def ping(self): + """ + Check whether the connection to the DB is still working. + We could enable the ``pre_ping`` in the engine, but this would + be ran at every query. + """ + try: + await self.conn.scalar(select(1)) + except OperationalError as e: + raise SQLDBUnavailable("Cannot ping the DB") from e + def apply_search_filters(table, stmt, search): # Apply any filters diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index 41356259b..0ef13ed5d 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -3,11 +3,15 @@ import inspect import logging import os +from collections.abc import AsyncGenerator +from contextlib import AbstractAsyncContextManager from functools import partial -from typing import Any, AsyncContextManager, AsyncGenerator, Iterable, TypeVar +from typing import Any, Iterable, TypeVar import dotenv -from fastapi import APIRouter, Depends, Request +from asyncache import cached +from cachetools import TTLCache +from fastapi import APIRouter, Depends, Request, status from fastapi.dependencies.models import Dependant from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response @@ -15,9 +19,13 @@ from pydantic import parse_raw_as from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponse +from diracx.core.exceptions import ( + DiracError, + DiracHttpResponse, +) from diracx.core.extensions import select_from_extension from diracx.core.utils import dotenv_files_from_environment +from diracx.db.exceptions import DBUnavailable from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB @@ -26,7 +34,8 @@ from .fastapi_classes import DiracFastAPI, DiracxRouter T = TypeVar("T") -T2 = TypeVar("T2", bound=AsyncContextManager) +T2 = TypeVar("T2", bound=AbstractAsyncContextManager[BaseSQLDB | BaseOSDB]) + logger = logging.getLogger(__name__) @@ -59,21 +68,33 @@ def create_app_inner( # Override the configuration source app.dependency_overrides[ConfigSource.create] = config_source.read_config + fail_startup = True # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() for db_name, db_url in database_urls.items(): - sql_db_classes = BaseSQLDB.available_implementations(db_name) - # The first DB is the highest priority one - sql_db = sql_db_classes[0](db_url=db_url) - app.lifetime_functions.append(sql_db.engine_context) - # Add overrides for all the DB classes, including those from extensions - # This means vanilla DiracX routers get an instance of the extension's DB - for sql_db_class in sql_db_classes: - assert sql_db_class.transaction not in app.dependency_overrides - available_sql_db_classes.add(sql_db_class) - app.dependency_overrides[sql_db_class.transaction] = partial( - db_transaction, sql_db - ) + try: + sql_db_classes = BaseSQLDB.available_implementations(db_name) + + # The first DB is the highest priority one + sql_db = sql_db_classes[0](db_url=db_url) + + app.lifetime_functions.append(sql_db.engine_context) + # Add overrides for all the DB classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's DB + for sql_db_class in sql_db_classes: + assert sql_db_class.transaction not in app.dependency_overrides + available_sql_db_classes.add(sql_db_class) + app.dependency_overrides[sql_db_class.transaction] = partial( + db_transaction, sql_db + ) + + # At least one DB works, so we do not fail the startup + fail_startup = False + except Exception: + logger.exception("Failed to initialize DB %s", db_name) + + if fail_startup: + raise Exception("No SQL database could be initialized, aborting") # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() @@ -87,7 +108,9 @@ def create_app_inner( for os_db_class in os_db_classes: assert os_db_class.session not in app.dependency_overrides available_os_db_classes.add(os_db_class) - app.dependency_overrides[os_db_class.session] = partial(db_session, os_db) + app.dependency_overrides[os_db_class.session] = partial( + db_transaction, os_db + ) # Load the requested routers routers: dict[str, APIRouter] = {} @@ -145,6 +168,7 @@ def create_app_inner( # Add exception handlers app.add_exception_handler(DiracError, dirac_error_handler) app.add_exception_handler(DiracHttpResponse, http_response_handler) + app.add_exception_handler(DBUnavailable, route_unavailable_error_hander) # TODO: remove the CORSMiddleware once we figure out how to launch # diracx and diracx-web under the same origin @@ -199,7 +223,9 @@ def create_app() -> DiracFastAPI: def dirac_error_handler(request: Request, exc: DiracError) -> Response: return JSONResponse( - status_code=exc.http_status_code, content={"detail": exc.detail} + status_code=exc.http_status_code, + content={"detail": exc.detail}, + headers=exc.http_headers, ) @@ -207,6 +233,14 @@ def http_response_handler(request: Request, exc: DiracHttpResponse) -> Response: return JSONResponse(status_code=exc.status_code, content=exc.data) +def route_unavailable_error_hander(request: Request, exc: DBUnavailable): + return JSONResponse( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + headers={"Retry-After": "10"}, + content={"detail": str(exc.args)}, + ) + + def find_dependents( obj: APIRouter | Iterable[Dependant], cls: type[T] ) -> Iterable[type[T]]: @@ -225,10 +259,30 @@ def find_dependents( yield from find_dependents(dependency.dependencies, cls) +_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10) + + +@cached(_db_alive_cache) +async def is_db_unavailable(db: BaseSQLDB) -> str: + """Cache the result of pinging the DB + (exceptions are not cachable) + """ + try: + await db.ping() + return "" + except DBUnavailable as e: + return e.args[0] + + async def db_transaction(db: T2) -> AsyncGenerator[T2, None]: + """ + Initiate a DB transaction. + """ + + # Entering the context already triggers a connection to the DB + # that may fail async with db: + # Check whether the connection still works before executing the query + if reason := await is_db_unavailable(db): + raise DBUnavailable(reason) yield db - - -async def db_session(db: T2) -> AsyncGenerator[T2, None]: - yield db diff --git a/tests/db/opensearch/test_connection.py b/tests/db/opensearch/test_connection.py index 728faec5c..953cb6d19 100644 --- a/tests/db/opensearch/test_connection.py +++ b/tests/db/opensearch/test_connection.py @@ -9,12 +9,10 @@ async def _ensure_db_unavailable(db: DummyOSDB): """Helper function which raises an exception if we manage to connect to the DB.""" - # Normally we would use "async with db.client_context()" but here - # __aenter__ is used explicitly to ensure the exception is raised - # while entering the context manager - acm = db.client_context() - with pytest.raises(OpenSearchDBUnavailable): - await acm.__aenter__() + async with db.client_context(): + async with db: + with pytest.raises(OpenSearchDBUnavailable): + await db.ping() async def test_connection(dummy_opensearch_db: DummyOSDB): @@ -85,10 +83,9 @@ async def test_sanity_checks(opensearch_conn_kwargs): db = DummyOSDB(opensearch_conn_kwargs) # Check that the client is not available before entering the context manager with pytest.raises(RuntimeError): - await db.client.ping() + await db.ping() # It shouldn't be possible to enter the context manager twice async with db.client_context(): - assert await db.client.ping() - with pytest.raises(AssertionError): - await db.__aenter__() + async with db: + await db.ping() diff --git a/tests/db/test_dummyDB.py b/tests/db/test_dummyDB.py index 029971ae5..767b2593d 100644 --- a/tests/db/test_dummyDB.py +++ b/tests/db/test_dummyDB.py @@ -7,6 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.db.sql.dummy.db import DummyDB +from diracx.db.sql.utils import SQLDBUnavailable # Each DB test class must defined a fixture looking like this one # It allows to get an instance of an in memory DB, @@ -67,3 +68,11 @@ async def test_insert_and_summary(dummy_db: DummyDB): } ], ) + + +async def test_bad_connection(): + dummy_db = DummyDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") + async with dummy_db.engine_context(): + with pytest.raises(SQLDBUnavailable): + async with dummy_db: + dummy_db.ping() diff --git a/tests/routers/test_generic.py b/tests/routers/test_generic.py index 4fe306d9c..99db2352c 100644 --- a/tests/routers/test_generic.py +++ b/tests/routers/test_generic.py @@ -1,3 +1,6 @@ +import pytest + + def test_openapi(test_client): r = test_client.get("/api/openapi.json") assert r.status_code == 200 @@ -15,3 +18,16 @@ def test_installation_metadata(test_client): assert r.status_code == 200 assert r.json() + + +@pytest.mark.xfail(reason="TODO") +def test_unavailable_db(monkeypatch, test_client): + # TODO + # That does not work because test_client is already initialized + monkeypatch.setenv( + "DIRACX_DB_URL_JOBDB", "mysql+aiomysql://tata:yoyo@dbod.cern.ch:3306/name" + ) + + r = test_client.get("/api/job/123") + assert r.status_code == 503 + assert r.json()