diff --git a/CHANGELOG.md b/CHANGELOG.md index aa2499457..df3a9eed2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - Better value validation for `state` parameter of `fetch_*_table()` methods ([#1616](https://github.com/neptune-ai/neptune-client/pull/1616)) - Parse `datetime` attribute values in `fetch_runs_table()` ([#1634](https://github.com/neptune-ai/neptune-client/pull/1634)) - Better handle limit in `fetch_*_table()` methods ([#1644](https://github.com/neptune-ai/neptune-client/pull/1644)) +- Fix pagination handling in table fetching ([#1651](https://github.com/neptune-ai/neptune-client/pull/1651)) ### Changes - Use literals instead of str for Mode typing ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 99c479fec..b728d0270 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -27,6 +27,10 @@ from bravado.client import construct_request # type: ignore from bravado.config import RequestConfig # type: ignore +from typing_extensions import ( + Literal, + TypeAlias, +) from neptune.internal.backends.api_model import ( AttributeType, @@ -54,6 +58,8 @@ SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in AttributeType} +SORT_BY_COLUMN_TYPE: TypeAlias = Literal["string", "datetime", "integer", "boolean", "float"] + class NoLimit(int): def __gt__(self, other: Any) -> bool: @@ -83,17 +89,18 @@ def get_single_page( limit: int, offset: int, sort_by: Optional[str] = None, - sort_by_column_type: Optional[str] = None, + sort_by_column_type: Optional[SORT_BY_COLUMN_TYPE] = None, ascending: bool = False, types: Optional[Iterable[str]] = None, query: Optional["NQLQuery"] = None, searching_after: Optional[str] = None, ) -> Any: normalized_query = query or NQLEmptyQuery() + sort_by_column_type = sort_by_column_type if sort_by_column_type else AttributeType.STRING.value if sort_by and searching_after: sort_by_as_nql = NQLQueryAttribute( name=sort_by, - type=NQLAttributeType.STRING, + type=NQLAttributeType(sort_by_column_type), operator=NQLAttributeOperator.GREATER_THAN, value=searching_after, ) @@ -167,7 +174,7 @@ def iter_over_pages( limit: Optional[int] = None, sort_by: str = "sys/id", max_offset: int = MAX_SERVER_OFFSET, - sort_by_column_type: Optional[str] = None, + sort_by_column_type: Optional[SORT_BY_COLUMN_TYPE] = None, ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, **kwargs: Any, @@ -223,8 +230,7 @@ def iter_over_pages( if offset == 0 and last_page is not None: total += result.get("matchingItemCount", 0) - if total > limit: - total = limit + total = min(total, limit) page = _entries_from_page(result) extracted_records += len(page) diff --git a/src/neptune/internal/backends/nql.py b/src/neptune/internal/backends/nql.py index 1435ecd11..f324a7097 100644 --- a/src/neptune/internal/backends/nql.py +++ b/src/neptune/internal/backends/nql.py @@ -66,6 +66,9 @@ class NQLAttributeType(str, Enum): STRING_SET = "stringSet" EXPERIMENT_STATE = "experimentState" BOOLEAN = "bool" + DATETIME = "datetime" + INTEGER = "integer" + FLOAT = "float" @dataclass