Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45281: Break apart safir.database for ease of maintenance #271

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/safir/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Utility functions for database management."""

from __future__ import annotations

from ._connection import create_async_session, create_database_engine
from ._datetime import datetime_from_db, datetime_to_db
from ._initialize import DatabaseInitializationError, initialize_database

__all__ = [
"DatabaseInitializationError",
"create_async_session",
"create_database_engine",
"datetime_from_db",
"datetime_to_db",
"initialize_database",
]
139 changes: 7 additions & 132 deletions src/safir/database.py → src/safir/database/_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Utility functions for database management."""
"""Managing database engines and sessions."""

from __future__ import annotations

import asyncio
from datetime import UTC, datetime, timedelta
from typing import overload
from urllib.parse import quote, urlparse

from pydantic import SecretStr
Expand All @@ -15,28 +13,21 @@
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.schema import CreateSchema
from sqlalchemy.sql.expression import Select
from sqlalchemy.sql.schema import MetaData
from structlog.stdlib import BoundLogger

__all__ = [
"DatabaseInitializationError",
"create_async_session",
"create_database_engine",
"datetime_from_db",
"datetime_to_db",
"initialize_database",
]


class DatabaseInitializationError(Exception):
"""Database initialization failed."""


def _build_database_url(url: str, password: str | SecretStr | None) -> str:
"""Build the authenticated URL for the database.

The database scheme is forced to ``postgresql+asyncpg`` if it is
``postgresql``.

Parameters
----------
url
Expand All @@ -47,12 +38,13 @@ def _build_database_url(url: str, password: str | SecretStr | None) -> str:
Returns
-------
url
The URL including the password.
URL including the password.

Raises
------
ValueError
A password was provided but the connection URL has no username.
Raised if a password was provided but the connection URL has no
username.
"""
parsed_url = urlparse(url)
if parsed_url.scheme == "postgresql":
Expand All @@ -74,66 +66,6 @@ def _build_database_url(url: str, password: str | SecretStr | None) -> str:
return parsed_url.geturl()


@overload
def datetime_from_db(time: datetime) -> datetime: ...


@overload
def datetime_from_db(time: None) -> None: ...


def datetime_from_db(time: datetime | None) -> datetime | None:
"""Add the UTC time zone to a naive datetime from the database.

Parameters
----------
time
The naive datetime from the database, or `None`.

Returns
-------
datetime.datetime or None
`None` if the input was none, otherwise a timezone-aware version of
the same `~datetime.datetime` in the UTC timezone.
"""
if not time:
return None
if time.tzinfo not in (None, UTC):
raise ValueError(f"datetime {time} not in UTC")
return time.replace(tzinfo=UTC)


@overload
def datetime_to_db(time: datetime) -> datetime: ...


@overload
def datetime_to_db(time: None) -> None: ...


def datetime_to_db(time: datetime | None) -> datetime | None:
"""Strip time zone for storing a datetime in the database.

Parameters
----------
time
The timezone-aware `~datetime.datetime` in the UTC time zone, or
`None`.

Returns
-------
datetime.datetime or None
`None` if the input was `None`, otherwise the same
`~datetime.datetime` but timezone-naive and thus suitable for storing
in a SQL database.
"""
if not time:
return None
if time.utcoffset() != timedelta(seconds=0):
raise ValueError(f"datetime {time} not in UTC")
return time.replace(tzinfo=None)


def create_database_engine(
url: str,
password: str | SecretStr | None,
Expand Down Expand Up @@ -231,60 +163,3 @@ async def create_async_session(
async with session.begin():
await session.execute(statement.limit(1))
return session


async def initialize_database(
engine: AsyncEngine,
logger: BoundLogger,
*,
schema: MetaData,
reset: bool = False,
) -> None:
"""Create and initialize a new database.

Parameters
----------
engine
Database engine to use. Create with `create_database_engine`.
logger
Logger used to report problems
schema
Metadata for the database schema. Generally this will be
``Base.metadata`` where ``Base`` is the declarative base used as the
base class for all ORM table definitions. The caller must ensure that
all table definitions have been imported by Python before calling this
function, or parts of the schema will be missing.
reset
If set to `True`, drop all tables and reprovision the database.
Useful when running tests with an external database.

Raises
------
DatabaseInitializationError
After five attempts, the database still could not be initialized.
This is normally due to some connectivity issue to the database.
"""
success = False
error = None
for _ in range(5):
try:
async with engine.begin() as conn:
if schema.schema is not None:
await conn.execute(CreateSchema(schema.schema, True))
if reset:
await conn.run_sync(schema.drop_all)
await conn.run_sync(schema.create_all)
success = True
except (ConnectionRefusedError, OperationalError, OSError) as e:
logger.info("database not ready, waiting two seconds")
error = str(e)
await asyncio.sleep(2)
continue
if success:
logger.info("initialized database schema")
break
if not success:
msg = "database schema initialization failed (database not reachable?)"
logger.error(msg)
await engine.dispose()
raise DatabaseInitializationError(error)
71 changes: 71 additions & 0 deletions src/safir/database/_datetime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""datetime management for databases."""

from __future__ import annotations

from datetime import UTC, datetime, timedelta
from typing import overload

__all__ = [
"datetime_from_db",
"datetime_to_db",
]


@overload
def datetime_from_db(time: datetime) -> datetime: ...


@overload
def datetime_from_db(time: None) -> None: ...


def datetime_from_db(time: datetime | None) -> datetime | None:
"""Add the UTC time zone to a naive datetime from the database.

Parameters
----------
time
The naive datetime from the database, or `None`.

Returns
-------
datetime.datetime or None
`None` if the input was none, otherwise a timezone-aware version of
the same `~datetime.datetime` in the UTC timezone.
"""
if not time:
return None
if time.tzinfo not in (None, UTC):
raise ValueError(f"datetime {time} not in UTC")
return time.replace(tzinfo=UTC)


@overload
def datetime_to_db(time: datetime) -> datetime: ...


@overload
def datetime_to_db(time: None) -> None: ...


def datetime_to_db(time: datetime | None) -> datetime | None:
"""Strip time zone for storing a datetime in the database.

Parameters
----------
time
The timezone-aware `~datetime.datetime` in the UTC time zone, or
`None`.

Returns
-------
datetime.datetime or None
`None` if the input was `None`, otherwise the same
`~datetime.datetime` but timezone-naive and thus suitable for storing
in a SQL database.
"""
if not time:
return None
if time.utcoffset() != timedelta(seconds=0):
raise ValueError(f"datetime {time} not in UTC")
return time.replace(tzinfo=None)
77 changes: 77 additions & 0 deletions src/safir/database/_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Database initialization."""

from __future__ import annotations

import asyncio

from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.schema import CreateSchema
from sqlalchemy.sql.schema import MetaData
from structlog.stdlib import BoundLogger

__all__ = [
"DatabaseInitializationError",
"initialize_database",
]


class DatabaseInitializationError(Exception):
"""Database initialization failed."""


async def initialize_database(
engine: AsyncEngine,
logger: BoundLogger,
*,
schema: MetaData,
reset: bool = False,
) -> None:
"""Create and initialize a new database.

Parameters
----------
engine
Database engine to use. Create with `create_database_engine`.
logger
Logger used to report problems
schema
Metadata for the database schema. Generally this will be
``Base.metadata`` where ``Base`` is the declarative base used as the
base class for all ORM table definitions. The caller must ensure that
all table definitions have been imported by Python before calling this
function, or parts of the schema will be missing.
reset
If set to `True`, drop all tables and reprovision the database.
Useful when running tests with an external database.

Raises
------
DatabaseInitializationError
After five attempts, the database still could not be initialized.
This is normally due to some connectivity issue to the database.
"""
success = False
error = None
for _ in range(5):
try:
async with engine.begin() as conn:
if schema.schema is not None:
await conn.execute(CreateSchema(schema.schema, True))
if reset:
await conn.run_sync(schema.drop_all)
await conn.run_sync(schema.create_all)
success = True
except (ConnectionRefusedError, OperationalError, OSError) as e:
logger.info("database not ready, waiting two seconds")
error = str(e)
await asyncio.sleep(2)
continue
if success:
logger.info("initialized database schema")
break
if not success:
msg = "database schema initialization failed (database not reachable?)"
logger.error(msg)
await engine.dispose()
raise DatabaseInitializationError(error)
2 changes: 1 addition & 1 deletion tests/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from sqlalchemy.orm import declarative_base

from safir.database import (
_build_database_url,
create_async_session,
create_database_engine,
datetime_from_db,
datetime_to_db,
initialize_database,
)
from safir.database._connection import _build_database_url

TEST_DATABASE_PASSWORD = os.environ["TEST_DATABASE_PASSWORD"]

Expand Down