diff --git a/environment.yml b/environment.yml index 8a11e9920..7834cef17 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - aiohttp - aiomysql - aiosqlite + - asyncache - azure-core - cachetools ######## diff --git a/pyproject.toml b/pyproject.toml index 599daabdd..5f87b1ad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,11 @@ ignore_missing_imports = true module = 'authlib.*' ignore_missing_imports = true +# https://github.com/hephex/asyncache/pull/18 +[[tool.mypy.overrides]] +module = 'asyncache.*' + + [tool.pytest.ini_options] addopts = ["-v", "--cov=diracx", "--cov-report=term-missing"] asyncio_mode = "auto" diff --git a/setup.cfg b/setup.cfg index 0006c3161..2c286a5a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,7 @@ install_requires = aiohttp aiomysql aiosqlite + asyncache azure-core cachetools m2crypto >=0.38.0 diff --git a/src/diracx/core/exceptions.py b/src/diracx/core/exceptions.py index efb00f4c6..2a1d0846a 100644 --- a/src/diracx/core/exceptions.py +++ b/src/diracx/core/exceptions.py @@ -9,6 +9,7 @@ def __init__(self, status_code: int, data): class DiracError(RuntimeError): http_status_code = status.HTTP_400_BAD_REQUEST + http_headers: dict[str, str] | None = None def __init__(self, detail: str = "Unknown"): self.detail = detail @@ -42,3 +43,10 @@ class JobNotFound(Exception): def __init__(self, job_id: int): self.job_id: int = job_id super().__init__(f"Job {job_id} not found") + + +class RouteUnavailableError(DiracError): + """ "The route is not available (bad init)""" + + http_status_code = status.HTTP_503_SERVICE_UNAVAILABLE + http_headers = {"Retry-After": "10"} diff --git a/src/diracx/db/sql/utils.py b/src/diracx/db/sql/utils.py index 60e15183f..8beb19b35 100644 --- a/src/diracx/db/sql/utils.py +++ b/src/diracx/db/sql/utils.py @@ -13,7 +13,8 @@ from pydantic import parse_obj_as from sqlalchemy import Column as RawColumn -from sqlalchemy import DateTime, Enum, MetaData +from sqlalchemy import DateTime, Enum, MetaData, select +from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import expression @@ -66,6 +67,14 @@ def EnumColumn(enum_type, **kwargs): return Column(Enum(enum_type, native_enum=False, length=16), **kwargs) +class SQLDBError(Exception): + pass + + +class SQLDBConnectionError(SQLDBError): + """Used whenever we encounter a problem with the B connection""" + + class BaseSQLDB(metaclass=ABCMeta): """This should be the base class of all the DiracX DBs""" @@ -146,7 +155,10 @@ async def engine_context(self) -> AsyncIterator[None]: """ assert self._engine is None, "engine_context cannot be nested" - engine = create_async_engine(self._db_url) + # Set the pool_recycle to 30mn + # That should prevent the problem of MySQL expiring connection + # after 60mn by default + engine = create_async_engine(self._db_url, pool_recycle=60 * 30) self._engine = engine yield @@ -166,8 +178,12 @@ async def __aenter__(self): This is called by the Dependency mechanism (see ``db_transaction``), It will create a new connection/transaction for each route call. """ + assert self._conn.get() is None, "BaseSQLDB context cannot be nested" + try: + self._conn.set(await self.engine.connect().__aenter__()) + except Exception as e: + raise SQLDBConnectionError("Cannot connect to DB") from e - self._conn.set(await self.engine.connect().__aenter__()) return self async def __aexit__(self, exc_type, exc, tb): @@ -181,6 +197,18 @@ async def __aexit__(self, exc_type, exc, tb): await self._conn.get().__aexit__(exc_type, exc, tb) self._conn.set(None) + async def ping(self) -> tuple[bool, str]: + """ + Check whether the connection to the DB is still working. + We could enable the ``pre_ping`` in the engine, but this would + be ran at every query. + """ + try: + await self.conn.scalar(select(1)) + return True, "" + except OperationalError as e: + return False, repr(e) + def apply_search_filters(table, stmt, search): # Apply any filters diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index 41356259b..1a0808b44 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -7,6 +7,8 @@ from typing import Any, AsyncContextManager, AsyncGenerator, Iterable, TypeVar import dotenv +from asyncache import cached +from cachetools import TTLCache from fastapi import APIRouter, Depends, Request from fastapi.dependencies.models import Dependant from fastapi.middleware.cors import CORSMiddleware @@ -15,11 +17,15 @@ from pydantic import parse_raw_as from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponse +from diracx.core.exceptions import ( + DiracError, + DiracHttpResponse, + RouteUnavailableError, +) from diracx.core.extensions import select_from_extension from diracx.core.utils import dotenv_files_from_environment from diracx.db.os.utils import BaseOSDB -from diracx.db.sql.utils import BaseSQLDB +from diracx.db.sql.utils import BaseSQLDB, SQLDBConnectionError from ..core.settings import ServiceSettingsBase from .auth import verify_dirac_access_token @@ -27,6 +33,8 @@ T = TypeVar("T") T2 = TypeVar("T2", bound=AsyncContextManager) +T3 = TypeVar("T3", bound=BaseSQLDB) + logger = logging.getLogger(__name__) @@ -59,21 +67,33 @@ def create_app_inner( # Override the configuration source app.dependency_overrides[ConfigSource.create] = config_source.read_config + fail_startup = True # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() for db_name, db_url in database_urls.items(): - sql_db_classes = BaseSQLDB.available_implementations(db_name) - # The first DB is the highest priority one - sql_db = sql_db_classes[0](db_url=db_url) - app.lifetime_functions.append(sql_db.engine_context) - # Add overrides for all the DB classes, including those from extensions - # This means vanilla DiracX routers get an instance of the extension's DB - for sql_db_class in sql_db_classes: - assert sql_db_class.transaction not in app.dependency_overrides - available_sql_db_classes.add(sql_db_class) - app.dependency_overrides[sql_db_class.transaction] = partial( - db_transaction, sql_db - ) + try: + sql_db_classes = BaseSQLDB.available_implementations(db_name) + + # The first DB is the highest priority one + sql_db = sql_db_classes[0](db_url=db_url) + + app.lifetime_functions.append(sql_db.engine_context) + # Add overrides for all the DB classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's DB + for sql_db_class in sql_db_classes: + assert sql_db_class.transaction not in app.dependency_overrides + available_sql_db_classes.add(sql_db_class) + app.dependency_overrides[sql_db_class.transaction] = partial( + db_transaction, sql_db + ) + + # At least one DB works, so we do not fail the startup + fail_startup = False + except Exception: + logger.exception(f"Failed to initialize DB {db_name}, {db_url}") + + if fail_startup: + raise Exception("No SQL database could be initialized, aborting") # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() @@ -199,7 +219,9 @@ def create_app() -> DiracFastAPI: def dirac_error_handler(request: Request, exc: DiracError) -> Response: return JSONResponse( - status_code=exc.http_status_code, content={"detail": exc.detail} + status_code=exc.http_status_code, + content={"detail": exc.detail}, + headers=exc.http_headers, ) @@ -225,9 +247,34 @@ def find_dependents( yield from find_dependents(dependency.dependencies, cls) +_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10) + + +@cached(_db_alive_cache) +async def is_db_alive(db: T3): + """Cache the result of pinging the DB""" + is_alive, reason = await db.ping() + logger.debug("Pinged db %s, is_alive %s, reason %s", type(db), is_alive, reason) + if not is_alive: + raise SQLDBConnectionError(reason) + + async def db_transaction(db: T2) -> AsyncGenerator[T2, None]: - async with db: - yield db + """ + Initiate a DB transaction. + + :raises: RouteUnavailableError in case the connection to the DB fails + """ + + # Entering the context already triggers a connection to the DB + # that may fail, hence the try/except + try: + async with db: + # Check whether the connection still works before executing the query + await is_db_alive(db) + yield db + except SQLDBConnectionError as e: + raise RouteUnavailableError(repr(e)) from e async def db_session(db: T2) -> AsyncGenerator[T2, None]: diff --git a/tests/db/test_dummyDB.py b/tests/db/test_dummyDB.py index 029971ae5..8fa262d3f 100644 --- a/tests/db/test_dummyDB.py +++ b/tests/db/test_dummyDB.py @@ -7,6 +7,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.db.sql.dummy.db import DummyDB +from diracx.db.sql.utils import SQLDBConnectionError # Each DB test class must defined a fixture looking like this one # It allows to get an instance of an in memory DB, @@ -67,3 +68,11 @@ async def test_insert_and_summary(dummy_db: DummyDB): } ], ) + + +async def test_bad_connection(): + dummy_db = DummyDB("mysql+aiomysql://tata:yoyo@db.invalid:3306/name") + async with dummy_db.engine_context(): + with pytest.raises(SQLDBConnectionError): + async with dummy_db: + dummy_db.ping() diff --git a/tests/routers/test_generic.py b/tests/routers/test_generic.py index 4fe306d9c..99db2352c 100644 --- a/tests/routers/test_generic.py +++ b/tests/routers/test_generic.py @@ -1,3 +1,6 @@ +import pytest + + def test_openapi(test_client): r = test_client.get("/api/openapi.json") assert r.status_code == 200 @@ -15,3 +18,16 @@ def test_installation_metadata(test_client): assert r.status_code == 200 assert r.json() + + +@pytest.mark.xfail(reason="TODO") +def test_unavailable_db(monkeypatch, test_client): + # TODO + # That does not work because test_client is already initialized + monkeypatch.setenv( + "DIRACX_DB_URL_JOBDB", "mysql+aiomysql://tata:yoyo@dbod.cern.ch:3306/name" + ) + + r = test_client.get("/api/job/123") + assert r.status_code == 503 + assert r.json()