diff --git a/CHANGELOG.md b/CHANGELOG.md index a527f99a6..c26b1d3a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ - Series values fetching reworked with protocol buffer support ([#1744](https://github.com/neptune-ai/neptune-client/pull/1744)) - Added support for enhanced field definitions querying ([#1751](https://github.com/neptune-ai/neptune-client/pull/1751)) - Added support for `NQL` `MATCHES` operator ([#1863](https://github.com/neptune-ai/neptune-client/pull/1863)) +- Pagination respecting `limit` parameter and page size ([#1866](https://github.com/neptune-ai/neptune-client/pull/1866)) ### Fixes - Fixed `tqdm.notebook` import only in Notebook environment ([#1716](https://github.com/neptune-ai/neptune-client/pull/1716)) diff --git a/src/neptune/api/pagination.py b/src/neptune/api/pagination.py index 19b9b308a..4d7646c71 100644 --- a/src/neptune/api/pagination.py +++ b/src/neptune/api/pagination.py @@ -16,12 +16,13 @@ __all__ = ("paginate_over",) import abc +import itertools from dataclasses import dataclass from typing import ( Any, Callable, - Iterable, Iterator, + List, Optional, TypeVar, ) @@ -46,15 +47,32 @@ def __call__(self, *, next_page: Optional[NextPage] = None, **kwargs: Any) -> An def paginate_over( getter: Paginatable, - extract_entries: Callable[[T], Iterable[Entry]], + extract_entries: Callable[[T], List[Entry]], + page_size: int = 50, + limit: Optional[int] = None, **kwargs: Any, ) -> Iterator[Entry]: """ Generic approach to pagination via `NextPage` """ - data = getter(**kwargs, next_page=None) - yield from extract_entries(data) + counter = 0 + data = getter(**kwargs, next_page=NextPage(limit=page_size, next_page_token=None)) + results = extract_entries(data) + if limit is not None: + counter = len(results[:limit]) + + yield from itertools.islice(results, limit) while data.next_page is not None and data.next_page.next_page_token is not None: - data = getter(**kwargs, next_page=data.next_page) - yield from extract_entries(data) + to_fetch = page_size + if limit is not None: + if counter >= limit: + break + to_fetch = min(page_size, limit - counter) + + data = getter(**kwargs, next_page=NextPage(limit=to_fetch, next_page_token=data.next_page.next_page_token)) + results = extract_entries(data) + if limit is not None: + counter += len(results[:to_fetch]) + + yield from itertools.islice(results, to_fetch) diff --git a/tests/unit/neptune/new/api/test_pagination.py b/tests/unit/neptune/new/api/test_pagination.py index 7a4d25542..b7b35a0be 100644 --- a/tests/unit/neptune/new/api/test_pagination.py +++ b/tests/unit/neptune/new/api/test_pagination.py @@ -75,9 +75,9 @@ def test__multiple_pages(): assert getter.call_count == 3 assert getter.call_args_list == [ - call(next_page=None), - call(next_page=NextPage(next_page_token="aa", limit=None)), - call(next_page=NextPage(next_page_token="bb", limit=None)), + call(next_page=NextPage(next_page_token=None, limit=50)), + call(next_page=NextPage(next_page_token="aa", limit=50)), + call(next_page=NextPage(next_page_token="bb", limit=50)), ] @@ -99,7 +99,54 @@ def test__kwargs_passed(): assert getter.call_count == 3 assert getter.call_args_list == [ - call(a=1, b=2, next_page=None), - call(a=1, b=2, next_page=NextPage(next_page_token="aa", limit=None)), - call(a=1, b=2, next_page=NextPage(next_page_token="bb", limit=None)), + call(a=1, b=2, next_page=NextPage(next_page_token=None, limit=50)), + call(a=1, b=2, next_page=NextPage(next_page_token="aa", limit=50)), + call(a=1, b=2, next_page=NextPage(next_page_token="bb", limit=50)), + ] + + +def test__page_size(): + # given + getter = Mock( + side_effect=[ + Mock(next_page=NextPage(next_page_token="aa", limit=None)), + Mock(next_page=NextPage(next_page_token="bb", limit=None)), + Mock(next_page=None), + ] + ) + + # when + entries = list(paginate_over(getter=getter, extract_entries=extract_entries, page_size=10, a=1, b=2)) + + # then + assert entries == [1, 2, 3, 1, 2, 3, 1, 2, 3] + + assert getter.call_count == 3 + assert getter.call_args_list == [ + call(a=1, b=2, next_page=NextPage(next_page_token=None, limit=10)), + call(a=1, b=2, next_page=NextPage(next_page_token="aa", limit=10)), + call(a=1, b=2, next_page=NextPage(next_page_token="bb", limit=10)), + ] + + +def test__limit(): + # given + getter = Mock( + side_effect=[ + Mock(next_page=NextPage(next_page_token="aa", limit=None)), + Mock(next_page=NextPage(next_page_token="bb", limit=None)), + Mock(next_page=None), + ] + ) + + # when + entries = list(paginate_over(getter=getter, extract_entries=extract_entries, page_size=3, limit=5, a=1, b=2)) + + # then + assert entries == [1, 2, 3, 1, 2] + + assert getter.call_count == 2 + assert getter.call_args_list == [ + call(a=1, b=2, next_page=NextPage(next_page_token=None, limit=3)), + call(a=1, b=2, next_page=NextPage(next_page_token="aa", limit=2)), ]