diff --git a/safir/src/safir/database/_pagination.py b/safir/src/safir/database/_pagination.py index 0027915a..12ef3e67 100644 --- a/safir/src/safir/database/_pagination.py +++ b/safir/src/safir/database/_pagination.py @@ -423,21 +423,8 @@ async def query_object( return await self._paginated_query( session, stmt, cursor=cursor, limit=limit, scalar=True ) - - # No pagination was required. Run the simple query in the correct - # sorted order and return it with no cursors. - stmt = self._cursor_type.apply_order(stmt) - result = await session.scalars(stmt) - entries = [ - self._entry_type.model_validate(r, from_attributes=True) - for r in result.all() - ] - return PaginatedList[E, C]( - entries=entries, - count=len(entries), - prev_cursor=None, - next_cursor=None, - ) + else: + return await self._full_query(session, stmt, scalar=True) async def query_row( self, @@ -480,11 +467,40 @@ async def query_row( return await self._paginated_query( session, stmt, cursor=cursor, limit=limit ) + else: + return await self._full_query(session, stmt) + + async def _full_query( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + scalar: bool = False, + ) -> PaginatedList[E, C]: + """Perform a full, unpaginated query. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. + scalar + If `True`, the query returns one ORM object for each row instead + of a tuple of columns. - # No pagination was required. Run the simple query in the correct - # sorted order and return it with no cursors. + Returns + ------- + PaginatedList + Results of the query wrapped with pagination information. + """ stmt = self._cursor_type.apply_order(stmt) - result = await session.execute(stmt) + if scalar: + result = await session.scalars(stmt) + else: + result = await session.execute(stmt) entries = [ self._entry_type.model_validate(r, from_attributes=True) for r in result.all()