Skip to content

Commit

Permalink
Merge pull request #1856 from neptune-ai/rj/lineage
Browse files Browse the repository at this point in the history
Added support for `lineage` parameter when fetching float series values
  • Loading branch information
Raalsky authored Sep 2, 2024
2 parents adb90ff + 975556d commit d57023d
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 70 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- Series values DTO conversion reworked with protocol buffer support ([#1738](https://github.com/neptune-ai/neptune-client/pull/1738))
- Series values fetching reworked with protocol buffer support ([#1744](https://github.com/neptune-ai/neptune-client/pull/1744))
- Added support for enhanced field definitions querying ([#1751](https://github.com/neptune-ai/neptune-client/pull/1751))
- Added support for `NQL` `MATCHES` operator ([#1863](https://github.com/neptune-ai/neptune-client/pull/1863))

### Fixes
- Fixed `tqdm.notebook` import only in Notebook environment ([#1716](https://github.com/neptune-ai/neptune-client/pull/1716))
Expand Down
15 changes: 12 additions & 3 deletions src/neptune/attributes/series/fetchable_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import abc
from datetime import datetime
from functools import partial
from typing import (
Dict,
Generic,
Expand Down Expand Up @@ -50,14 +51,22 @@ def make_row(entry: Row, include_timestamp: bool = True) -> Dict[str, Union[str,

class FetchableSeries(Generic[Row]):
@abc.abstractmethod
def _fetch_values_from_backend(self, limit: int, from_step: Optional[float] = None) -> Row: ...
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> Row: ...

def fetch_values(self, *, include_timestamp: bool = True, progress_bar: Optional[ProgressBarType] = None):
def fetch_values(
self,
*,
include_timestamp: bool = True,
progress_bar: Optional[ProgressBarType] = None,
include_inherited: bool = True,
):
import pandas as pd

path = path_to_str(self._path) if hasattr(self, "_path") else ""
data = fetch_series_values(
getter=self._fetch_values_from_backend,
getter=partial(self._fetch_values_from_backend, include_inherited=include_inherited),
path=path,
progress_bar=progress_bar,
)
Expand Down
5 changes: 4 additions & 1 deletion src/neptune/attributes/series/float_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def fetch_last(self) -> float:
val = self._backend.get_float_series_attribute(self._container_id, self._container_type, self._path)
return val.last

def _fetch_values_from_backend(self, limit: int, from_step: Optional[float] = None) -> FloatSeriesValues:
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> FloatSeriesValues:
return self._backend.get_float_series_values(
container_id=self._container_id,
container_type=self._container_type,
path=self._path,
from_step=from_step,
limit=limit,
include_inherited=include_inherited,
)
4 changes: 3 additions & 1 deletion src/neptune/attributes/series/string_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def fetch_last(self) -> str:
val = self._backend.get_string_series_attribute(self._container_id, self._container_type, self._path)
return val.last

def _fetch_values_from_backend(self, limit: int, from_step: Optional[float] = None) -> StringSeriesValues:
def _fetch_values_from_backend(
self, limit: int, from_step: Optional[float] = None, include_inherited: bool = True
) -> StringSeriesValues:
return self._backend.get_string_series_values(
container_id=self._container_id,
container_type=self._container_type,
Expand Down
47 changes: 15 additions & 32 deletions src/neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -994,6 +993,7 @@ def get_float_series_values(
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:
use_proto = use_proto if use_proto is not None else self.use_proto

Expand All @@ -1003,6 +1003,10 @@ def get_float_series_values(
"limit": limit,
"skipToStep": from_step,
}

if not include_inherited:
params["lineage"] = "NONE"

try:
if use_proto:
result = (
Expand All @@ -1028,31 +1032,6 @@ def get_float_series_values(
except HTTPNotFound:
raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"attributeNamesFilter": field_names_filter,
"experimentIdsFilter": experiment_ids_filter,
},
**DEFAULT_REQUEST_KWARGS,
}

try:
result = self.leaderboard_client.api.queryAttributesWithinProject(**params).response().result
return QueryFieldsResult.from_model(result)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)

@with_api_exceptions_handler
def fetch_atom_attribute_values(
self, container_id: str, container_type: ContainerType, path: List[str]
Expand Down Expand Up @@ -1087,16 +1066,18 @@ def _get_file_set_download_request(self, container_id: str, container_type: Cont
raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def _get_column_types(self, project_id: UniqueId, column: str, types: Optional[Iterable[str]] = None) -> List[Any]:
def _get_column_types(self, project_id: UniqueId, column: str) -> List[Any]:
params = {
"projectIdentifier": project_id,
"search": column,
"type": types,
"params": {},
"query": {
"attributeNameFilter": {"mustMatchRegexes": [column]},
},
**DEFAULT_REQUEST_KWARGS,
}
try:
return self.leaderboard_client.api.searchLeaderboardAttributes(**params).response().result.entries
return (
self.leaderboard_client.api.queryAttributeDefinitionsWithinProject(**params).response().result.entries
)
except HTTPNotFound as e:
raise ProjectNotFound(project_id=project_id) from e

Expand All @@ -1119,6 +1100,8 @@ def search_leaderboard_entries(

step_size = min(default_step_size, limit) if limit else default_step_size

columns = set(columns) | {sort_by} if columns else {sort_by}

types_filter = list(map(lambda container_type: container_type.to_api(), types)) if types else None
attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {}

Expand All @@ -1127,7 +1110,7 @@ def search_leaderboard_entries(
elif sort_by == "sys/id":
sort_by_column_type = FieldType.STRING.value
else:
sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter)
sort_by_column_type_candidates = self._get_column_types(project_id, sort_by)
sort_by_column_type = _get_column_type_from_entries(sort_by_column_type_candidates, sort_by)

try:
Expand Down
11 changes: 1 addition & 10 deletions src/neptune/internal/backends/neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -273,6 +272,7 @@ def get_float_series_values(
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues: ...

@abc.abstractmethod
Expand Down Expand Up @@ -346,12 +346,3 @@ def query_fields_definitions_within_project(
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult: ...

@abc.abstractmethod
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult: ...
14 changes: 1 addition & 13 deletions src/neptune/internal/backends/neptune_backend_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringPointValue,
StringSeriesField,
Expand Down Expand Up @@ -472,6 +471,7 @@ def get_float_series_values(
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:
val = self._get_attribute(container_id, container_type, path, FloatSeries)
return FloatSeriesValues(
Expand Down Expand Up @@ -818,15 +818,3 @@ def query_fields_definitions_within_project(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)
1 change: 1 addition & 0 deletions src/neptune/internal/backends/nql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class NQLAttributeOperator(str, Enum):
CONTAINS = "CONTAINS"
GREATER_THAN = ">"
LESS_THAN = "<"
MATCHES = "MATCHES"


class NQLAttributeType(str, Enum):
Expand Down
11 changes: 1 addition & 10 deletions src/neptune/internal/backends/offline_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -123,6 +122,7 @@ def get_float_series_values(
limit: int,
from_step: Optional[float] = None,
use_proto: Optional[bool] = None,
include_inherited: bool = True,
) -> FloatSeriesValues:
raise NeptuneOfflineModeFetchException

Expand Down Expand Up @@ -185,12 +185,3 @@ def query_fields_definitions_within_project(
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
raise NeptuneOfflineModeFetchException

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
raise NeptuneOfflineModeFetchException

0 comments on commit d57023d

Please sign in to comment.