diff --git a/src/safir/database/__init__.py b/src/safir/database/__init__.py new file mode 100644 index 00000000..8f7e5b95 --- /dev/null +++ b/src/safir/database/__init__.py @@ -0,0 +1,16 @@ +"""Utility functions for database management.""" + +from __future__ import annotations + +from ._connection import create_async_session, create_database_engine +from ._datetime import datetime_from_db, datetime_to_db +from ._initialize import DatabaseInitializationError, initialize_database + +__all__ = [ + "DatabaseInitializationError", + "create_async_session", + "create_database_engine", + "datetime_from_db", + "datetime_to_db", + "initialize_database", +] diff --git a/src/safir/database.py b/src/safir/database/_connection.py similarity index 55% rename from src/safir/database.py rename to src/safir/database/_connection.py index d5aade53..dd66e73a 100644 --- a/src/safir/database.py +++ b/src/safir/database/_connection.py @@ -1,10 +1,8 @@ -"""Utility functions for database management.""" +"""Managing database engines and sessions.""" from __future__ import annotations import asyncio -from datetime import UTC, datetime, timedelta -from typing import overload from urllib.parse import quote, urlparse from pydantic import SecretStr @@ -15,28 +13,21 @@ async_sessionmaker, create_async_engine, ) -from sqlalchemy.schema import CreateSchema from sqlalchemy.sql.expression import Select -from sqlalchemy.sql.schema import MetaData from structlog.stdlib import BoundLogger __all__ = [ - "DatabaseInitializationError", "create_async_session", "create_database_engine", - "datetime_from_db", - "datetime_to_db", - "initialize_database", ] -class DatabaseInitializationError(Exception): - """Database initialization failed.""" - - def _build_database_url(url: str, password: str | SecretStr | None) -> str: """Build the authenticated URL for the database. + The database scheme is forced to ``postgresql+asyncpg`` if it is + ``postgresql``. + Parameters ---------- url @@ -47,12 +38,13 @@ def _build_database_url(url: str, password: str | SecretStr | None) -> str: Returns ------- url - The URL including the password. + URL including the password. Raises ------ ValueError - A password was provided but the connection URL has no username. + Raised if a password was provided but the connection URL has no + username. """ parsed_url = urlparse(url) if parsed_url.scheme == "postgresql": @@ -74,66 +66,6 @@ def _build_database_url(url: str, password: str | SecretStr | None) -> str: return parsed_url.geturl() -@overload -def datetime_from_db(time: datetime) -> datetime: ... - - -@overload -def datetime_from_db(time: None) -> None: ... - - -def datetime_from_db(time: datetime | None) -> datetime | None: - """Add the UTC time zone to a naive datetime from the database. - - Parameters - ---------- - time - The naive datetime from the database, or `None`. - - Returns - ------- - datetime.datetime or None - `None` if the input was none, otherwise a timezone-aware version of - the same `~datetime.datetime` in the UTC timezone. - """ - if not time: - return None - if time.tzinfo not in (None, UTC): - raise ValueError(f"datetime {time} not in UTC") - return time.replace(tzinfo=UTC) - - -@overload -def datetime_to_db(time: datetime) -> datetime: ... - - -@overload -def datetime_to_db(time: None) -> None: ... - - -def datetime_to_db(time: datetime | None) -> datetime | None: - """Strip time zone for storing a datetime in the database. - - Parameters - ---------- - time - The timezone-aware `~datetime.datetime` in the UTC time zone, or - `None`. - - Returns - ------- - datetime.datetime or None - `None` if the input was `None`, otherwise the same - `~datetime.datetime` but timezone-naive and thus suitable for storing - in a SQL database. - """ - if not time: - return None - if time.utcoffset() != timedelta(seconds=0): - raise ValueError(f"datetime {time} not in UTC") - return time.replace(tzinfo=None) - - def create_database_engine( url: str, password: str | SecretStr | None, @@ -231,60 +163,3 @@ async def create_async_session( async with session.begin(): await session.execute(statement.limit(1)) return session - - -async def initialize_database( - engine: AsyncEngine, - logger: BoundLogger, - *, - schema: MetaData, - reset: bool = False, -) -> None: - """Create and initialize a new database. - - Parameters - ---------- - engine - Database engine to use. Create with `create_database_engine`. - logger - Logger used to report problems - schema - Metadata for the database schema. Generally this will be - ``Base.metadata`` where ``Base`` is the declarative base used as the - base class for all ORM table definitions. The caller must ensure that - all table definitions have been imported by Python before calling this - function, or parts of the schema will be missing. - reset - If set to `True`, drop all tables and reprovision the database. - Useful when running tests with an external database. - - Raises - ------ - DatabaseInitializationError - After five attempts, the database still could not be initialized. - This is normally due to some connectivity issue to the database. - """ - success = False - error = None - for _ in range(5): - try: - async with engine.begin() as conn: - if schema.schema is not None: - await conn.execute(CreateSchema(schema.schema, True)) - if reset: - await conn.run_sync(schema.drop_all) - await conn.run_sync(schema.create_all) - success = True - except (ConnectionRefusedError, OperationalError, OSError) as e: - logger.info("database not ready, waiting two seconds") - error = str(e) - await asyncio.sleep(2) - continue - if success: - logger.info("initialized database schema") - break - if not success: - msg = "database schema initialization failed (database not reachable?)" - logger.error(msg) - await engine.dispose() - raise DatabaseInitializationError(error) diff --git a/src/safir/database/_datetime.py b/src/safir/database/_datetime.py new file mode 100644 index 00000000..19ec9594 --- /dev/null +++ b/src/safir/database/_datetime.py @@ -0,0 +1,71 @@ +"""datetime management for databases.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import overload + +__all__ = [ + "datetime_from_db", + "datetime_to_db", +] + + +@overload +def datetime_from_db(time: datetime) -> datetime: ... + + +@overload +def datetime_from_db(time: None) -> None: ... + + +def datetime_from_db(time: datetime | None) -> datetime | None: + """Add the UTC time zone to a naive datetime from the database. + + Parameters + ---------- + time + The naive datetime from the database, or `None`. + + Returns + ------- + datetime.datetime or None + `None` if the input was none, otherwise a timezone-aware version of + the same `~datetime.datetime` in the UTC timezone. + """ + if not time: + return None + if time.tzinfo not in (None, UTC): + raise ValueError(f"datetime {time} not in UTC") + return time.replace(tzinfo=UTC) + + +@overload +def datetime_to_db(time: datetime) -> datetime: ... + + +@overload +def datetime_to_db(time: None) -> None: ... + + +def datetime_to_db(time: datetime | None) -> datetime | None: + """Strip time zone for storing a datetime in the database. + + Parameters + ---------- + time + The timezone-aware `~datetime.datetime` in the UTC time zone, or + `None`. + + Returns + ------- + datetime.datetime or None + `None` if the input was `None`, otherwise the same + `~datetime.datetime` but timezone-naive and thus suitable for storing + in a SQL database. + """ + if not time: + return None + if time.utcoffset() != timedelta(seconds=0): + raise ValueError(f"datetime {time} not in UTC") + return time.replace(tzinfo=None) diff --git a/src/safir/database/_initialize.py b/src/safir/database/_initialize.py new file mode 100644 index 00000000..8a022a51 --- /dev/null +++ b/src/safir/database/_initialize.py @@ -0,0 +1,77 @@ +"""Database initialization.""" + +from __future__ import annotations + +import asyncio + +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.schema import CreateSchema +from sqlalchemy.sql.schema import MetaData +from structlog.stdlib import BoundLogger + +__all__ = [ + "DatabaseInitializationError", + "initialize_database", +] + + +class DatabaseInitializationError(Exception): + """Database initialization failed.""" + + +async def initialize_database( + engine: AsyncEngine, + logger: BoundLogger, + *, + schema: MetaData, + reset: bool = False, +) -> None: + """Create and initialize a new database. + + Parameters + ---------- + engine + Database engine to use. Create with `create_database_engine`. + logger + Logger used to report problems + schema + Metadata for the database schema. Generally this will be + ``Base.metadata`` where ``Base`` is the declarative base used as the + base class for all ORM table definitions. The caller must ensure that + all table definitions have been imported by Python before calling this + function, or parts of the schema will be missing. + reset + If set to `True`, drop all tables and reprovision the database. + Useful when running tests with an external database. + + Raises + ------ + DatabaseInitializationError + After five attempts, the database still could not be initialized. + This is normally due to some connectivity issue to the database. + """ + success = False + error = None + for _ in range(5): + try: + async with engine.begin() as conn: + if schema.schema is not None: + await conn.execute(CreateSchema(schema.schema, True)) + if reset: + await conn.run_sync(schema.drop_all) + await conn.run_sync(schema.create_all) + success = True + except (ConnectionRefusedError, OperationalError, OSError) as e: + logger.info("database not ready, waiting two seconds") + error = str(e) + await asyncio.sleep(2) + continue + if success: + logger.info("initialized database schema") + break + if not success: + msg = "database schema initialization failed (database not reachable?)" + logger.error(msg) + await engine.dispose() + raise DatabaseInitializationError(error) diff --git a/tests/database_test.py b/tests/database_test.py index 7fec9fbd..033ae0f6 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -15,13 +15,13 @@ from sqlalchemy.orm import declarative_base from safir.database import ( - _build_database_url, create_async_session, create_database_engine, datetime_from_db, datetime_to_db, initialize_database, ) +from safir.database._connection import _build_database_url TEST_DATABASE_PASSWORD = os.environ["TEST_DATABASE_PASSWORD"]