Skip to content

Commit

Permalink
Factor out full queries in PaginatedQueryRunner
Browse files Browse the repository at this point in the history
The `query_object` and `query_row` methods had duplicate code for
the unpaginated case. Refactor that into a helper method.
  • Loading branch information
rra committed Nov 20, 2024
1 parent 78071b3 commit ff2c088
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions safir/src/safir/database/_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,21 +423,8 @@ async def query_object(
return await self._paginated_query(
session, stmt, cursor=cursor, limit=limit, scalar=True
)

# No pagination was required. Run the simple query in the correct
# sorted order and return it with no cursors.
stmt = self._cursor_type.apply_order(stmt)
result = await session.scalars(stmt)
entries = [
self._entry_type.model_validate(r, from_attributes=True)
for r in result.all()
]
return PaginatedList[E, C](
entries=entries,
count=len(entries),
prev_cursor=None,
next_cursor=None,
)
else:
return await self._full_query(session, stmt, scalar=True)

async def query_row(
self,
Expand Down Expand Up @@ -480,11 +467,40 @@ async def query_row(
return await self._paginated_query(
session, stmt, cursor=cursor, limit=limit
)
else:
return await self._full_query(session, stmt)

async def _full_query(
self,
session: async_scoped_session,
stmt: Select[tuple],
*,
scalar: bool = False,
) -> PaginatedList[E, C]:
"""Perform a full, unpaginated query.
Parameters
----------
session
Database session within which to run the query.
stmt
Select statement to execute. Pagination and ordering will be
added, so this statement should not already have limits or order
clauses applied.
scalar
If `True`, the query returns one ORM object for each row instead
of a tuple of columns.
# No pagination was required. Run the simple query in the correct
# sorted order and return it with no cursors.
Returns
-------
PaginatedList
Results of the query wrapped with pagination information.
"""
stmt = self._cursor_type.apply_order(stmt)
result = await session.execute(stmt)
if scalar:
result = await session.scalars(stmt)
else:
result = await session.execute(stmt)
entries = [
self._entry_type.model_validate(r, from_attributes=True)
for r in result.all()
Expand Down

0 comments on commit ff2c088

Please sign in to comment.