Skip to content

Commit

Permalink
Merge pull request #1728 from neptune-ai/rj/protobuffers
Browse files Browse the repository at this point in the history
Added support for fetching data with protocol buffers
  • Loading branch information
Raalsky authored Apr 9, 2024
2 parents f9da5f2 + 64ad22a commit 9505ead
Show file tree
Hide file tree
Showing 22 changed files with 1,458 additions and 69 deletions.
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ repos:
- id: mypy
args: [ --config-file, pyproject.toml ]
pass_filenames: false
additional_dependencies: [ types-click ]
additional_dependencies: [ types-click, mypy-protobuf ]
exclude: |
(?x)(
^tests/unit/data/|
^.github/license_header\.txt
^.github/license_header\.txt|
^src/neptune/api/proto/
)
default_language_version:
python: python3
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Stop sending `X-Neptune-LegacyClient` header ([#1715](https://github.com/neptune-ai/neptune-client/pull/1715))
- Use `tqdm.auto` ([#1717](https://github.com/neptune-ai/neptune-client/pull/1717))
- Fields DTO conversion reworked ([#1722](https://github.com/neptune-ai/neptune-client/pull/1722))
- Added support for Protocol Buffers ([#1728](https://github.com/neptune-ai/neptune-client/pull/1728))

### Features
- ?
Expand Down
2 changes: 2 additions & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ pytest-xdist
vega_datasets
backoff
altair
icecream
bokeh
matplotlib
seaborn
plotly
tensorflow
torch
typing_extensions>=4.6.0
grpcio-tools

# e2e
scikit-learn
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ requests-oauthlib = ">=1.0.0"
websocket-client = ">=0.35.0, !=1.0.0"
urllib3 = "*"
swagger-spec-validator = ">=2.7.4"
protobuf = "^4.0.0"

# Built-in integrations
boto3 = ">=1.28.0"
Expand Down Expand Up @@ -127,7 +128,7 @@ neptune = "neptune.cli.__main__:main"
[tool.black]
line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310', 'py311', 'py312']
include = '\.pyi?$'
include = '\.pyi?$,\_pb2\.py$'
exclude = '''
/(
\.git
Expand Down
128 changes: 126 additions & 2 deletions src/neptune/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@
)

import abc
import re
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from datetime import datetime
from datetime import (
datetime,
timezone,
)
from enum import Enum
from typing import (
Any,
Expand All @@ -58,6 +62,19 @@
TypeVar,
)

from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import ProtoAttributeDefinitionDTO
from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import (
ProtoAttributeDTO,
ProtoAttributesDTO,
ProtoBoolAttributeDTO,
ProtoDatetimeAttributeDTO,
ProtoFloatAttributeDTO,
ProtoFloatSeriesAttributeDTO,
ProtoIntAttributeDTO,
ProtoLeaderboardEntriesSearchResultDTO,
ProtoStringAttributeDTO,
ProtoStringSetAttributeDTO,
)
from neptune.internal.utils.iso_dates import parse_iso_date
from neptune.internal.utils.run_state import RunState

Expand Down Expand Up @@ -120,7 +137,19 @@ def from_dict(data: Dict[str, Any]) -> Field:
@staticmethod
def from_model(model: Any) -> Field:
field_type = str(model.type)
return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}Properties"))
return Field._registry[field_type].from_model(model.__getattr__(f"{field_type}Properties"))

@staticmethod
def from_proto(data: Any) -> Field:
field_type = str(data.type)
return Field._registry[field_type].from_proto(data.__getattribute__(f"{camel_to_snake(field_type)}_properties"))


def camel_to_snake(name: str) -> str:
# Insert an underscore before any uppercase letters and convert the string to lowercase
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
# Handle the case where there are uppercase letters in the middle of the name
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()


class FieldVisitor(Generic[Ret], abc.ABC):
Expand Down Expand Up @@ -189,6 +218,10 @@ def from_dict(data: Dict[str, Any]) -> FloatField:
def from_model(model: Any) -> FloatField:
return FloatField(path=model.attributeName, value=model.value)

@staticmethod
def from_proto(data: ProtoFloatAttributeDTO) -> FloatField:
return FloatField(path=data.attribute_name, value=data.value)


@dataclass
class IntField(Field, field_type=FieldType.INT):
Expand All @@ -205,6 +238,10 @@ def from_dict(data: Dict[str, Any]) -> IntField:
def from_model(model: Any) -> IntField:
return IntField(path=model.attributeName, value=model.value)

@staticmethod
def from_proto(data: ProtoIntAttributeDTO) -> IntField:
return IntField(path=data.attribute_name, value=data.value)


@dataclass
class BoolField(Field, field_type=FieldType.BOOL):
Expand All @@ -221,6 +258,10 @@ def from_dict(data: Dict[str, Any]) -> BoolField:
def from_model(model: Any) -> BoolField:
return BoolField(path=model.attributeName, value=model.value)

@staticmethod
def from_proto(data: ProtoBoolAttributeDTO) -> BoolField:
return BoolField(path=data.attribute_name, value=data.value)


@dataclass
class StringField(Field, field_type=FieldType.STRING):
Expand All @@ -237,6 +278,10 @@ def from_dict(data: Dict[str, Any]) -> StringField:
def from_model(model: Any) -> StringField:
return StringField(path=model.attributeName, value=model.value)

@staticmethod
def from_proto(data: ProtoStringAttributeDTO) -> StringField:
return StringField(path=data.attribute_name, value=data.value)


@dataclass
class DateTimeField(Field, field_type=FieldType.DATETIME):
Expand All @@ -253,6 +298,12 @@ def from_dict(data: Dict[str, Any]) -> DateTimeField:
def from_model(model: Any) -> DateTimeField:
return DateTimeField(path=model.attributeName, value=parse_iso_date(model.value))

@staticmethod
def from_proto(data: ProtoDatetimeAttributeDTO) -> DateTimeField:
return DateTimeField(
path=data.attribute_name, value=datetime.fromtimestamp(data.value / 1000.0, tz=timezone.utc)
)


@dataclass
class FileField(Field, field_type=FieldType.FILE):
Expand All @@ -271,6 +322,10 @@ def from_dict(data: Dict[str, Any]) -> FileField:
def from_model(model: Any) -> FileField:
return FileField(path=model.attributeName, name=model.name, ext=model.ext, size=model.size)

@staticmethod
def from_proto(data: Any) -> FileField:
raise NotImplementedError()


@dataclass
class FileSetField(Field, field_type=FieldType.FILE_SET):
Expand All @@ -287,6 +342,10 @@ def from_dict(data: Dict[str, Any]) -> FileSetField:
def from_model(model: Any) -> FileSetField:
return FileSetField(path=model.attributeName, size=model.size)

@staticmethod
def from_proto(data: Any) -> FileSetField:
raise NotImplementedError()


@dataclass
class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES):
Expand All @@ -304,6 +363,11 @@ def from_dict(data: Dict[str, Any]) -> FloatSeriesField:
def from_model(model: Any) -> FloatSeriesField:
return FloatSeriesField(path=model.attributeName, last=model.last)

@staticmethod
def from_proto(data: ProtoFloatSeriesAttributeDTO) -> FloatSeriesField:
last = data.last if data.HasField("last") else None
return FloatSeriesField(path=data.attribute_name, last=last)


@dataclass
class StringSeriesField(Field, field_type=FieldType.STRING_SERIES):
Expand All @@ -321,6 +385,10 @@ def from_dict(data: Dict[str, Any]) -> StringSeriesField:
def from_model(model: Any) -> StringSeriesField:
return StringSeriesField(path=model.attributeName, last=model.last)

@staticmethod
def from_proto(data: Any) -> StringSeriesField:
raise NotImplementedError()


@dataclass
class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES):
Expand All @@ -338,6 +406,10 @@ def from_dict(data: Dict[str, Any]) -> ImageSeriesField:
def from_model(model: Any) -> ImageSeriesField:
return ImageSeriesField(path=model.attributeName, last_step=model.lastStep)

@staticmethod
def from_proto(data: Any) -> ImageSeriesField:
raise NotImplementedError()


@dataclass
class StringSetField(Field, field_type=FieldType.STRING_SET):
Expand All @@ -354,6 +426,10 @@ def from_dict(data: Dict[str, Any]) -> StringSetField:
def from_model(model: Any) -> StringSetField:
return StringSetField(path=model.attributeName, values=set(model.values))

@staticmethod
def from_proto(data: ProtoStringSetAttributeDTO) -> StringSetField:
return StringSetField(path=data.attribute_name, values=set(data.value))


@dataclass
class GitCommit:
Expand All @@ -368,6 +444,10 @@ def from_dict(data: Dict[str, Any]) -> GitCommit:
def from_model(model: Any) -> GitCommit:
return GitCommit(commit_id=model.commitId)

@staticmethod
def from_proto(data: Any) -> GitCommit:
raise NotImplementedError()


@dataclass
class GitRefField(Field, field_type=FieldType.GIT_REF):
Expand All @@ -386,6 +466,10 @@ def from_model(model: Any) -> GitRefField:
commit = GitCommit.from_model(model.commit) if model.commit is not None else None
return GitRefField(path=model.attributeName, commit=commit)

@staticmethod
def from_proto(data: ProtoAttributeDTO) -> GitRefField:
raise NotImplementedError()


@dataclass
class ObjectStateField(Field, field_type=FieldType.OBJECT_STATE):
Expand All @@ -404,6 +488,10 @@ def from_model(model: Any) -> ObjectStateField:
value = RunState.from_api(str(model.value)).value
return ObjectStateField(path=model.attributeName, value=value)

@staticmethod
def from_proto(data: Any) -> ObjectStateField:
raise NotImplementedError()


@dataclass
class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF):
Expand All @@ -421,6 +509,10 @@ def from_dict(data: Dict[str, Any]) -> NotebookRefField:
def from_model(model: Any) -> NotebookRefField:
return NotebookRefField(path=model.attributeName, notebook_name=model.notebookName)

@staticmethod
def from_proto(data: Any) -> NotebookRefField:
raise NotImplementedError()


@dataclass
class ArtifactField(Field, field_type=FieldType.ARTIFACT):
Expand All @@ -437,6 +529,10 @@ def from_dict(data: Dict[str, Any]) -> ArtifactField:
def from_model(model: Any) -> ArtifactField:
return ArtifactField(path=model.attributeName, hash=model.hash)

@staticmethod
def from_proto(data: Any) -> ArtifactField:
raise NotImplementedError()


@dataclass
class LeaderboardEntry:
Expand All @@ -455,6 +551,23 @@ def from_model(model: Any) -> LeaderboardEntry:
object_id=model.experimentId, fields=[Field.from_model(field) for field in model.attributes]
)

@staticmethod
def from_proto(data: ProtoAttributesDTO) -> LeaderboardEntry:
with_proto_support = {
FieldType.STRING.value,
FieldType.BOOL.value,
FieldType.INT.value,
FieldType.FLOAT.value,
FieldType.DATETIME.value,
FieldType.STRING_SET.value,
FieldType.FLOAT_SERIES.value,
}

return LeaderboardEntry(
object_id=data.experiment_id,
fields=[Field.from_proto(field) for field in data.attributes if str(field.type) in with_proto_support],
)


@dataclass
class LeaderboardEntriesSearchResult:
Expand All @@ -475,6 +588,13 @@ def from_model(result: Any) -> LeaderboardEntriesSearchResult:
matching_item_count=result.matchingItemCount,
)

@staticmethod
def from_proto(data: ProtoLeaderboardEntriesSearchResultDTO) -> LeaderboardEntriesSearchResult:
return LeaderboardEntriesSearchResult(
entries=[LeaderboardEntry.from_proto(entry) for entry in data.entries],
matching_item_count=data.matching_item_count,
)


@dataclass
class FieldDefinition:
Expand All @@ -488,3 +608,7 @@ def from_dict(data: Dict[str, Any]) -> FieldDefinition:
@staticmethod
def from_model(model: Any) -> FieldDefinition:
return FieldDefinition(path=model.name, type=FieldType(model.type))

@staticmethod
def from_proto(data: ProtoAttributeDefinitionDTO) -> FieldDefinition:
return FieldDefinition(path=data.name, type=FieldType(data.type))
15 changes: 15 additions & 0 deletions src/neptune/api/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
15 changes: 15 additions & 0 deletions src/neptune/api/proto/neptune_pb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
15 changes: 15 additions & 0 deletions src/neptune/api/proto/neptune_pb/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2024, Neptune Labs Sp. z o.o.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Loading

0 comments on commit 9505ead

Please sign in to comment.