From 42821ae27bb225658b3f6913b13942bec64bf591 Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Mon, 16 Dec 2024 15:52:53 -0800 Subject: [PATCH] Add abstraction for counted paginated queries `PaginatedQueryRunner` has a separate `query_count` method to return the count, but since the total entry count isn't included in the `PaginatedList` data structure, it's annoying to pass between components of a service. Either the result plus the count have to be returned as a tuple, or the service has to add its own data structure to wrap `PaginatedList`. Avoid this by introducing a `CountedPaginatedQueryRunner` and a corresponding `CountedPaginatedList` that always includes a `count` attribute. This can be used by services such as Gafaelfawr that always want to count the total number of entries, either because the table is small or because the count can always be satisfied from the table indices. --- docs/user-guide/database/pagination.rst | 52 +++++++- safir/src/safir/database/__init__.py | 4 + safir/src/safir/database/_pagination.py | 160 ++++++++++++++++++++++-- safir/tests/database_test.py | 47 +++++-- 4 files changed, 243 insertions(+), 20 deletions(-) diff --git a/docs/user-guide/database/pagination.rst b/docs/user-guide/database/pagination.rst index 2d3efca5..415eb603 100644 --- a/docs/user-guide/database/pagination.rst +++ b/docs/user-guide/database/pagination.rst @@ -264,10 +264,60 @@ Here is a very simplified example of a route handler that sets this header: Here, ``perform_query`` is a wrapper around `~safir.database.PaginatedQueryRunner` that constructs and runs the query. A real route handler would have more query parameters and more documentation. -Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. +Including result counts +----------------------- + +The example above also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. `~safir.database.PaginatedQueryRunner.query_count` will return this information. There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard. +If you will always want to include the count, use `~safir.database.CountedPaginatedQueryRunner` instead. +Its `~safir.database.CountedPaginatedQueryRunner.query_object` and `~safir.database.CountedPaginatedQueryRunner.query_row` methods will return a `~safir.database.CountedPaginatedList`, which contains a ``count`` attribute holding the count. +This is equivalent to calling `~safir.database.PaginatedQueryRunner.query_object` or `~safir.database.PaginatedQueryRunner.query_object` followed by `~safir.database.PaginatedQueryRunner.query_count`, but the encapsulation into a data structure makes it easier to pass the results between components of the service. + +Here's the same code above but using that approach: + +.. code-block:: python + +.. code-block:: python + :emphasize-lines: 27, 34 + + @router.get("/query", response_class=Model) + async def query( + *, + cursor: Annotated[ + str | None, + Query( + title="Pagination cursor", + description="Cursor to navigate paginated results", + ), + ] = None, + limit: Annotated[ + int, + Query( + title="Row limit", + description="Maximum number of entries to return", + examples=[100], + ge=1, + le=100, + ), + ] = 100, + request: Request, + response: Response, + ) -> list[Model]: + parsed_cursor = None + if cursor: + parsed_cursor = ModelCursor.from_str(cursor) + runner = CountedPaginatedQueryRunner(Model, ModelCursor) + stmt = build_query(...) + results = await runner.query_object( + session, stmt, cursor=parsed_cursor, limit=limit + ) + if cursor or limit: + response.headers["Link"] = results.link_header(request.url) + response.headers["X-Total-Count"] = str(results.count) + return results.entries + Including links in the response ------------------------------- diff --git a/safir/src/safir/database/__init__.py b/safir/src/safir/database/__init__.py index 5fb86255..3f8d69d3 100644 --- a/safir/src/safir/database/__init__.py +++ b/safir/src/safir/database/__init__.py @@ -17,6 +17,8 @@ initialize_database, ) from ._pagination import ( + CountedPaginatedList, + CountedPaginatedQueryRunner, DatetimeIdCursor, InvalidCursorError, PaginatedList, @@ -28,6 +30,8 @@ __all__ = [ "AlembicConfigError", + "CountedPaginatedList", + "CountedPaginatedQueryRunner", "DatabaseInitializationError", "DatetimeIdCursor", "InvalidCursorError", diff --git a/safir/src/safir/database/_pagination.py b/safir/src/safir/database/_pagination.py index bcaa1eb9..bef0264d 100644 --- a/safir/src/safir/database/_pagination.py +++ b/safir/src/safir/database/_pagination.py @@ -34,6 +34,8 @@ """Type of an entry in a paginated list.""" __all__ = [ + "CountedPaginatedList", + "CountedPaginatedQueryRunner", "DatetimeIdCursor", "InvalidCursorError", "PaginatedList", @@ -321,19 +323,18 @@ def __str__(self) -> str: class PaginatedList(Generic[E, C]): """Paginated SQL results with accompanying pagination metadata. - Holds a paginated list of any Pydantic type, complete with a count and - cursors. Can hold any type of entry and any type of cursor, but implicitly - requires the entry type be one that is meaningfully paginated by that type - of cursor. + Holds a paginated list of any Pydantic type with pagination cursors. Can + hold any type of entry and any type of cursor, but implicitly requires the + entry type be one that is meaningfully paginated by that type of cursor. """ entries: list[E] """A batch of entries.""" - next_cursor: C | None = None + next_cursor: C | None """Cursor for the next batch of entries.""" - prev_cursor: C | None = None + prev_cursor: C | None """Cursor for the previous batch of entries.""" def first_url(self, current_url: URL) -> str: @@ -418,7 +419,7 @@ def link_header(self, current_url: URL) -> str: class PaginatedQueryRunner(Generic[E, C]): - """Construct and run database queries that return paginated results. + """Run database queries that return paginated results. This class implements the logic for keyset pagination based on arbitrary SQLAlchemy ORM where clauses. @@ -688,3 +689,148 @@ async def _paginated_query( return PaginatedList[E, C]( entries=entries, prev_cursor=prev_cursor, next_cursor=next_cursor ) + + +@dataclass +class CountedPaginatedList(PaginatedList[E, C]): + """Paginated SQL results with pagination metadata and total count. + + Holds a paginated list of any Pydantic type, complete with a count and + cursors. Can hold any type of entry and any type of cursor, but implicitly + requires the entry type be one that is meaningfully paginated by that type + of cursor. + """ + + count: int + """Total number of entries if queried without pagination.""" + + +class CountedPaginatedQueryRunner(PaginatedQueryRunner[E, C]): + """Run database queries that return paginated results with counts. + + This variation of `PaginatedQueryRunner` always runs a second query to + count the total number of available entries if queried without pagination. + It should only be used on small tables or with queries that can be + satisfied from the table indices; otherwise, the count query could be + undesirably slow. + + Parameters + ---------- + entry_type + Type of each entry returned by the queries. This must be a Pydantic + model. + cursor_type + Type of the pagination cursor, which encapsulates the logic of how + entries are sorted and what set of keys is used to retrieve the next + or previous batch of entries. + """ + + async def query_object( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> CountedPaginatedList[E, C]: + """Perform a query for objects with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. Also performs a second + query to get the total count of entries if retrieved without + pagination. + + This method should be used with queries that return a single + SQLAlchemy model. The provided query will be run with the session + `~sqlalchemy.ext.asyncio.async_scoped_session.scalars` method and the + resulting object passed to Pydantic's ``model_validate`` to convert to + ``entry_type``. For queries returning a tuple of attributes, use + `query_row` instead. + + Unfortunately, this distinction cannot be type-checked, so be careful + to use the correct method. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a list of SQLAlchemy + ORM models that can be converted to ``entry_type`` by Pydantic. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + CountedPaginatedList + Results of the query wrapped with pagination information and a + count of the total number of entries. + """ + result = await super().query_object( + session, stmt, cursor=cursor, limit=limit + ) + count = await self.query_count(session, stmt) + return CountedPaginatedList[E, C]( + entries=result.entries, + next_cursor=result.next_cursor, + prev_cursor=result.prev_cursor, + count=count, + ) + + async def query_row( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> CountedPaginatedList[E, C]: + """Perform a query for attributes with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. Also performs a second + query to get the total count of entries if retrieved without + pagination. + + This method should be used with queries that return a list of + attributes that can be converted to the ``entry_type`` Pydantic model. + For queries returning a single ORM object, use `query_object` instead. + + Unfortunately, this distinction cannot be type-checked, so be careful + to use the correct method. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a tuple of attributes + that can be converted to ``entry_type`` by Pydantic's + ``model_validate``. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + CountedPaginatedList + Results of the query wrapped with pagination information and a + count of the total number of entries. + """ + result = await super().query_row( + session, stmt, cursor=cursor, limit=limit + ) + count = await self.query_count(session, stmt) + return CountedPaginatedList[E, C]( + entries=result.entries, + next_cursor=result.next_cursor, + prev_cursor=result.prev_cursor, + count=count, + ) diff --git a/safir/tests/database_test.py b/safir/tests/database_test.py index 4ddddbf8..0ac77d64 100644 --- a/safir/tests/database_test.py +++ b/safir/tests/database_test.py @@ -26,6 +26,7 @@ from starlette.datastructures import URL from safir.database import ( + CountedPaginatedQueryRunner, DatetimeIdCursor, PaginatedQueryRunner, PaginationLinkData, @@ -410,11 +411,12 @@ async def test_pagination(database_url: str, database_password: str) -> None: # Query by object and test the pagination cursors going backwards and # forwards. - builder = PaginatedQueryRunner(PaginationModel, TableCursor) + runner = PaginatedQueryRunner(PaginationModel, TableCursor) + counted_runner = CountedPaginatedQueryRunner(PaginationModel, TableCursor) async with session.begin(): stmt: Select[tuple] = select(PaginationTable) - assert await builder.query_count(session, stmt) == 7 - result = await builder.query_object(session, stmt, limit=2) + assert await runner.query_count(session, stmt) == 7 + result = await runner.query_object(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) assert not result.prev_cursor base_url = URL("https://example.com/query") @@ -427,7 +429,15 @@ async def test_pagination(database_url: str, database_password: str) -> None: assert result.prev_url(base_url) is None assert str(result.next_cursor) == "1600000000.5_1" - result = await builder.query_object( + counted_result = await counted_runner.query_object( + session, stmt, limit=2 + ) + assert counted_result.entries == result.entries + assert counted_result.prev_cursor == result.prev_cursor + assert counted_result.next_cursor == result.next_cursor + assert counted_result.count == 7 + + result = await runner.query_object( session, stmt, cursor=result.next_cursor, limit=3 ) assert_model_lists_equal(result.entries, rows[2:5]) @@ -447,7 +457,7 @@ async def test_pagination(database_url: str, database_password: str) -> None: assert result.prev_url(base_url) == prev_url next_cursor = result.next_cursor - result = await builder.query_object( + result = await runner.query_object( session, stmt, cursor=result.prev_cursor ) assert_model_lists_equal(result.entries, rows[:2]) @@ -457,7 +467,7 @@ async def test_pagination(database_url: str, database_password: str) -> None: f'<{base_url!s}&cursor={result.next_cursor}>; rel="next"' ) - result = await builder.query_object(session, stmt, cursor=next_cursor) + result = await runner.query_object(session, stmt, cursor=next_cursor) assert_model_lists_equal(result.entries, rows[5:]) assert not result.next_cursor base_url = URL("https://example.com/query") @@ -468,14 +478,14 @@ async def test_pagination(database_url: str, database_password: str) -> None: ) prev_cursor = result.prev_cursor - result = await builder.query_object(session, stmt, cursor=prev_cursor) + result = await runner.query_object(session, stmt, cursor=prev_cursor) assert_model_lists_equal(result.entries, rows[:5]) assert result.link_header(base_url) == ( f'<{base_url!s}>; rel="first", ' f'<{base_url!s}?cursor={result.next_cursor}>; rel="next"' ) - result = await builder.query_object( + result = await runner.query_object( session, stmt, cursor=prev_cursor, limit=2 ) assert_model_lists_equal(result.entries, rows[3:5]) @@ -490,26 +500,39 @@ async def test_pagination(database_url: str, database_password: str) -> None: # function. async with session.begin(): stmt = select(PaginationTable.time, PaginationTable.id) - result = await builder.query_row(session, stmt, limit=2) + result = await runner.query_row(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) - assert await builder.query_count(session, stmt) == 7 + assert await runner.query_count(session, stmt) == 7 + + counted_result = await counted_runner.query_row(session, stmt, limit=2) + assert counted_result.entries == result.entries + assert counted_result.prev_cursor == result.prev_cursor + assert counted_result.next_cursor == result.next_cursor + assert counted_result.count == 7 # Querying for the entire table should return the everything with no # pagination cursors. Try this with both an object query and an attribute # query. async with session.begin(): - result = await builder.query_object(session, select(PaginationTable)) + stmt = select(PaginationTable) + result = await runner.query_object(session, stmt) assert_model_lists_equal(result.entries, rows) assert not result.next_cursor assert not result.prev_cursor stmt = select(PaginationTable.id, PaginationTable.time) - result = await builder.query_row(session, stmt) + result = await runner.query_row(session, stmt) assert_model_lists_equal(result.entries, rows) assert not result.next_cursor assert not result.prev_cursor base_url = URL("https://example.com/query?foo=b") assert result.link_header(base_url) == (f'<{base_url!s}>; rel="first"') + counted_result = await counted_runner.query_row(session, stmt) + assert counted_result.entries == result.entries + assert not counted_result.next_cursor + assert not counted_result.prev_cursor + assert counted_result.count == len(counted_result.entries) + def test_link_data() -> None: header = (