From 0bc09e56994b028cad606ceefe0622fe5def915f Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Thu, 14 Nov 2024 17:23:19 -0800 Subject: [PATCH] Add generic paginated query support Lift the pagninated query support from Gafaelfawr into Safir and make it generic so that it can be reused in Wobbly. Pagination support consists of a cursor class, which defines how results are ordered and paginated and how the cursor is serialized and deserialized, a paginated query runner that contains the logic for applying limits and keyset pagination, and a paginated result container that wraps a list of Pydantic models with metadata about the pagination. Also define a cursor type for the common case of results ordered by a `datetime` and then a unique key, in descending order, so that cursors for that case can be easily constructed. --- docs/user-guide/database/index.rst | 1 + docs/user-guide/database/pagination.rst | 230 ++++++++++ safir/src/safir/database/__init__.py | 12 + safir/src/safir/database/_pagination.py | 537 ++++++++++++++++++++++++ safir/tests/database_test.py | 176 +++++++- 5 files changed, 953 insertions(+), 3 deletions(-) create mode 100644 docs/user-guide/database/pagination.rst create mode 100644 safir/src/safir/database/_pagination.py diff --git a/docs/user-guide/database/index.rst b/docs/user-guide/database/index.rst index 77e68c3e..eb3ebd8c 100644 --- a/docs/user-guide/database/index.rst +++ b/docs/user-guide/database/index.rst @@ -25,5 +25,6 @@ Guides session datetime retry + pagination testing schema diff --git a/docs/user-guide/database/pagination.rst b/docs/user-guide/database/pagination.rst new file mode 100644 index 00000000..03974cd8 --- /dev/null +++ b/docs/user-guide/database/pagination.rst @@ -0,0 +1,230 @@ +################# +Paginated queries +################# + +Pagination in a web API is the mechanism for returning a partial answer to a query with many results, alongside an easy way for the client to retrieve the next batch of results. +Implementing pagination for any query that may return a large number of results is considered best practice for web APIs. +Most clients will only need the first batch or two of results, batching results reduces latency, and shorter replies are easier to manage and create less memory pressure on both the server and client. + +There are numerous ways to manage pagination (see `this blog post `__ for a good introductory overview). +Safir provides a generic implementation of keyset pagination with client-side cursors, which is often a reasonable choice. + +Elements of a paginated query +============================= + +A paginated query is a fairly complex operation that usually goes through multiple layers of data types. + +First, your application will construct a SQL query that returns the full, unpaginated data set. +This may use other query parameters or information from the client to limit the results in ways unrelated to pagination. +This SQL query should not order the results; the order will be applied when the paginated query is done. + +Second, your application optionally may provide either a limit on the number of results to return or a cursor indicating where in the overall list of results to pick up from. +If neither a limit nor a cursor is provided, the query is not paginated; you can still use the facilities discussed here for simplicity (and to avoid needing special cases for the non-paginated case), but pagination will not be done. + +Third, your application passes the SQL query and any limit or cursor, along with a database session, into `~safir.database.PaginatedQueryRunner` to perform the query. +This will apply the sort order and any restrictions from the limit or cursor and then execute the query in that session. +It will return a `~safir.database.PaginatedList`, which holds the results along with pagination information. + +Finally, the application will return, via the API handler, the list of entries included in the `~safir.database.PaginatedList` along with information about how to obtain the next or previous group of entries and the total number of records. +This pagination information is generally returned in HTTP headers, although if you wish to return it in a data structure wrapper around the results, you can do that instead. + +Defining the cursor +=================== + +To use the generic paginated query support, first you must define the cursor that will be used for the queries. +The cursor class defines the following information needed for paginated queries: + +#. How to construct a cursor to get all entries before or after a given entry. +#. How to serialize to and deserialize from a string so that the cursor can be returned to an API client and sent back to retrieve the next batch of results. +#. The sort order the cursor represents. + A cursor class represents one and only one sort order, since keyset cursors rely on the sort order not changing. + (Your application can use multiple cursors and thus support multiple sort orders, however.) +#. How to apply the cursor to limit a SQL statement. +#. How to invert a cursor (change it from going down the full list of results to going up the full list of results for previous links). + +In the general case, your application must define the cursor by creating a subclass of `~safir.database.PaginationCursor` and implementing its abstract methods. + +In the very common case that the API results are sorted first by some timestamp in descending order (most recent first) and then by an auto-increment unique key (most recently inserted row first), Safir provides `~safir.database.DatetimeIdCursor`, which is a generic cursor implementation that implements that ordering and keyset pagination policy. +In this case, you need only subclass `~safir.database.DatetimeIdCursor` and provide the SQLAlchemy ORM model columns that correspond to the timestamp and the unique key. + +For example, if you are requesting paginated results from a table whose ORM model is named ``Job``, whose timestamp field is ``Job.creation_time``, and whose unique key is ``Job.id``, you can use the following cursor: + +.. code-block:: python + + from safir.database import DatetimeIdCursor + from sqlalchemy.orm import InstrumentedAttribute + + + class JobCursor(DatetimeIdCursor): + @staticmethod + def id_column() -> InstrumentedAttribute: + return Job.id + + @staticmethod + def time_column() -> InstrumentedAttribute: + return Job.creation_time + +(These are essentially class properties, but due to limitations in Python abstract data types and property decorators, they're implemented as static methods.) + +In this case, `~safir.database.DatetimeIdCursor` will handle all of the other details for you, including serialization and deserialization. + +Performing paginated queries +============================ + +Parse the cursor and limit +-------------------------- + +Handlers for routes that return paginated results should take an optional pagination cursor as a query parameter. +This will be used by the client to move forward and backwards through the results. + +The parameter declaration should generally look something like the following: + +.. code-block:: + + @router.get("/query", response_class=Model) + async def query( + *, + cursor: Annotated[ + ModelCursor | None, + Query( + title="Pagination cursor", + description=( + "Optional cursor used when moving between pages of results" + ), + ), + BeforeValidator(lambda c: SomeCursor.from_str(c) if c else None), + ], + limit: Annotated[ + int, + Query( + title="Row limit", + description="Maximum number of entries to return", + examples=[100], + ge=1, + le=100, + ), + ] = 100, + request: Request, + response: Response, + ) -> list[Model]: + ... + +You should be able to use your class's implementation of `~safir.database.PaginationCursor.from_str` as a validator, which lets FastAPI validate the syntax of the cursor for you and handle syntax errors. +Since the cursor is optional (the first query won't have a cursor), you'll need a small wrapper to handle `None`, as shown above. + +Also note the ``limit`` parameter, which should also be used on any paginated route. +This sets the size of each block of results. + +As shown here, you will generally want to set some upper limit on how large the limit can be and set a default limit if none was provided. +This ensures that clients cannot retrieve the full list of results with one query. + +If the clients are sufficiently trusted or if you're certain the application can handle returning the full list of objects without creating resource problems, you can allow ``limit`` to be omitted and default it to `None`. +The paginated query support in Safir will treat that as an unlimited query and will return all of the available results. +In this case, you should change the type to ``int | None`` and remove the ``le`` constraint on the parameter. + +Create the runner +----------------- + +The first step of performing a paginated query is to create a `~safir.database.PaginatedQueryRunner` object. +Its constructor takes as arguments the type of the Pydantic model that will hold each returned object and the type of the cursor that will be used for pagination. + +.. code-block:: + + runner = PaginatedQueryRunner(Job, JobCursor) + +Construct the query +------------------- + +Then, define the SQL query as a SQLAlchemy `~sqlalchemy.Select` statement. +You can do this in two ways: either a query that returns a single SQLAlchemy ORM model, or a query for a list of specific columns. +Other combinations are not supported. + +For example: + +.. code-block:: + + stmt = select(Job).where(Job.username == "someuser") + +Or, an example of selecting specific columns: + +.. code-block:: + + stmt = select(Job.id, Job.timestamp, Job.description) + +Ensure that all of the attributes required to create a cursor are included in the query and in the Pydantic model. + +In either case, the data returned by the query must be sufficient to construct the Pydantic model passed as the first argument to the `~safir.database.PaginatedQueryRunner` constructor. +The query result will be passed into the ``model_validate`` method of that model. +Among other things, this means that all necessary attributes must be present and the model must be able to handle any data conversion required. + +If the model includes any timestamps, the model validation must be able to convert them from the time format stored in the database (see :doc:`datetime`) to an appropriate Python `~datetime.datetime`. +The easiest way to do this is to declare those fields as having the `safir.pydantic.UtcDatetime` type. +See :ref:`pydantic-datetime` for more information. + +Run the query +------------- + +Finally, you can run the query. +There are two ways to do this depending on how the query is structured. + +If the SQL query returns a single ORM model for each result row, use `~safir.database.PaginatedQueryRunner.query_object`: + +.. code-block:: + + results = await runner.query_object( + session, stmt, cursor=cursor, limit=limit + ) + +If the SQL query returns a tuple of individually selected attributes that correspond to the fields of the result model (the first parameter to the `~safir.database.PaginatedQueryRunner` constructor), use `~safir.database.PaginatedQueryRunner.query_row`: + +.. code-block:: + + results = await runner.query_row(session, stmt, cursor=cursor, limit=limit) + +Either way, the results will be a `~safir.database.PaginatedList` wrapping a list of Pydantic models of the appropriate type. + +Returning paginated results +=========================== + +HTTP provides the ``Link`` header (:rfc:`8288`) to declare relationships between multiple web responses. +Using a ``Link`` header with relation types ``first``, ``next``, and ``prev`` is a standard way of providing the client with pagination information. + +The Safir `~safir.database.PaginatedList` type provides a method, `~safir.database.PaginatedList.link_header`, which returns the contents of an HTTP ``Link`` header for a given paginated result. +It takes as its argument the base URL for the query (usually the current URL of a route handler). +This is the recommended way to return pagination information alongside a result. + +Here is a very simplified example of a route handler that sets this header: + +.. code-block:: python + + @router.get("/query", response_class=Model) + async def query( + *, + cursor: Annotated[ + ModelCursor | None, + Query(), + BeforeValidator(ModelCursor.from_str), + ], + limit: int | None = Query(), + session: Annotated[ + async_scoped_session, Depends(db_session_dependency) + ], + request: Request, + response: Response, + ) -> list[Model]: + runner = PydanticQueryRunner(Model, ModelCursor) + stmt = build_query(...) + results = await runner.query_object( + session, stmt, cursor=cursor, limit=limit + ) + if cursor or limit: + response.headers["Link"] = results.link_header(request.url) + response.headers["X-Total-Count"] = str(results.count) + return results.entries + +Here, ``perform_query`` is a wrapper around `~safir.databsae.PaginatedQueryRunner` that constructs and runs the query. +A real route handler would have more query parameters and more documentation. + +Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. +`~safir.database.PaginatedQueryRunner` obtains this information by default, since the count query is often fast for databases to perform. +There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard. diff --git a/safir/src/safir/database/__init__.py b/safir/src/safir/database/__init__.py index cc2fbea2..afb2353d 100644 --- a/safir/src/safir/database/__init__.py +++ b/safir/src/safir/database/__init__.py @@ -16,11 +16,23 @@ drop_database, initialize_database, ) +from ._pagination import ( + DatetimeIdCursor, + InvalidCursorError, + PaginatedList, + PaginatedQueryRunner, + PaginationCursor, +) from ._retry import RetryP, RetryT, retry_async_transaction __all__ = [ "AlembicConfigError", "DatabaseInitializationError", + "DatetimeIdCursor", + "InvalidCursorError", + "PaginationCursor", + "PaginatedList", + "PaginatedQueryRunner", "create_async_session", "create_database_engine", "datetime_from_db", diff --git a/safir/src/safir/database/_pagination.py b/safir/src/safir/database/_pagination.py new file mode 100644 index 00000000..a5963e59 --- /dev/null +++ b/safir/src/safir/database/_pagination.py @@ -0,0 +1,537 @@ +"""Support for paginated database queries. + +This pagination support uses keyset pagination rather than relying on database +cursors, since the latter interact poorly with horizontally scaled services. +""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Generic, Self, TypeVar +from urllib.parse import parse_qs, urlencode + +from pydantic import BaseModel +from sqlalchemy import Select, and_, func, or_, select +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute +from starlette.datastructures import URL + +from safir.fastapi import ClientRequestError +from safir.models import ErrorLocation + +from ._datetime import datetime_to_db + +C = TypeVar("C", bound="PaginationCursor") +"""Type of a cursor for a paginated list.""" + +E = TypeVar("E", bound="BaseModel") +"""Type of an entry in a paginated list.""" + +__all__ = [ + "DatetimeIdCursor", + "InvalidCursorError", + "PaginationCursor", + "PaginatedList", + "PaginatedQueryRunner", +] + + +class InvalidCursorError(ClientRequestError): + """The provided cursor was invalid.""" + + error = "invalid_cursor" + + def __init__(self, message: str) -> None: + super().__init__(message, ErrorLocation.query, ["cursor"]) + + +@dataclass +class PaginationCursor(Generic[E], metaclass=ABCMeta): + """Generic pagnination cursor for keyset pagination. + + The generic type parameter is the Pydantic model into which each row will + be converted, not the ORM model. + """ + + previous: bool + """Whether to search backwards instead of forwards.""" + + @classmethod + @abstractmethod + def from_entry(cls, entry: E, *, reverse: bool = False) -> Self: + """Construct a cursor with an entry as a bound. + + Builds a cursor to get the entries after the provided entry, or before + the provided entry if ``reverse`` is set to `True`. + + Parameters + ---------- + entry + Basis of the cursor. + reverse + Whether to create a previous cursor. + + Returns + ------- + PaginationCursor + Requested cursor. + """ + + @classmethod + @abstractmethod + def from_str(cls, cursor: str) -> Self: + """Build cursor from the string serialization form. + + Parameters + ---------- + cursor + Serialized form of the cursor. + + Returns + ------- + DatetimeIdCursor + The cursor represented as an object. + + Raises + ------ + InvalidCursorError + Raised if the cursor is not valid. + """ + + @classmethod + @abstractmethod + def apply_order(cls, stmt: Select, *, reverse: bool = False) -> Select: + """Apply the sort order of the cursor to a select statement. + + This is independent of the cursor and only needs to know the + underlying ORM fields, so it is available as a class method on the + cursor class, allowing it to be used without a cursor (such as for the + initial query). This does, however, mean that the caller has to + explicitly say whether to reverse the order, which is required when + using a previous cursor. + + Parameters + ---------- + stmt + SQL select statement. + reverse + Whether to reverse the sort order. + + Returns + ------- + Select + The same select statement but sorted in the order expected by the + cursor. + """ + + @abstractmethod + def apply_cursor(self, stmt: Select) -> Select: + """Apply the restrictions from the cursor to a select statement. + + Parameters + ---------- + stmt + Select statement to modify. + + Returns + ------- + Select + Modified select statement. + """ + + @abstractmethod + def invert(self) -> Self: + """Return the inverted cursor (going the opposite direction). + + Parameters + ---------- + cursor + Cursor to invert. + + Returns + ------- + DatetimeIdCursor + The inverted cursor. + """ + + +@dataclass +class DatetimeIdCursor(PaginationCursor[E], metaclass=ABCMeta): + """Pagination cursor using a `~datetime.datetime` and unique column ID. + + Cursors that first order by time and then by unique column ID can subclass + this class and only define the ``id_column`` and ``time_column`` static + methods to return the ORM model fields for the timestamp and column ID. + + Examples + -------- + Here is a specialization of this cursor class for a simple ORM model where + the timestamp field to order by is named ``creation_time`` and the unique + row ID is named ``id``. + + .. code-block:: python + + class TableCursor(DatetimeIdCursor): + @staticmethod + def id_column() -> InstrumentedAttribute: + return Table.id + + @staticmethod + def time_column() -> InstrumentedAttribute: + return Table.creation_time + """ + + time: datetime + """Time position.""" + + id: int + """Unique ID position.""" + + @staticmethod + @abstractmethod + def id_column() -> InstrumentedAttribute: + """Return SQL model attribute holding the ID.""" + + @staticmethod + @abstractmethod + def time_column() -> InstrumentedAttribute: + """Return SQL model attribute holding the time position.""" + + @classmethod + def from_str(cls, cursor: str) -> Self: + previous = cursor.startswith("p") + if previous: + cursor = cursor[1:] + try: + time, id = cursor.split("_") + return cls( + time=datetime.fromtimestamp(float(time), tz=UTC), + id=int(id), + previous=previous, + ) + except Exception as e: + raise InvalidCursorError(f"Invalid cursor: {e!s}") from e + + @classmethod + def apply_order(cls, stmt: Select, *, reverse: bool = False) -> Select: + if reverse: + return stmt.order_by(cls.time_column(), cls.id_column()) + else: + return stmt.order_by( + cls.time_column().desc(), cls.id_column().desc() + ) + + def apply_cursor(self, stmt: Select) -> Select: + time = datetime_to_db(self.time) + time_column = self.time_column() + id_column = self.id_column() + if self.previous: + return stmt.where( + or_( + time_column > time, + and_(time_column == time, id_column > self.id), + ) + ) + else: + return stmt.where( + or_( + time_column < time, + and_(time_column == time, id_column <= self.id), + ) + ) + + def invert(self) -> Self: + return type(self)( + time=self.time, id=self.id, previous=not self.previous + ) + + def __str__(self) -> str: + previous = "p" if self.previous else "" + timestamp = self.time.timestamp() + + # Remove a trailing .0, but keep the fractional portion if it matters. + if int(timestamp) == timestamp: + timestamp = int(timestamp) + + return f"{previous}{timestamp!s}_{self.id!s}" + + +@dataclass +class PaginatedList(Generic[E, C]): + """Paginated SQL results with accompanying pagination metadata. + + Holds a paginated list of any Pydantic type, complete with a count and + cursors. Can hold any type of entry, but uses a `DatetimeIdCursor`, so + implicitly requires the type be one that is meaningfully paginated by that + type of cursor. + """ + + entries: list[E] + """The history entries.""" + + count: int + """Total available entries.""" + + next_cursor: C | None = None + """Cursor for the next batch of entries.""" + + prev_cursor: C | None = None + """Cursor for the previous batch of entries.""" + + def link_header(self, base_url: URL) -> str: + """Construct an RFC 8288 ``Link`` header for a paginated result. + + Parameters + ---------- + base_url + The starting URL of the current group of entries. + """ + first_url = base_url.remove_query_params("cursor") + header = f'<{first_url!s}>; rel="first"' + params = parse_qs(first_url.query) + if self.next_cursor: + params["cursor"] = [str(self.next_cursor)] + next_url = first_url.replace(query=urlencode(params, doseq=True)) + header += f', <{next_url!s}>; rel="next"' + if self.prev_cursor: + params["cursor"] = [str(self.prev_cursor)] + prev_url = first_url.replace(query=urlencode(params, doseq=True)) + header += f', <{prev_url!s}>; rel="prev"' + return header + + +class PaginatedQueryRunner(Generic[E, C]): + """Construct and run database queries that return paginated results. + + This class implements the logic for keyset pagination based on arbitrary + SQLAlchemy ORM where clauses. + + Parameters + ---------- + entry_type + Type of each entry returned by the queries. This must be a Pydantic + model. + cursor_type + Type of the pagination cursor, which encapsulates the logic of how + entries are sorted and what set of keys is used to retrieve the next + or previous batch of entries. + """ + + def __init__(self, entry_type: type[E], cursor_type: type[C]) -> None: + self._entry_type = entry_type + self._cursor_type = cursor_type + + async def query_object( + self, + session: async_scoped_session, + stmt: Select[tuple[DeclarativeBase]], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> PaginatedList[E, C]: + """Perform a query for objects with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. + + This method should be used with queries that return a single + SQLAlchemy model. The provided query will be run with the + `Select.scalars` method and the resulting object passed to Pydantic's + ``model_validate`` to convert to ``entry_type``. For queries returning + a tuple of attributes, use `query_row` instead. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a list of SQLAlchemy + ORM models that can be converted to ``entry_type`` by Pydantic. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + PaginatedList + Results of the query wrapped with pagination information. + """ + if cursor or limit: + return await self._paginated_query( + session, stmt, cursor=cursor, limit=limit, scalar=True + ) + + # No pagination was required. Run the simple query in the correct + # sorted order and return it with no cursors. + stmt = self._cursor_type.apply_order(stmt) + result = await session.scalars(stmt) + entries = [ + self._entry_type.model_validate(r, from_attributes=True) + for r in result.all() + ] + return PaginatedList[E, C]( + entries=entries, + count=len(entries), + prev_cursor=None, + next_cursor=None, + ) + + async def query_row( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> PaginatedList[E, C]: + """Perform a query for attributes with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. + + This method should be used with queries that return a list of + attributes that can be converted to the ``entry_type`` Pydantic model. + For queries returning a single ORM object, use `query_object` instead. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a tuple of attributes + that can be converted to ``entry_type`` by Pydantic's + ``model_validate``. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + PaginatedList + Results of the query wrapped with pagination information. + """ + if cursor or limit: + return await self._paginated_query( + session, stmt, cursor=cursor, limit=limit + ) + + # No pagination was required. Run the simple query in the correct + # sorted order and return it with no cursors. + stmt = self._cursor_type.apply_order(stmt) + result = await session.execute(stmt) + entries = [ + self._entry_type.model_validate(r, from_attributes=True) + for r in result.all() + ] + return PaginatedList[E, C]( + entries=entries, + count=len(entries), + prev_cursor=None, + next_cursor=None, + ) + + async def _paginated_query( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + scalar: bool = False, + ) -> PaginatedList[E, C]: + """Perform a paginated query. + + The internal implementation details of the complicated case for + `query`, where either a cursor or a limit is in play. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + scalar + If `True`, the query returns one ORM object for each row instead + of a tuple of columns. + + Returns + ------- + PaginatedList + Results of the query wrapped with pagination information. + """ + limited_stmt = stmt + + # Apply the cursor, if there is one. + if cursor: + limited_stmt = cursor.apply_cursor(limited_stmt) + + # When retrieving a previous set of results using a previous cursor, + # we have to reverse the sort algorithm so that the cursor boundary + # can be applied correctly. We'll then later reverse the result set to + # return it in proper forward-sorted order. + if cursor and cursor.previous: + limited_stmt = cursor.apply_order(limited_stmt, reverse=True) + else: + limited_stmt = self._cursor_type.apply_order(limited_stmt) + + # Grab one more element than the query limit so that we know whether + # to create a cursor (because there are more elements) and what the + # cursor value should be (for forward cursors). + if limit: + limited_stmt = limited_stmt.limit(limit + 1) + + # Execute the query twice, once to get the next bach of results and + # once to get the count of all entries without pagination. + if scalar: + result = await session.scalars(limited_stmt) + else: + result = await session.execute(limited_stmt) + entries = [ + self._entry_type.model_validate(r, from_attributes=True) + for r in result.all() + ] + count_stmt = select(func.count()).select_from(stmt.subquery()) + count = await session.scalar(count_stmt) or 0 + + # Calculate the cursors and remove the extra element we asked for. + prev_cursor = None + next_cursor = None + if cursor and cursor.previous: + if limit: + next_cursor = cursor.invert() + if len(entries) > limit: + prev = entries[limit - 1] + prev_cursor = self._cursor_type.from_entry(prev) + entries = entries[:limit] + + # Reverse the results again if we did a reverse sort because we + # were using a previous cursor. + entries.reverse() + else: + if cursor: + prev_cursor = cursor.invert() + if limit and len(entries) > limit: + next_cursor = self._cursor_type.from_entry(entries[limit]) + entries = entries[:limit] + + # Return the results. + return PaginatedList[E, C]( + entries=entries, + count=count, + prev_cursor=prev_cursor, + next_cursor=next_cursor, + ) diff --git a/safir/tests/database_test.py b/safir/tests/database_test.py index 2c9131a7..cdf4bc31 100644 --- a/safir/tests/database_test.py +++ b/safir/tests/database_test.py @@ -5,20 +5,29 @@ import asyncio import os import subprocess +from dataclasses import dataclass from datetime import UTC, datetime, timedelta, timezone from pathlib import Path -from typing import Any +from typing import Any, Self from urllib.parse import unquote, urlparse import pytest import structlog from pydantic import BaseModel, SecretStr from pydantic_core import Url -from sqlalchemy import Column, MetaData, String, Table +from sqlalchemy import Column, MetaData, String, Table, select from sqlalchemy.exc import OperationalError, ProgrammingError -from sqlalchemy.future import select +from sqlalchemy.orm import ( + DeclarativeBase, + InstrumentedAttribute, + Mapped, + mapped_column, +) +from starlette.datastructures import URL from safir.database import ( + DatetimeIdCursor, + PaginatedQueryRunner, create_async_session, create_database_engine, datetime_from_db, @@ -32,6 +41,7 @@ unstamp_database, ) from safir.database._connection import build_database_url +from safir.pydantic import UtcDatetime from .support.alembic import BaseV1, BaseV2, UserV1, UserV2, config @@ -316,3 +326,163 @@ async def list_v2() -> list[str]: assert not event_loop.run_until_complete(check(config_path)) event_loop.run_until_complete(stamp_database_async(engine, config_path)) assert event_loop.run_until_complete(check(config_path)) + + +class PaginationBase(DeclarativeBase): + pass + + +class PaginationTable(PaginationBase): + __tablename__ = "table" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + time: Mapped[datetime] + + def __repr__(self) -> str: + return f"PaginationTable(id={self.id}, time={self.time})" + + +class PaginationModel(BaseModel): + id: int + time: UtcDatetime + + +@dataclass +class TableCursor(DatetimeIdCursor[PaginationModel]): + @staticmethod + def id_column() -> InstrumentedAttribute: + return PaginationTable.id + + @staticmethod + def time_column() -> InstrumentedAttribute: + return PaginationTable.time + + @classmethod + def from_entry( + cls, entry: PaginationModel, *, reverse: bool = False + ) -> Self: + return cls(time=entry.time, id=entry.id, previous=reverse) + + +def naive_datetime(timestamp: float) -> datetime: + """Construct timezone-naive datetimes for tests.""" + return datetime_to_db(datetime.fromtimestamp(timestamp, tz=UTC)) + + +def assert_model_lists_equal( + a: list[PaginationModel], b: list[PaginationTable] +) -> None: + assert len(a) == len(b) + for index, entry in enumerate(a): + assert entry.id == b[index].id, f"element {index} id" + orm_time = b[index].time.replace(tzinfo=UTC) + assert entry.time == orm_time, f"element {index} time" + + +@pytest.mark.asyncio +async def test_pagination(database_url: str, database_password: str) -> None: + logger = structlog.get_logger(__name__) + engine = create_database_engine(database_url, database_password) + await initialize_database( + engine, logger, schema=PaginationBase.metadata, reset=True + ) + session = await create_async_session(engine, logger) + + rows = [ + PaginationTable(time=naive_datetime(1500000000)), + PaginationTable(time=naive_datetime(1510000000)), + PaginationTable(time=naive_datetime(1520000000)), + PaginationTable(time=naive_datetime(1520000000)), + PaginationTable(time=naive_datetime(1600000000.5)), + PaginationTable(time=naive_datetime(1600000000.5)), + PaginationTable(time=naive_datetime(1610000000)), + ] + async with session.begin(): + for row in rows: + session.add(row) + + # Rows will be returned from the database in reverse order, so change the + # rows data structure to match. + rows.reverse() + + # Query by object and test the pagination cursors going backwards and + # forwards. + builder = PaginatedQueryRunner(PaginationModel, TableCursor) + async with session.begin(): + result = await builder.query_object( + session, select(PaginationTable), limit=2 + ) + assert_model_lists_equal(result.entries, rows[:2]) + assert result.count == 7 + assert not result.prev_cursor + assert result.link_header(URL("https://example.com/query")) == ( + '; rel="first", ' + f";" + ' rel="next"' + ) + assert str(result.next_cursor) == "1600000000.5_5" + result = await builder.query_object( + session, + select(PaginationTable), + cursor=result.next_cursor, + limit=3, + ) + assert_model_lists_equal(result.entries, rows[2:5]) + assert result.count == 7 + assert str(result.next_cursor) == "1510000000_2" + assert str(result.prev_cursor) == "p1600000000.5_5" + assert result.link_header( + URL("https://example.com/query?foo=bar&cursor=xxxx") + ) == ( + '; rel="first", ' + f";" + ' rel="next", ' + f";" + ' rel="prev"' + ) + next_cursor = result.next_cursor + result = await builder.query_object( + session, select(PaginationTable), cursor=result.prev_cursor + ) + assert_model_lists_equal(result.entries, rows[:2]) + assert result.count == 7 + result = await builder.query_object( + session, select(PaginationTable), cursor=next_cursor + ) + assert_model_lists_equal(result.entries, rows[5:]) + assert result.count == 7 + assert not result.next_cursor + result = await builder.query_object( + session, select(PaginationTable), cursor=result.prev_cursor + ) + assert_model_lists_equal(result.entries, rows[:5]) + assert result.count == 7 + + # Perform one of the queries by attribute instead to test the query_row + # function. + async with session.begin(): + result = await builder.query_row( + session, select(PaginationTable.time, PaginationTable.id), limit=2 + ) + assert_model_lists_equal(result.entries, rows[:2]) + assert result.count == 7 + + # Querying for the entire table should return the everything with no + # pagination cursors. Try this with both an object query and an attribute + # query. + async with session.begin(): + result = await builder.query_object(session, select(PaginationTable)) + assert_model_lists_equal(result.entries, rows) + assert result.count == 7 + assert not result.next_cursor + assert not result.prev_cursor + result = await builder.query_row( + session, select(PaginationTable.id, PaginationTable.time) + ) + assert_model_lists_equal(result.entries, rows) + assert result.count == 7 + assert not result.next_cursor + assert not result.prev_cursor + assert result.link_header(URL("https://example.com/query?foo=b")) == ( + '; rel="first"' + )