Skip to content

Commit

Permalink
Endpoints added for enchanced fields search (#1751)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Apr 19, 2024
1 parent 352b1ad commit 24a49f2
Show file tree
Hide file tree
Showing 9 changed files with 779 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- Added support for Protocol Buffers ([#1728](https://github.com/neptune-ai/neptune-client/pull/1728))
- Series values DTO conversion reworked with protocol buffer support ([#1738](https://github.com/neptune-ai/neptune-client/pull/1738))
- Series values fetching reworked with protocol buffer support ([#1744](https://github.com/neptune-ai/neptune-client/pull/1744))
- Added support for enhanced field definitions querying ([#1751](https://github.com/neptune-ai/neptune-client/pull/1751))

### Fixes
- Fixed `tqdm.notebook` import only in Notebook environment ([#1716](https://github.com/neptune-ai/neptune-client/pull/1716))
Expand Down
102 changes: 102 additions & 0 deletions src/neptune/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
"StringSeriesValues",
"StringPointValue",
"ImageSeriesValues",
"QueryFieldDefinitionsResult",
"NextPage",
"QueryFieldsResult",
)

import abc
Expand Down Expand Up @@ -601,6 +604,105 @@ def from_proto(data: ProtoLeaderboardEntriesSearchResultDTO) -> LeaderboardEntri
)


@dataclass
class NextPage:
limit: Optional[int]
next_page_token: Optional[str]

@staticmethod
def from_dict(data: Dict[str, Any]) -> NextPage:
return NextPage(limit=data.get("limit"), next_page_token=data.get("nextPageToken"))

@staticmethod
def from_model(model: Any) -> NextPage:
return NextPage(limit=model.limit, next_page_token=model.nextPageToken)

@staticmethod
def from_proto(data: Any) -> NextPage:
raise NotImplementedError()

def to_dto(self) -> Dict[str, Any]:
return {
"limit": self.limit,
"nextPageToken": self.next_page_token,
}


@dataclass
class QueryFieldsExperimentResult:
object_id: str
object_key: str
fields: List[Field]

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=data["experimentId"],
object_key=data["experimentShortId"],
fields=[Field.from_dict(field) for field in data["attributes"]],
)

@staticmethod
def from_model(model: Any) -> QueryFieldsExperimentResult:
return QueryFieldsExperimentResult(
object_id=model.experimentId,
object_key=model.experimentShortId,
fields=[Field.from_model(field) for field in model.attributes],
)

@staticmethod
def from_proto(data: Any) -> QueryFieldsExperimentResult:
raise NotImplementedError()


@dataclass
class QueryFieldsResult:
entries: List[QueryFieldsExperimentResult]
next_page: NextPage

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)

@staticmethod
def from_model(model: Any) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[QueryFieldsExperimentResult.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)

@staticmethod
def from_proto(data: Any) -> QueryFieldsResult:
raise NotImplementedError()


@dataclass
class QueryFieldDefinitionsResult:
entries: List[FieldDefinition]
next_page: NextPage

@staticmethod
def from_dict(data: Dict[str, Any]) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_dict(entry) for entry in data["entries"]],
next_page=NextPage.from_dict(data["nextPage"]),
)

@staticmethod
def from_model(model: Any) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[FieldDefinition.from_model(entry) for entry in model.entries],
next_page=NextPage.from_model(model.nextPage),
)

@staticmethod
def from_proto(data: Any) -> QueryFieldDefinitionsResult:
raise NotImplementedError()


@dataclass
class FieldDefinition:
path: str
Expand Down
60 changes: 60 additions & 0 deletions src/neptune/api/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ("paginate_over",)

import abc
from dataclasses import dataclass
from typing import (
Any,
Callable,
Iterable,
Iterator,
Optional,
TypeVar,
)

from typing_extensions import Protocol

from neptune.api.models import NextPage


@dataclass
class WithPagination(abc.ABC):
next_page: Optional[NextPage]


T = TypeVar("T", bound=WithPagination)
Entry = TypeVar("Entry")


class Paginatable(Protocol):
def __call__(self, *, next_page: Optional[NextPage] = None, **kwargs: Any) -> Any: ...


def paginate_over(
getter: Paginatable,
extract_entries: Callable[[T], Iterable[Entry]],
**kwargs: Any,
) -> Iterator[Entry]:
"""
Generic approach to pagination via `NextPage`
"""
data = getter(**kwargs, next_page=None)
yield from extract_entries(data)

while data.next_page is not None and data.next_page.next_page_token is not None:
data = getter(**kwargs, next_page=data.next_page)
yield from extract_entries(data)
58 changes: 58 additions & 0 deletions src/neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -1025,6 +1028,31 @@ def get_float_series_values(
except HTTPNotFound:
raise FetchAttributeNotFoundException(path_to_str(path))

@with_api_exceptions_handler
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"attributeNamesFilter": field_names_filter,
"experimentIdsFilter": experiment_ids_filter,
},
**DEFAULT_REQUEST_KWARGS,
}

try:
result = self.leaderboard_client.api.queryAttributesWithinProject(**params).response().result
return QueryFieldsResult.from_model(result)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)

@with_api_exceptions_handler
def fetch_atom_attribute_values(
self, container_id: str, container_type: ContainerType, path: List[str]
Expand Down Expand Up @@ -1143,6 +1171,36 @@ def get_model_version_url(
base_url = self.get_display_address()
return f"{base_url}/{workspace}/{project_name}/m/{model_id}/v/{sys_id}"

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
pagination = {"nextPage": next_page.to_dto()} if next_page else {}
params = {
"projectIdentifier": project_id,
"query": {
**pagination,
"experimentIdsFilter": experiment_ids_filter,
"attributeNameRegex": field_name_regex,
},
}

try:
data = (
self.leaderboard_client.api.queryAttributeDefinitionsWithinProject(
**params,
**DEFAULT_REQUEST_KWARGS,
)
.response()
.result
)
return QueryFieldDefinitionsResult.from_model(data)
except HTTPNotFound:
raise ProjectNotFound(project_id=project_id)

def get_fields_definitions(
self,
container_id: str,
Expand Down
21 changes: 21 additions & 0 deletions src/neptune/internal/backends/neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand Down Expand Up @@ -334,3 +337,21 @@ def search_leaderboard_entries(
@abc.abstractmethod
def list_fileset_files(self, attribute: List[str], container_id: str, path: str) -> List[FileEntry]:
pass

@abc.abstractmethod
def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult: ...

@abc.abstractmethod
def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult: ...
27 changes: 27 additions & 0 deletions src/neptune/internal/backends/neptune_backend_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringPointValue,
StringSeriesField,
Expand Down Expand Up @@ -803,3 +806,27 @@ def get_fields_with_paths_filter(
self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None
) -> List[Field]:
return []

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
return QueryFieldDefinitionsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
return QueryFieldsResult(
entries=[],
next_page=NextPage(next_page_token=None, limit=0),
)
26 changes: 25 additions & 1 deletion src/neptune/internal/backends/offline_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
ImageSeriesValues,
IntField,
LeaderboardEntry,
NextPage,
QueryFieldDefinitionsResult,
QueryFieldsResult,
StringField,
StringSeriesField,
StringSeriesValues,
Expand All @@ -46,7 +49,10 @@
from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.internal.backends.nql import NQLQuery
from neptune.internal.container_type import ContainerType
from neptune.internal.id_formats import UniqueId
from neptune.internal.id_formats import (
QualifiedName,
UniqueId,
)
from neptune.typing import ProgressBarType


Expand Down Expand Up @@ -170,3 +176,21 @@ def search_leaderboard_entries(
use_proto: Optional[bool] = None,
) -> Generator[LeaderboardEntry, None, None]:
raise NeptuneOfflineModeFetchException

def query_fields_definitions_within_project(
self,
project_id: QualifiedName,
field_name_regex: Optional[str] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldDefinitionsResult:
raise NeptuneOfflineModeFetchException

def query_fields_within_project(
self,
project_id: QualifiedName,
field_names_filter: Optional[List[str]] = None,
experiment_ids_filter: Optional[List[str]] = None,
next_page: Optional[NextPage] = None,
) -> QueryFieldsResult:
raise NeptuneOfflineModeFetchException
Loading

0 comments on commit 24a49f2

Please sign in to comment.