From 629df2ab566aa9e0c55626922ed863b5917a1600 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 29 Mar 2024 09:03:51 +0100 Subject: [PATCH 1/8] Protobuffers added --- .pre-commit-config.yaml | 5 +- dev_requirements.txt | 2 + pyproject.toml | 3 +- src/neptune/api/proto/__init__.py | 15 + src/neptune/api/proto/neptune_pb/__init__.py | 15 + .../api/proto/neptune_pb/api/__init__.py | 15 + .../proto/neptune_pb/api/model/__init__.py | 15 + .../neptune_pb/api/model/attributes_pb2.py | 29 ++ .../neptune_pb/api/model/attributes_pb2.pyi | 51 +++ .../api/model/leaderboard_entries_pb2.py | 47 +++ .../api/model/leaderboard_entries_pb2.pyi | 345 ++++++++++++++++++ .../backends/hosted_neptune_backend.py | 52 +++ .../internal/backends/neptune_backend.py | 10 + .../internal/backends/neptune_backend_mock.py | 9 + 14 files changed, 610 insertions(+), 3 deletions(-) create mode 100644 src/neptune/api/proto/__init__.py create mode 100644 src/neptune/api/proto/neptune_pb/__init__.py create mode 100644 src/neptune/api/proto/neptune_pb/api/__init__.py create mode 100644 src/neptune/api/proto/neptune_pb/api/model/__init__.py create mode 100644 src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.py create mode 100644 src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.pyi create mode 100644 src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.py create mode 100644 src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.pyi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15ed5f3cd..dfdc17fbd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,11 +34,12 @@ repos: - id: mypy args: [ --config-file, pyproject.toml ] pass_filenames: false - additional_dependencies: [ types-click ] + additional_dependencies: [ types-click, mypy-protobuf ] exclude: | (?x)( ^tests/unit/data/| - ^.github/license_header\.txt + ^.github/license_header\.txt| + ^src/neptune/api/proto/ ) default_language_version: python: python3 diff --git a/dev_requirements.txt b/dev_requirements.txt index 6d51c2889..5cb19474a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,6 +16,7 @@ pytest-xdist vega_datasets backoff altair +icecream bokeh matplotlib seaborn @@ -23,6 +24,7 @@ plotly tensorflow torch typing_extensions>=4.6.0 +grpcio-tools # e2e scikit-learn diff --git a/pyproject.toml b/pyproject.toml index 87cd1beaa..3e2f0bc36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ requests-oauthlib = ">=1.0.0" websocket-client = ">=0.35.0, !=1.0.0" urllib3 = "*" swagger-spec-validator = ">=2.7.4" +protobuf = "^4.0.0" # Built-in integrations boto3 = ">=1.28.0" @@ -127,7 +128,7 @@ neptune = "neptune.cli.__main__:main" [tool.black] line-length = 120 target-version = ['py37', 'py38', 'py39', 'py310', 'py311', 'py312'] -include = '\.pyi?$' +include = '\.pyi?$,\_pb2\.py$' exclude = ''' /( \.git diff --git a/src/neptune/api/proto/__init__.py b/src/neptune/api/proto/__init__.py new file mode 100644 index 000000000..665b8500e --- /dev/null +++ b/src/neptune/api/proto/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/neptune/api/proto/neptune_pb/__init__.py b/src/neptune/api/proto/neptune_pb/__init__.py new file mode 100644 index 000000000..665b8500e --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/neptune/api/proto/neptune_pb/api/__init__.py b/src/neptune/api/proto/neptune_pb/api/__init__.py new file mode 100644 index 000000000..665b8500e --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/neptune/api/proto/neptune_pb/api/model/__init__.py b/src/neptune/api/proto/neptune_pb/api/model/__init__.py new file mode 100644 index 000000000..665b8500e --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/model/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.py b/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.py new file mode 100644 index 000000000..98cc5432c --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: model/attributes.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16model/attributes.proto\x12\x11neptune.api.model\"a\n\x1eProtoAttributesSearchResultDTO\x12?\n\x07\x65ntries\x18\x01 \x03(\x0b\x32..neptune.api.model.ProtoAttributeDefinitionDTO\"9\n\x1bProtoAttributeDefinitionDTO\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\tB4\n0ml.neptune.leaderboard.api.model.proto.generatedP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model.attributes_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n0ml.neptune.leaderboard.api.model.proto.generatedP\001' + _globals['_PROTOATTRIBUTESSEARCHRESULTDTO']._serialized_start=45 + _globals['_PROTOATTRIBUTESSEARCHRESULTDTO']._serialized_end=142 + _globals['_PROTOATTRIBUTEDEFINITIONDTO']._serialized_start=144 + _globals['_PROTOATTRIBUTEDEFINITIONDTO']._serialized_end=201 +# @@protoc_insertion_point(module_scope) diff --git a/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.pyi b/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.pyi new file mode 100644 index 000000000..f2acfeeb3 --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/model/attributes_pb2.pyi @@ -0,0 +1,51 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import sys + +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing_extensions.final +class ProtoAttributesSearchResultDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ENTRIES_FIELD_NUMBER: builtins.int + @property + def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributeDefinitionDTO]: ... + def __init__( + self, + *, + entries: collections.abc.Iterable[global___ProtoAttributeDefinitionDTO] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["entries", b"entries"]) -> None: ... + +global___ProtoAttributesSearchResultDTO = ProtoAttributesSearchResultDTO + +@typing_extensions.final +class ProtoAttributeDefinitionDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + name: builtins.str + type: builtins.str + def __init__( + self, + *, + name: builtins.str = ..., + type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["name", b"name", "type", b"type"]) -> None: ... + +global___ProtoAttributeDefinitionDTO = ProtoAttributeDefinitionDTO diff --git a/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.py b/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.py new file mode 100644 index 000000000..f6a735ee7 --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: model/leaderboard_entries.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fmodel/leaderboard_entries.proto\x12\x11neptune.api.model\"\xb3\x01\n&ProtoLeaderboardEntriesSearchResultDTO\x12\x1b\n\x13matching_item_count\x18\x01 \x01(\x03\x12\x1e\n\x11total_group_count\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x36\n\x07\x65ntries\x18\x03 \x03(\x0b\x32%.neptune.api.model.ProtoAttributesDTOB\x14\n\x12_total_group_count\"\xd1\x01\n\x12ProtoAttributesDTO\x12\x15\n\rexperiment_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x12\n\nproject_id\x18\x03 \x01(\t\x12\x17\n\x0forganization_id\x18\x04 \x01(\t\x12\x14\n\x0cproject_name\x18\x05 \x01(\t\x12\x19\n\x11organization_name\x18\x06 \x01(\t\x12\x38\n\nattributes\x18\x07 \x03(\x0b\x32$.neptune.api.model.ProtoAttributeDTO\"\xed\x05\n\x11ProtoAttributeDTO\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x44\n\x0eint_properties\x18\x03 \x01(\x0b\x32\'.neptune.api.model.ProtoIntAttributeDTOH\x00\x88\x01\x01\x12H\n\x10\x66loat_properties\x18\x04 \x01(\x0b\x32).neptune.api.model.ProtoFloatAttributeDTOH\x01\x88\x01\x01\x12J\n\x11string_properties\x18\x05 \x01(\x0b\x32*.neptune.api.model.ProtoStringAttributeDTOH\x02\x88\x01\x01\x12\x46\n\x0f\x62ool_properties\x18\x06 \x01(\x0b\x32(.neptune.api.model.ProtoBoolAttributeDTOH\x03\x88\x01\x01\x12N\n\x13\x64\x61tetime_properties\x18\x07 \x01(\x0b\x32,.neptune.api.model.ProtoDatetimeAttributeDTOH\x04\x88\x01\x01\x12Q\n\x15string_set_properties\x18\x08 \x01(\x0b\x32-.neptune.api.model.ProtoStringSetAttributeDTOH\x05\x88\x01\x01\x12U\n\x17\x66loat_series_properties\x18\t \x01(\x0b\x32/.neptune.api.model.ProtoFloatSeriesAttributeDTOH\x06\x88\x01\x01\x42\x11\n\x0f_int_propertiesB\x13\n\x11_float_propertiesB\x14\n\x12_string_propertiesB\x12\n\x10_bool_propertiesB\x16\n\x14_datetime_propertiesB\x18\n\x16_string_set_propertiesB\x1a\n\x18_float_series_properties\"U\n\x14ProtoIntAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03\"W\n\x16ProtoFloatAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x01\"X\n\x17ProtoStringAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\"V\n\x15ProtoBoolAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x08\"Z\n\x19ProtoDatetimeAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x03\"[\n\x1aProtoStringSetAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x03(\t\"\xd1\x02\n\x1cProtoFloatSeriesAttributeDTO\x12\x16\n\x0e\x61ttribute_name\x18\x01 \x01(\t\x12\x16\n\x0e\x61ttribute_type\x18\x02 \x01(\t\x12\x16\n\tlast_step\x18\x03 \x01(\x01H\x00\x88\x01\x01\x12\x11\n\x04last\x18\x04 \x01(\x01H\x01\x88\x01\x01\x12\x10\n\x03min\x18\x05 \x01(\x01H\x02\x88\x01\x01\x12\x10\n\x03max\x18\x06 \x01(\x01H\x03\x88\x01\x01\x12\x14\n\x07\x61verage\x18\x07 \x01(\x01H\x04\x88\x01\x01\x12\x15\n\x08variance\x18\x08 \x01(\x01H\x05\x88\x01\x01\x12\x45\n\x06\x63onfig\x18\t \x01(\x0b\x32\x35.neptune.api.model.ProtoFloatSeriesAttributeConfigDTOB\x0c\n\n_last_stepB\x07\n\x05_lastB\x06\n\x04_minB\x06\n\x04_maxB\n\n\x08_averageB\x0b\n\t_variance\"t\n\"ProtoFloatSeriesAttributeConfigDTO\x12\x10\n\x03min\x18\x01 \x01(\x01H\x00\x88\x01\x01\x12\x10\n\x03max\x18\x02 \x01(\x01H\x01\x88\x01\x01\x12\x11\n\x04unit\x18\x03 \x01(\tH\x02\x88\x01\x01\x42\x06\n\x04_minB\x06\n\x04_maxB\x07\n\x05_unitB4\n0ml.neptune.leaderboard.api.model.proto.generatedP\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model.leaderboard_entries_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n0ml.neptune.leaderboard.api.model.proto.generatedP\001' + _globals['_PROTOLEADERBOARDENTRIESSEARCHRESULTDTO']._serialized_start=55 + _globals['_PROTOLEADERBOARDENTRIESSEARCHRESULTDTO']._serialized_end=234 + _globals['_PROTOATTRIBUTESDTO']._serialized_start=237 + _globals['_PROTOATTRIBUTESDTO']._serialized_end=446 + _globals['_PROTOATTRIBUTEDTO']._serialized_start=449 + _globals['_PROTOATTRIBUTEDTO']._serialized_end=1198 + _globals['_PROTOINTATTRIBUTEDTO']._serialized_start=1200 + _globals['_PROTOINTATTRIBUTEDTO']._serialized_end=1285 + _globals['_PROTOFLOATATTRIBUTEDTO']._serialized_start=1287 + _globals['_PROTOFLOATATTRIBUTEDTO']._serialized_end=1374 + _globals['_PROTOSTRINGATTRIBUTEDTO']._serialized_start=1376 + _globals['_PROTOSTRINGATTRIBUTEDTO']._serialized_end=1464 + _globals['_PROTOBOOLATTRIBUTEDTO']._serialized_start=1466 + _globals['_PROTOBOOLATTRIBUTEDTO']._serialized_end=1552 + _globals['_PROTODATETIMEATTRIBUTEDTO']._serialized_start=1554 + _globals['_PROTODATETIMEATTRIBUTEDTO']._serialized_end=1644 + _globals['_PROTOSTRINGSETATTRIBUTEDTO']._serialized_start=1646 + _globals['_PROTOSTRINGSETATTRIBUTEDTO']._serialized_end=1737 + _globals['_PROTOFLOATSERIESATTRIBUTEDTO']._serialized_start=1740 + _globals['_PROTOFLOATSERIESATTRIBUTEDTO']._serialized_end=2077 + _globals['_PROTOFLOATSERIESATTRIBUTECONFIGDTO']._serialized_start=2079 + _globals['_PROTOFLOATSERIESATTRIBUTECONFIGDTO']._serialized_end=2195 +# @@protoc_insertion_point(module_scope) diff --git a/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.pyi b/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.pyi new file mode 100644 index 000000000..7cd4e6ff9 --- /dev/null +++ b/src/neptune/api/proto/neptune_pb/api/model/leaderboard_entries_pb2.pyi @@ -0,0 +1,345 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing_extensions.final +class ProtoLeaderboardEntriesSearchResultDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MATCHING_ITEM_COUNT_FIELD_NUMBER: builtins.int + TOTAL_GROUP_COUNT_FIELD_NUMBER: builtins.int + ENTRIES_FIELD_NUMBER: builtins.int + matching_item_count: builtins.int + total_group_count: builtins.int + @property + def entries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributesDTO]: ... + def __init__( + self, + *, + matching_item_count: builtins.int = ..., + total_group_count: builtins.int | None = ..., + entries: collections.abc.Iterable[global___ProtoAttributesDTO] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_total_group_count", b"_total_group_count", "total_group_count", b"total_group_count"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_total_group_count", b"_total_group_count", "entries", b"entries", "matching_item_count", b"matching_item_count", "total_group_count", b"total_group_count"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["_total_group_count", b"_total_group_count"]) -> typing_extensions.Literal["total_group_count"] | None: ... + +global___ProtoLeaderboardEntriesSearchResultDTO = ProtoLeaderboardEntriesSearchResultDTO + +@typing_extensions.final +class ProtoAttributesDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXPERIMENT_ID_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + PROJECT_ID_FIELD_NUMBER: builtins.int + ORGANIZATION_ID_FIELD_NUMBER: builtins.int + PROJECT_NAME_FIELD_NUMBER: builtins.int + ORGANIZATION_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTES_FIELD_NUMBER: builtins.int + experiment_id: builtins.str + type: builtins.str + project_id: builtins.str + organization_id: builtins.str + project_name: builtins.str + organization_name: builtins.str + @property + def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ProtoAttributeDTO]: ... + def __init__( + self, + *, + experiment_id: builtins.str = ..., + type: builtins.str = ..., + project_id: builtins.str = ..., + organization_id: builtins.str = ..., + project_name: builtins.str = ..., + organization_name: builtins.str = ..., + attributes: collections.abc.Iterable[global___ProtoAttributeDTO] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attributes", b"attributes", "experiment_id", b"experiment_id", "organization_id", b"organization_id", "organization_name", b"organization_name", "project_id", b"project_id", "project_name", b"project_name", "type", b"type"]) -> None: ... + +global___ProtoAttributesDTO = ProtoAttributesDTO + +@typing_extensions.final +class ProtoAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + NAME_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + INT_PROPERTIES_FIELD_NUMBER: builtins.int + FLOAT_PROPERTIES_FIELD_NUMBER: builtins.int + STRING_PROPERTIES_FIELD_NUMBER: builtins.int + BOOL_PROPERTIES_FIELD_NUMBER: builtins.int + DATETIME_PROPERTIES_FIELD_NUMBER: builtins.int + STRING_SET_PROPERTIES_FIELD_NUMBER: builtins.int + FLOAT_SERIES_PROPERTIES_FIELD_NUMBER: builtins.int + name: builtins.str + type: builtins.str + @property + def int_properties(self) -> global___ProtoIntAttributeDTO: ... + @property + def float_properties(self) -> global___ProtoFloatAttributeDTO: ... + @property + def string_properties(self) -> global___ProtoStringAttributeDTO: ... + @property + def bool_properties(self) -> global___ProtoBoolAttributeDTO: ... + @property + def datetime_properties(self) -> global___ProtoDatetimeAttributeDTO: ... + @property + def string_set_properties(self) -> global___ProtoStringSetAttributeDTO: ... + @property + def float_series_properties(self) -> global___ProtoFloatSeriesAttributeDTO: ... + def __init__( + self, + *, + name: builtins.str = ..., + type: builtins.str = ..., + int_properties: global___ProtoIntAttributeDTO | None = ..., + float_properties: global___ProtoFloatAttributeDTO | None = ..., + string_properties: global___ProtoStringAttributeDTO | None = ..., + bool_properties: global___ProtoBoolAttributeDTO | None = ..., + datetime_properties: global___ProtoDatetimeAttributeDTO | None = ..., + string_set_properties: global___ProtoStringSetAttributeDTO | None = ..., + float_series_properties: global___ProtoFloatSeriesAttributeDTO | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_bool_properties", b"_bool_properties", "_datetime_properties", b"_datetime_properties", "_float_properties", b"_float_properties", "_float_series_properties", b"_float_series_properties", "_int_properties", b"_int_properties", "_string_properties", b"_string_properties", "_string_set_properties", b"_string_set_properties", "bool_properties", b"bool_properties", "datetime_properties", b"datetime_properties", "float_properties", b"float_properties", "float_series_properties", b"float_series_properties", "int_properties", b"int_properties", "string_properties", b"string_properties", "string_set_properties", b"string_set_properties"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_bool_properties", b"_bool_properties", "_datetime_properties", b"_datetime_properties", "_float_properties", b"_float_properties", "_float_series_properties", b"_float_series_properties", "_int_properties", b"_int_properties", "_string_properties", b"_string_properties", "_string_set_properties", b"_string_set_properties", "bool_properties", b"bool_properties", "datetime_properties", b"datetime_properties", "float_properties", b"float_properties", "float_series_properties", b"float_series_properties", "int_properties", b"int_properties", "name", b"name", "string_properties", b"string_properties", "string_set_properties", b"string_set_properties", "type", b"type"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_bool_properties", b"_bool_properties"]) -> typing_extensions.Literal["bool_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_datetime_properties", b"_datetime_properties"]) -> typing_extensions.Literal["datetime_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_float_properties", b"_float_properties"]) -> typing_extensions.Literal["float_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_float_series_properties", b"_float_series_properties"]) -> typing_extensions.Literal["float_series_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_int_properties", b"_int_properties"]) -> typing_extensions.Literal["int_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_string_properties", b"_string_properties"]) -> typing_extensions.Literal["string_properties"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_string_set_properties", b"_string_set_properties"]) -> typing_extensions.Literal["string_set_properties"] | None: ... + +global___ProtoAttributeDTO = ProtoAttributeDTO + +@typing_extensions.final +class ProtoIntAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + value: builtins.int + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoIntAttributeDTO = ProtoIntAttributeDTO + +@typing_extensions.final +class ProtoFloatAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + value: builtins.float + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: builtins.float = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoFloatAttributeDTO = ProtoFloatAttributeDTO + +@typing_extensions.final +class ProtoStringAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + value: builtins.str + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoStringAttributeDTO = ProtoStringAttributeDTO + +@typing_extensions.final +class ProtoBoolAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + value: builtins.bool + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoBoolAttributeDTO = ProtoBoolAttributeDTO + +@typing_extensions.final +class ProtoDatetimeAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + value: builtins.int + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: builtins.int = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoDatetimeAttributeDTO = ProtoDatetimeAttributeDTO + +@typing_extensions.final +class ProtoStringSetAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + @property + def value(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + value: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "value", b"value"]) -> None: ... + +global___ProtoStringSetAttributeDTO = ProtoStringSetAttributeDTO + +@typing_extensions.final +class ProtoFloatSeriesAttributeDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ATTRIBUTE_NAME_FIELD_NUMBER: builtins.int + ATTRIBUTE_TYPE_FIELD_NUMBER: builtins.int + LAST_STEP_FIELD_NUMBER: builtins.int + LAST_FIELD_NUMBER: builtins.int + MIN_FIELD_NUMBER: builtins.int + MAX_FIELD_NUMBER: builtins.int + AVERAGE_FIELD_NUMBER: builtins.int + VARIANCE_FIELD_NUMBER: builtins.int + CONFIG_FIELD_NUMBER: builtins.int + attribute_name: builtins.str + attribute_type: builtins.str + last_step: builtins.float + last: builtins.float + min: builtins.float + max: builtins.float + average: builtins.float + variance: builtins.float + @property + def config(self) -> global___ProtoFloatSeriesAttributeConfigDTO: ... + def __init__( + self, + *, + attribute_name: builtins.str = ..., + attribute_type: builtins.str = ..., + last_step: builtins.float | None = ..., + last: builtins.float | None = ..., + min: builtins.float | None = ..., + max: builtins.float | None = ..., + average: builtins.float | None = ..., + variance: builtins.float | None = ..., + config: global___ProtoFloatSeriesAttributeConfigDTO | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_average", b"_average", "_last", b"_last", "_last_step", b"_last_step", "_max", b"_max", "_min", b"_min", "_variance", b"_variance", "average", b"average", "config", b"config", "last", b"last", "last_step", b"last_step", "max", b"max", "min", b"min", "variance", b"variance"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_average", b"_average", "_last", b"_last", "_last_step", b"_last_step", "_max", b"_max", "_min", b"_min", "_variance", b"_variance", "attribute_name", b"attribute_name", "attribute_type", b"attribute_type", "average", b"average", "config", b"config", "last", b"last", "last_step", b"last_step", "max", b"max", "min", b"min", "variance", b"variance"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_average", b"_average"]) -> typing_extensions.Literal["average"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_last", b"_last"]) -> typing_extensions.Literal["last"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_last_step", b"_last_step"]) -> typing_extensions.Literal["last_step"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_max", b"_max"]) -> typing_extensions.Literal["max"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_min", b"_min"]) -> typing_extensions.Literal["min"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_variance", b"_variance"]) -> typing_extensions.Literal["variance"] | None: ... + +global___ProtoFloatSeriesAttributeDTO = ProtoFloatSeriesAttributeDTO + +@typing_extensions.final +class ProtoFloatSeriesAttributeConfigDTO(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MIN_FIELD_NUMBER: builtins.int + MAX_FIELD_NUMBER: builtins.int + UNIT_FIELD_NUMBER: builtins.int + min: builtins.float + max: builtins.float + unit: builtins.str + def __init__( + self, + *, + min: builtins.float | None = ..., + max: builtins.float | None = ..., + unit: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_max", b"_max", "_min", b"_min", "_unit", b"_unit", "max", b"max", "min", b"min", "unit", b"unit"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_max", b"_max", "_min", b"_min", "_unit", b"_unit", "max", b"max", "min", b"min", "unit", b"unit"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_max", b"_max"]) -> typing_extensions.Literal["max"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_min", b"_min"]) -> typing_extensions.Literal["min"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_unit", b"_unit"]) -> typing_extensions.Literal["unit"] | None: ... + +global___ProtoFloatSeriesAttributeConfigDTO = ProtoFloatSeriesAttributeConfigDTO diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 4c416c798..0bf43c979 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -54,6 +54,7 @@ StringSeriesField, StringSetField, ) +from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import ProtoAttributesSearchResultDTO from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE @@ -1118,6 +1119,57 @@ def get_model_version_url( base_url = self.get_display_address() return f"{base_url}/{workspace}/{project_name}/m/{model_id}/v/{sys_id}" + def get_attribute_definitions( + self, + container_id: str, + container_type: ContainerType, + filter_types: Optional[List[str]] = None, + use_proto: bool = False, + ) -> List[FieldDefinition]: + params = { + "experimentIdentifier": container_id, + **DEFAULT_REQUEST_KWARGS, + } + + try: + if use_proto: + result = self.leaderboard_client.api.queryAttributeDefinitionsProto(**params).response().result + result = ProtoAttributesSearchResultDTO.FromString(result) + else: + result = self.leaderboard_client.api.queryAttributeDefinitions(**params).response().result + + attributes = result.entries + + if filter_types is not None: + attributes = [attr for attr in attributes if attr.type in filter_types] + + return [FieldDefinition(attr.name, FieldType(attr.type)) for attr in attributes] + except HTTPNotFound as e: + raise ContainerUUIDNotFound( + container_id=container_id, + container_type=container_type, + ) from e + + def get_attributes_with_paths_filter( + self, container_id: str, container_type: ContainerType, paths: List[str] + ) -> Any: + params = { + "holderIdentifier": container_id, + "holderType": "experiment", + "attributeQuery": { + "attributePathsFilter": paths, + }, + **DEFAULT_REQUEST_KWARGS, + } + + try: + return self.leaderboard_client.api.getAttributesWithPathsFilter(**params).response().result + except HTTPNotFound as e: + raise ContainerUUIDNotFound( + container_id=container_id, + container_type=container_type, + ) from e + def _get_column_type_from_entries(entries: List[Any], column: str) -> str: if not entries: # column chosen is not present in the table diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index f6d21dbeb..aa2970a06 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -295,6 +295,16 @@ def get_model_version_url( ) -> str: pass + # WARN: Used in Neptune Fetcher + @abc.abstractmethod + def get_attribute_definitions( + self, + container_id: str, + container_type: ContainerType, + filter_types: Optional[List[str]] = None, + use_proto: bool = False, + ) -> List[Attribute]: ... + @abc.abstractmethod def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 0a8a1b515..f808a0752 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -514,6 +514,15 @@ def get_model_version_url( ) -> str: return f"offline/{model_version_id}" + def get_attribute_definitions( + self, + container_id: str, + container_type: ContainerType, + filter_types: Optional[List[str]] = None, + use_proto: bool = False, + ) -> List[Attribute]: + return [] + def _get_attribute_values(self, value_dict, path_prefix: List[str]): assert isinstance(value_dict, dict) for k, value in value_dict.items(): From 2cf5bab9456f809676b4ce04f868aa455415d32b Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 8 Apr 2024 19:13:01 +0200 Subject: [PATCH 2/8] CHANGELOG updated --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac9f937af..2a3b33725 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ - Stop sending `X-Neptune-LegacyClient` header ([#1715](https://github.com/neptune-ai/neptune-client/pull/1715)) - Use `tqdm.auto` ([#1717](https://github.com/neptune-ai/neptune-client/pull/1717)) - Fields DTO conversion reworked ([#1722](https://github.com/neptune-ai/neptune-client/pull/1722)) +- Added support for Protocol Buffers ([#1728](https://github.com/neptune-ai/neptune-client/pull/1728)) ### Features - ? From 6320d7676ca3640620641a65d0f892758484a9cc Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 8 Apr 2024 19:38:21 +0200 Subject: [PATCH 3/8] Added converters --- src/neptune/api/models.py | 115 +++++++++++++++++- src/neptune/api/searching_entries.py | 1 + .../backends/hosted_neptune_backend.py | 3 +- 3 files changed, 117 insertions(+), 2 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index a1604ed9d..b3a37a0d4 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -44,7 +44,10 @@ import abc from dataclasses import dataclass from dataclasses import field as dataclass_field -from datetime import datetime +from datetime import ( + datetime, + timezone, +) from enum import Enum from typing import ( Any, @@ -58,6 +61,19 @@ TypeVar, ) +from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import ProtoAttributeDefinitionDTO +from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ( + ProtoAttributeDTO, + ProtoAttributesDTO, + ProtoBoolAttributeDTO, + ProtoDatetimeAttributeDTO, + ProtoFloatAttributeDTO, + ProtoFloatSeriesAttributeDTO, + ProtoIntAttributeDTO, + ProtoLeaderboardEntriesSearchResultDTO, + ProtoStringAttributeDTO, + ProtoStringSetAttributeDTO, +) from neptune.internal.utils.iso_dates import parse_iso_date from neptune.internal.utils.run_state import RunState @@ -122,6 +138,11 @@ def from_model(model: Any) -> Field: field_type = str(model.type) return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}Properties")) + @staticmethod + def from_proto(data: Any) -> Field: + # TODO: Implement + raise NotImplementedError() + class FieldVisitor(Generic[Ret], abc.ABC): @@ -189,6 +210,10 @@ def from_dict(data: Dict[str, Any]) -> FloatField: def from_model(model: Any) -> FloatField: return FloatField(path=model.attributeName, value=model.value) + @staticmethod + def from_proto(data: ProtoFloatAttributeDTO) -> Field: + return FloatField(path=data.attribute_name, value=data.value) + @dataclass class IntField(Field, field_type=FieldType.INT): @@ -205,6 +230,10 @@ def from_dict(data: Dict[str, Any]) -> IntField: def from_model(model: Any) -> IntField: return IntField(path=model.attributeName, value=model.value) + @staticmethod + def from_proto(data: ProtoIntAttributeDTO) -> IntField: + return IntField(path=data.attribute_name, value=data.value) + @dataclass class BoolField(Field, field_type=FieldType.BOOL): @@ -221,6 +250,10 @@ def from_dict(data: Dict[str, Any]) -> BoolField: def from_model(model: Any) -> BoolField: return BoolField(path=model.attributeName, value=model.value) + @staticmethod + def from_proto(data: ProtoBoolAttributeDTO) -> BoolField: + return BoolField(path=data.attribute_name, value=data.value) + @dataclass class StringField(Field, field_type=FieldType.STRING): @@ -237,6 +270,10 @@ def from_dict(data: Dict[str, Any]) -> StringField: def from_model(model: Any) -> StringField: return StringField(path=model.attributeName, value=model.value) + @staticmethod + def from_proto(data: ProtoStringAttributeDTO) -> StringField: + return StringField(path=data.attribute_name, value=data.value) + @dataclass class DateTimeField(Field, field_type=FieldType.DATETIME): @@ -253,6 +290,11 @@ def from_dict(data: Dict[str, Any]) -> DateTimeField: def from_model(model: Any) -> DateTimeField: return DateTimeField(path=model.attributeName, value=parse_iso_date(model.value)) + @staticmethod + def from_proto(data: ProtoDatetimeAttributeDTO) -> DateTimeField: + # TODO: Ensure that the timestamp is in UTC + return DateTimeField(path=data.attribute_name, value=datetime.fromtimestamp(data.value, tz=timezone.utc)) + @dataclass class FileField(Field, field_type=FieldType.FILE): @@ -271,6 +313,11 @@ def from_dict(data: Dict[str, Any]) -> FileField: def from_model(model: Any) -> FileField: return FileField(path=model.attributeName, name=model.name, ext=model.ext, size=model.size) + @staticmethod + def from_proto(data: Any) -> FileField: + # TODO: implement + raise NotImplementedError() + @dataclass class FileSetField(Field, field_type=FieldType.FILE_SET): @@ -287,6 +334,11 @@ def from_dict(data: Dict[str, Any]) -> FileSetField: def from_model(model: Any) -> FileSetField: return FileSetField(path=model.attributeName, size=model.size) + @staticmethod + def from_proto(data: Any) -> FileSetField: + # TODO: implement + raise NotImplementedError() + @dataclass class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES): @@ -304,6 +356,11 @@ def from_dict(data: Dict[str, Any]) -> FloatSeriesField: def from_model(model: Any) -> FloatSeriesField: return FloatSeriesField(path=model.attributeName, last=model.last) + @staticmethod + def from_proto(data: ProtoFloatSeriesAttributeDTO) -> FloatSeriesField: + last = data.last if data.HasField("last") else None + return FloatSeriesField(path=data.attribute_name, last=last) + @dataclass class StringSeriesField(Field, field_type=FieldType.STRING_SERIES): @@ -321,6 +378,11 @@ def from_dict(data: Dict[str, Any]) -> StringSeriesField: def from_model(model: Any) -> StringSeriesField: return StringSeriesField(path=model.attributeName, last=model.last) + @staticmethod + def from_proto(data: Any) -> StringSeriesField: + # TODO: implement + raise NotImplementedError() + @dataclass class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES): @@ -338,6 +400,11 @@ def from_dict(data: Dict[str, Any]) -> ImageSeriesField: def from_model(model: Any) -> ImageSeriesField: return ImageSeriesField(path=model.attributeName, last_step=model.lastStep) + @staticmethod + def from_proto(data: Any) -> ImageSeriesField: + # TODO: implement + raise NotImplementedError() + @dataclass class StringSetField(Field, field_type=FieldType.STRING_SET): @@ -354,6 +421,10 @@ def from_dict(data: Dict[str, Any]) -> StringSetField: def from_model(model: Any) -> StringSetField: return StringSetField(path=model.attributeName, values=set(model.values)) + @staticmethod + def from_proto(data: ProtoStringSetAttributeDTO) -> StringSetField: + return StringSetField(path=data.attribute_name, values=set(data.value)) + @dataclass class GitCommit: @@ -368,6 +439,11 @@ def from_dict(data: Dict[str, Any]) -> GitCommit: def from_model(model: Any) -> GitCommit: return GitCommit(commit_id=model.commitId) + @staticmethod + def from_proto(data: Any) -> GitCommit: + # TODO: implement + raise NotImplementedError() + @dataclass class GitRefField(Field, field_type=FieldType.GIT_REF): @@ -386,6 +462,11 @@ def from_model(model: Any) -> GitRefField: commit = GitCommit.from_model(model.commit) if model.commit is not None else None return GitRefField(path=model.attributeName, commit=commit) + @staticmethod + def from_proto(data: ProtoAttributeDTO) -> GitRefField: + # TODO: implement + raise NotImplementedError() + @dataclass class ObjectStateField(Field, field_type=FieldType.OBJECT_STATE): @@ -404,6 +485,11 @@ def from_model(model: Any) -> ObjectStateField: value = RunState.from_api(str(model.value)).value return ObjectStateField(path=model.attributeName, value=value) + @staticmethod + def from_proto(data: Any) -> ObjectStateField: + # TODO: implement + raise NotImplementedError() + @dataclass class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF): @@ -421,6 +507,11 @@ def from_dict(data: Dict[str, Any]) -> NotebookRefField: def from_model(model: Any) -> NotebookRefField: return NotebookRefField(path=model.attributeName, notebook_name=model.notebookName) + @staticmethod + def from_proto(data: Any) -> NotebookRefField: + # TODO: implement + raise NotImplementedError() + @dataclass class ArtifactField(Field, field_type=FieldType.ARTIFACT): @@ -437,6 +528,11 @@ def from_dict(data: Dict[str, Any]) -> ArtifactField: def from_model(model: Any) -> ArtifactField: return ArtifactField(path=model.attributeName, hash=model.hash) + @staticmethod + def from_proto(data: Any) -> ArtifactField: + # TODO: implement + raise NotImplementedError() + @dataclass class LeaderboardEntry: @@ -455,6 +551,12 @@ def from_model(model: Any) -> LeaderboardEntry: object_id=model.experimentId, fields=[Field.from_model(field) for field in model.attributes] ) + @staticmethod + def from_proto(data: ProtoAttributesDTO) -> LeaderboardEntry: + return LeaderboardEntry( + object_id=data.experiment_id, fields=[Field.from_proto(field) for field in data.attributes] + ) + @dataclass class LeaderboardEntriesSearchResult: @@ -475,6 +577,13 @@ def from_model(result: Any) -> LeaderboardEntriesSearchResult: matching_item_count=result.matchingItemCount, ) + @staticmethod + def from_proto(data: ProtoLeaderboardEntriesSearchResultDTO) -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + entries=[LeaderboardEntry.from_proto(entry) for entry in data.entries], + matching_item_count=data.matching_item_count, + ) + @dataclass class FieldDefinition: @@ -488,3 +597,7 @@ def from_dict(data: Dict[str, Any]) -> FieldDefinition: @staticmethod def from_model(model: Any) -> FieldDefinition: return FieldDefinition(path=model.name, type=FieldType(model.type)) + + @staticmethod + def from_proto(data: ProtoAttributeDefinitionDTO) -> FieldDefinition: + return FieldDefinition(path=data.name, type=FieldType(data.type)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index e1f6cc572..056de1886 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -146,6 +146,7 @@ def get_single_page( http_client = client.swagger_spec.http_client try: + # TODO: Allow to fetch using protocol buffers return ( http_client.request(request_params, operation=None, request_config=request_config) .response() diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 0bf43c979..3180ae9f2 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1124,7 +1124,7 @@ def get_attribute_definitions( container_id: str, container_type: ContainerType, filter_types: Optional[List[str]] = None, - use_proto: bool = False, + use_proto: bool = False, # TODO: Use environment variable instead ) -> List[FieldDefinition]: params = { "experimentIdentifier": container_id, @@ -1132,6 +1132,7 @@ def get_attribute_definitions( } try: + # TODO: Rework as we should call `from_model`/`from_proto` on the result if use_proto: result = self.leaderboard_client.api.queryAttributeDefinitionsProto(**params).response().result result = ProtoAttributesSearchResultDTO.FromString(result) From 505d31f6c5d3513438b3ec0fe6314e181902182c Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 9 Apr 2024 11:01:24 +0200 Subject: [PATCH 4/8] Implemented all methods and unittests added --- src/neptune/api/models.py | 28 +- src/neptune/api/searching_entries.py | 37 +- .../backends/hosted_neptune_backend.py | 35 +- .../internal/backends/neptune_backend.py | 14 +- .../internal/backends/neptune_backend_mock.py | 13 +- .../backends/offline_neptune_backend.py | 14 + tests/unit/neptune/new/api/test_models.py | 610 ++++++++++++++++++ .../neptune/new/api/test_searching_entries.py | 73 +-- 8 files changed, 731 insertions(+), 93 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index b3a37a0d4..71ccd3ddb 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -42,6 +42,7 @@ ) import abc +import re from dataclasses import dataclass from dataclasses import field as dataclass_field from datetime import ( @@ -140,8 +141,15 @@ def from_model(model: Any) -> Field: @staticmethod def from_proto(data: Any) -> Field: - # TODO: Implement - raise NotImplementedError() + field_type = str(data.type) + return Field._registry[field_type].from_proto(data.__getattribute__(f"{camel_to_snake(field_type)}_properties")) + + +def camel_to_snake(name: str) -> str: + # Insert an underscore before any uppercase letters and convert the string to lowercase + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + # Handle the case where there are uppercase letters in the middle of the name + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() class FieldVisitor(Generic[Ret], abc.ABC): @@ -211,7 +219,7 @@ def from_model(model: Any) -> FloatField: return FloatField(path=model.attributeName, value=model.value) @staticmethod - def from_proto(data: ProtoFloatAttributeDTO) -> Field: + def from_proto(data: ProtoFloatAttributeDTO) -> FloatField: return FloatField(path=data.attribute_name, value=data.value) @@ -293,6 +301,7 @@ def from_model(model: Any) -> DateTimeField: @staticmethod def from_proto(data: ProtoDatetimeAttributeDTO) -> DateTimeField: # TODO: Ensure that the timestamp is in UTC + # TODO: Ensure that we are supporting seconds and miliseconds return DateTimeField(path=data.attribute_name, value=datetime.fromtimestamp(data.value, tz=timezone.utc)) @@ -553,8 +562,19 @@ def from_model(model: Any) -> LeaderboardEntry: @staticmethod def from_proto(data: ProtoAttributesDTO) -> LeaderboardEntry: + with_proto_support = { + FieldType.STRING.value, + FieldType.BOOL.value, + FieldType.INT.value, + FieldType.FLOAT.value, + FieldType.DATETIME.value, + FieldType.STRING_SET.value, + FieldType.FLOAT_SERIES.value, + } + return LeaderboardEntry( - object_id=data.experiment_id, fields=[Field.from_proto(field) for field in data.attributes] + object_id=data.experiment_id, + fields=[Field.from_proto(field) for field in data.attributes if str(field.type) in with_proto_support], ) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 056de1886..ee405c475 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -39,6 +39,7 @@ LeaderboardEntriesSearchResult, LeaderboardEntry, ) +from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ProtoLeaderboardEntriesSearchResultDTO from neptune.exceptions import NeptuneInvalidQueryException from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS from neptune.internal.backends.nql import ( @@ -97,7 +98,8 @@ def get_single_page( types: Optional[Iterable[str]], query: Optional["NQLQuery"], searching_after: Optional[str], -) -> Any: + use_proto: Optional[bool] = False, +) -> LeaderboardEntriesSearchResult: normalized_query = query or NQLEmptyQuery() sort_by_column_type = sort_by_column_type if sort_by_column_type else FieldType.STRING.value if sort_by and searching_after: @@ -139,19 +141,25 @@ def get_single_page( }, } - request_options = DEFAULT_REQUEST_KWARGS.get("_request_options", {}) - request_config = RequestConfig(request_options, True) - request_params = construct_request(client.api.searchLeaderboardEntries, request_options, **params) + try: + if use_proto: + result = client.api.searchLeaderboardEntriesProto(**params).response().result + proto_data = ProtoLeaderboardEntriesSearchResultDTO.FromString(result) + return LeaderboardEntriesSearchResult.from_proto(proto_data) + else: + request_options = DEFAULT_REQUEST_KWARGS.get("_request_options", {}) + request_config = RequestConfig(request_options, True) + request_params = construct_request(client.api.searchLeaderboardEntries, request_options, **params) - http_client = client.swagger_spec.http_client + http_client = client.swagger_spec.http_client - try: - # TODO: Allow to fetch using protocol buffers - return ( - http_client.request(request_params, operation=None, request_config=request_config) - .response() - .incoming_response.json() - ) + json_data = ( + http_client.request(request_params, operation=None, request_config=request_config) + .response() + .incoming_response.json() + ) + + return LeaderboardEntriesSearchResult.from_dict(json_data) except HTTPBadRequest as e: title = e.response.json().get("title") if title == "Syntax error": @@ -186,7 +194,7 @@ def iter_over_pages( searching_after=None, **kwargs, ) - total = LeaderboardEntriesSearchResult.from_dict(data).matching_item_count + total = data.matching_item_count limit = limit if limit is not None else NoLimit() @@ -217,7 +225,7 @@ def iter_over_pages( if extracted_records + local_limit > limit: local_limit = limit - extracted_records - data = get_single_page( + result = get_single_page( limit=local_limit, offset=offset, sort_by=sort_by, @@ -226,7 +234,6 @@ def iter_over_pages( ascending=ascending, **kwargs, ) - result = LeaderboardEntriesSearchResult.from_dict(data) # fetch the item count everytime a new page is started (except for the very fist page) if offset == 0 and last_page is not None: diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 3180ae9f2..0fb69ad9c 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -42,6 +42,7 @@ ArtifactField, BoolField, DateTimeField, + Field, FieldDefinition, FieldType, FileEntry, @@ -55,6 +56,7 @@ StringSetField, ) from neptune.api.proto.neptune_pb.api.model.attributes_pb2 import ProtoAttributesSearchResultDTO +from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ProtoAttributesDTO from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE @@ -1119,12 +1121,11 @@ def get_model_version_url( base_url = self.get_display_address() return f"{base_url}/{workspace}/{project_name}/m/{model_id}/v/{sys_id}" - def get_attribute_definitions( + def get_fields_definitions( self, container_id: str, container_type: ContainerType, - filter_types: Optional[List[str]] = None, - use_proto: bool = False, # TODO: Use environment variable instead + use_proto: Optional[bool] = False, # TODO: Use environment variable instead ) -> List[FieldDefinition]: params = { "experimentIdentifier": container_id, @@ -1132,28 +1133,22 @@ def get_attribute_definitions( } try: - # TODO: Rework as we should call `from_model`/`from_proto` on the result if use_proto: result = self.leaderboard_client.api.queryAttributeDefinitionsProto(**params).response().result - result = ProtoAttributesSearchResultDTO.FromString(result) + data = ProtoAttributesSearchResultDTO.FromString(result) + return [FieldDefinition.from_proto(field_def) for field_def in data.entries] else: - result = self.leaderboard_client.api.queryAttributeDefinitions(**params).response().result - - attributes = result.entries - - if filter_types is not None: - attributes = [attr for attr in attributes if attr.type in filter_types] - - return [FieldDefinition(attr.name, FieldType(attr.type)) for attr in attributes] + data = self.leaderboard_client.api.queryAttributeDefinitions(**params).response().result + return [FieldDefinition.from_model(field_def) for field_def in data.entries] except HTTPNotFound as e: raise ContainerUUIDNotFound( container_id=container_id, container_type=container_type, ) from e - def get_attributes_with_paths_filter( - self, container_id: str, container_type: ContainerType, paths: List[str] - ) -> Any: + def get_fields_with_paths_filter( + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + ) -> List[Field]: params = { "holderIdentifier": container_id, "holderType": "experiment", @@ -1164,7 +1159,13 @@ def get_attributes_with_paths_filter( } try: - return self.leaderboard_client.api.getAttributesWithPathsFilter(**params).response().result + if use_proto: + result = self.leaderboard_client.api.getAttributesWithPathsFilterProto(**params).response().result + data = ProtoAttributesDTO.FromString(result) + return [Field.from_proto(field) for field in data.attributes] + else: + data = self.leaderboard_client.api.getAttributesWithPathsFilter(**params).response().result + return [Field.from_model(field) for field in data.attributes] except HTTPNotFound as e: raise ContainerUUIDNotFound( container_id=container_id, diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index aa2970a06..b87f6557e 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -29,6 +29,7 @@ ArtifactField, BoolField, DateTimeField, + Field, FieldDefinition, FieldType, FileEntry, @@ -297,13 +298,18 @@ def get_model_version_url( # WARN: Used in Neptune Fetcher @abc.abstractmethod - def get_attribute_definitions( + def get_fields_definitions( self, container_id: str, container_type: ContainerType, - filter_types: Optional[List[str]] = None, - use_proto: bool = False, - ) -> List[Attribute]: ... + use_proto: Optional[bool] = False, + ) -> List[FieldDefinition]: ... + + # WARN: Used in Neptune Fetcher + @abc.abstractmethod + def get_fields_with_paths_filter( + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + ) -> List[Field]: ... @abc.abstractmethod def fetch_atom_attribute_values( diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index f808a0752..5e145bb79 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -38,6 +38,7 @@ ArtifactField, BoolField, DateTimeField, + Field, FieldDefinition, FieldType, FileEntry, @@ -514,13 +515,12 @@ def get_model_version_url( ) -> str: return f"offline/{model_version_id}" - def get_attribute_definitions( + def get_fields_definitions( self, container_id: str, container_type: ContainerType, - filter_types: Optional[List[str]] = None, - use_proto: bool = False, - ) -> List[Attribute]: + use_proto: Optional[bool] = False, + ) -> List[FieldDefinition]: return [] def _get_attribute_values(self, value_dict, path_prefix: List[str]): @@ -796,3 +796,8 @@ def list_fileset_files(self, attribute: List[str], container_id: str, path: str) file_type="file", ) ] + + def get_fields_with_paths_filter( + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + ) -> List[Field]: + return [] diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index 48f2b8264..032787c04 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -24,6 +24,7 @@ ArtifactField, BoolField, DateTimeField, + Field, FieldDefinition, FileEntry, FileField, @@ -138,3 +139,16 @@ def download_file_series_by_index( def list_fileset_files(self, attribute: List[str], container_id: str, path: str) -> List[FileEntry]: raise NeptuneOfflineModeFetchException + + def get_fields_with_paths_filter( + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + ) -> List[Field]: + raise NeptuneOfflineModeFetchException + + def get_fields_definitions( + self, + container_id: str, + container_type: ContainerType, + use_proto: Optional[bool] = False, + ) -> List[FieldDefinition]: + raise NeptuneOfflineModeFetchException diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py index 8e09b1974..2300cb7d6 100644 --- a/tests/unit/neptune/new/api/test_models.py +++ b/tests/unit/neptune/new/api/test_models.py @@ -41,6 +41,18 @@ StringSeriesField, StringSetField, ) +from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ( + ProtoAttributeDTO, + ProtoAttributesDTO, + ProtoBoolAttributeDTO, + ProtoDatetimeAttributeDTO, + ProtoFloatAttributeDTO, + ProtoFloatSeriesAttributeDTO, + ProtoIntAttributeDTO, + ProtoLeaderboardEntriesSearchResultDTO, + ProtoStringAttributeDTO, + ProtoStringSetAttributeDTO, +) def test__float_field__from_dict(): @@ -67,6 +79,22 @@ def test__float_field__from_model(): assert result.value == 18.5 +def test__float_field__from_proto(): + # given + proto = ProtoFloatAttributeDTO( + attribute_name="some/float", + attribute_type="float", + value=18.5, + ) + + # when + result = FloatField.from_proto(proto) + + # then + assert result.path == "some/float" + assert result.value == 18.5 + + def test__int_field__from_dict(): # given data = {"attributeType": "int", "attributeName": "some/int", "value": 18} @@ -91,6 +119,22 @@ def test__int_field__from_model(): assert result.value == 18 +def test__int_field__from_proto(): + # given + proto = ProtoIntAttributeDTO( + attribute_name="some/int", + attribute_type="int", + value=18, + ) + + # when + result = IntField.from_proto(proto) + + # then + assert result.path == "some/int" + assert result.value == 18 + + def test__string_field__from_dict(): # given data = {"attributeType": "string", "attributeName": "some/string", "value": "hello"} @@ -115,6 +159,22 @@ def test__string_field__from_model(): assert result.value == "hello" +def test__string_field__from_proto(): + # given + proto = ProtoStringAttributeDTO( + attribute_name="some/string", + attribute_type="string", + value="hello", + ) + + # when + result = StringField.from_proto(proto) + + # then + assert result.path == "some/string" + assert result.value == "hello" + + def test__string_field__from_dict__empty(): # given data = {"attributeType": "string", "attributeName": "some/string", "value": ""} @@ -139,6 +199,22 @@ def test__string_field__from_model__empty(): assert result.value == "" +def test__string_field__from_proto__empty(): + # given + proto = ProtoStringAttributeDTO( + attribute_name="some/string", + attribute_type="string", + value="", + ) + + # when + result = StringField.from_proto(proto) + + # then + assert result.path == "some/string" + assert result.value == "" + + def test__bool_field__from_dict(): # given data = {"attributeType": "bool", "attributeName": "some/bool", "value": True} @@ -163,6 +239,22 @@ def test__bool_field__from_model(): assert result.value is True +def test__bool_field__from_proto(): + # given + proto = ProtoBoolAttributeDTO( + attribute_name="some/bool", + attribute_type="bool", + value=True, + ) + + # when + result = BoolField.from_proto(proto) + + # then + assert result.path == "some/bool" + assert result.value is True + + def test__datetime_field__from_dict(): # given data = {"attributeType": "datetime", "attributeName": "some/datetime", "value": "2024-01-01T00:12:34.567890Z"} @@ -187,6 +279,22 @@ def test__datetime_field__from_model(): assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) +def test__datetime_field__from_proto(): + # given + at = datetime.datetime(2024, 1, 1, 0, 12, 34, tzinfo=datetime.timezone.utc) + + proto = ProtoDatetimeAttributeDTO( + attribute_name="some/datetime", attribute_type="datetime", value=int(at.timestamp()) + ) + + # when + result = DateTimeField.from_proto(proto) + + # then + assert result.path == "some/datetime" + assert result.value == at + + def test__float_series_field__from_dict(): # given data = { @@ -250,6 +358,38 @@ def test__float_series_field__from_model__no_last(): assert result.last is None +def test__float_series_field__from_proto(): + # given + proto = ProtoFloatSeriesAttributeDTO( + attribute_name="some/floatSeries", + attribute_type="floatSeries", + last=19.5, + ) + + # when + result = FloatSeriesField.from_proto(proto) + + # then + assert result.path == "some/floatSeries" + assert result.last == 19.5 + + +def test__float_series_field__from_proto__no_last(): + # given + proto = ProtoFloatSeriesAttributeDTO( + attribute_name="some/floatSeries", + attribute_type="floatSeries", + last=None, + ) + + # when + result = FloatSeriesField.from_proto(proto) + + # then + assert result.path == "some/floatSeries" + assert result.last is None + + def test__string_series_field__from_dict(): # given data = { @@ -313,6 +453,15 @@ def test__string_series_field__from_model__no_last(): assert result.last is None +def test__string_series_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + StringSeriesField.from_proto(proto) + + def test__image_series_field__from_dict(): # given data = { @@ -376,6 +525,15 @@ def test__image_series_field__from_model__no_last_step(): assert result.last_step is None +def test__image_series_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + ImageSeriesField.from_proto(proto) + + def test__string_set_field__from_dict(): # given data = { @@ -440,6 +598,38 @@ def test__string_set_field__from_model__empty(): assert result.values == set() +def test__string_set_field__from_proto(): + # given + proto = ProtoStringSetAttributeDTO( + attribute_name="some/stringSet", + attribute_type="stringSet", + value=["hello", "world"], + ) + + # when + result = StringSetField.from_proto(proto) + + # then + assert result.path == "some/stringSet" + assert result.values == {"hello", "world"} + + +def test__string_set_field__from_proto__empty(): + # given + proto = ProtoStringSetAttributeDTO( + attribute_name="some/stringSet", + attribute_type="stringSet", + value=[], + ) + + # when + result = StringSetField.from_proto(proto) + + # then + assert result.path == "some/stringSet" + assert result.values == set() + + def test__file_field__from_dict(): # given data = { @@ -480,6 +670,15 @@ def test__file_field__from_model(): assert result.ext == "txt" +def test__file_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + FileField.from_proto(proto) + + @pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) def test__object_state_field__from_dict(state, expected): # given @@ -506,6 +705,16 @@ def test__object_state_field__from_model(state, expected): assert result.value == expected +@pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) +def test__object_state_field__from_proto(state, expected): + # given + model = Mock() + + # then + with pytest.raises(NotImplementedError): + ObjectStateField.from_proto(model) + + def test__file_set_field__from_dict(): # given data = { @@ -538,6 +747,15 @@ def test__file_set_field__from_model(): assert result.size == 3072 +def test__file_set_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + FileSetField.from_proto(proto) + + def test__notebook_ref_field__from_dict(): # given data = { @@ -601,6 +819,24 @@ def test__notebook_ref_field__from_model__no_notebook_name(): assert result.notebook_name is None +def test__notebook_ref_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + NotebookRefField.from_proto(proto) + + +def test__notebook_ref_field__from_proto__no_notebook_name(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + NotebookRefField.from_proto(proto) + + def test__git_ref_field__from_dict(): # given data = { @@ -668,6 +904,24 @@ def test__git_ref_field__from_model__no_commit(): assert result.commit is None +def test__git_ref_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + GitRefField.from_proto(proto) + + +def test__git_ref_field__from_proto__no_commit(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + GitRefField.from_proto(proto) + + def test__artifact_field__from_dict(): # given data = { @@ -700,6 +954,15 @@ def test__artifact_field__from_model(): assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" +def test__artifact_field__from_proto(): + # given + proto = Mock() + + # then + with pytest.raises(NotImplementedError): + ArtifactField.from_proto(proto) + + def test__field__from_dict__float(): # given data = { @@ -734,6 +997,27 @@ def test__field__from_model__float(): assert result.value == 18.5 +def test__field__from_proto__float(): + # given + proto = ProtoAttributeDTO( + name="some/float", + type="float", + float_properties=ProtoFloatAttributeDTO( + attribute_name="some/float", + attribute_type="float", + value=18.5, + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/float" + assert isinstance(result, FloatField) + assert result.value == 18.5 + + def test__field__from_dict__int(): # given data = { @@ -766,6 +1050,27 @@ def test__field__from_model__int(): assert result.value == 18 +def test__field__from_proto__int(): + # given + proto = ProtoAttributeDTO( + name="some/int", + type="int", + int_properties=ProtoIntAttributeDTO( + attribute_name="some/int", + attribute_type="int", + value=18, + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/int" + assert isinstance(result, IntField) + assert result.value == 18 + + def test__field__from_dict__string(): # given data = { @@ -800,6 +1105,27 @@ def test__field__from_model__string(): assert result.value == "hello" +def test__field__from_proto__string(): + # given + proto = ProtoAttributeDTO( + name="some/string", + type="string", + string_properties=ProtoStringAttributeDTO( + attribute_name="some/string", + attribute_type="string", + value="hello", + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/string" + assert isinstance(result, StringField) + assert result.value == "hello" + + def test__field__from_dict__bool(): # given data = { @@ -832,6 +1158,27 @@ def test__field__from_model__bool(): assert result.value is True +def test__field__from_proto__bool(): + # given + proto = ProtoAttributeDTO( + name="some/bool", + type="bool", + bool_properties=ProtoBoolAttributeDTO( + attribute_name="some/bool", + attribute_type="bool", + value=True, + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/bool" + assert isinstance(result, BoolField) + assert result.value is True + + def test__field__from_dict__datetime(): # given data = { @@ -872,6 +1219,30 @@ def test__field__from_model__datetime(): assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) +def test__field__from_proto__datetime(): + # given + at = datetime.datetime(2021, 1, 1, 0, 12, 34, tzinfo=datetime.timezone.utc) + + # and + proto = ProtoAttributeDTO( + name="some/datetime", + type="datetime", + datetime_properties=ProtoDatetimeAttributeDTO( + attribute_name="some/datetime", + attribute_type="datetime", + value=int(at.timestamp()), + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/datetime" + assert isinstance(result, DateTimeField) + assert result.value == at + + def test__field__from_dict__float_series(): # given data = { @@ -906,6 +1277,27 @@ def test__field__from_model__float_series(): assert result.last == 19.5 +def test__field__from_proto__float_series(): + # given + proto = ProtoAttributeDTO( + name="some/floatSeries", + type="floatSeries", + float_series_properties=ProtoFloatSeriesAttributeDTO( + attribute_name="some/floatSeries", + attribute_type="floatSeries", + last=19.5, + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/floatSeries" + assert isinstance(result, FloatSeriesField) + assert result.last == 19.5 + + def test__field__from_dict__string_series(): # given data = { @@ -944,6 +1336,15 @@ def test__field__from_model__string_series(): assert result.last == "hello" +def test__field__from_proto__string_series(): + # given + proto = Mock(name="some/stringSeries", type="stringSeries", string_series_properties=Mock()) + + # when + with pytest.raises(NotImplementedError): + Field.from_proto(proto) + + def test__field__from_dict__image_series(): # given data = { @@ -982,6 +1383,15 @@ def test__field__from_model__image_series(): assert result.last_step == 15.0 +def test__field__from_proto__image_series(): + # given + proto = Mock(name="some/imageSeries", type="imageSeries", image_series_properties=Mock()) + + # when + with pytest.raises(NotImplementedError): + Field.from_proto(proto) + + def test__field__from_dict__string_set(): # given data = { @@ -1020,6 +1430,27 @@ def test__field__from_model__string_set(): assert result.values == {"hello", "world"} +def test__field__from_proto__string_set(): + # given + proto = ProtoAttributeDTO( + name="some/stringSet", + type="stringSet", + string_set_properties=ProtoStringSetAttributeDTO( + attribute_name="some/stringSet", + attribute_type="stringSet", + value=["hello", "world"], + ), + ) + + # when + result = Field.from_proto(proto) + + # then + assert result.path == "some/stringSet" + assert isinstance(result, StringSetField) + assert result.values == {"hello", "world"} + + def test__field__from_dict__file(): # given data = { @@ -1065,6 +1496,15 @@ def test__field__from_model__file(): assert result.ext == "txt" +def test__field__from_proto__file(): + # given + proto = Mock(name="some/file", type="file", file_properties=Mock()) + + # then + with pytest.raises(NotImplementedError): + FileField.from_proto(proto) + + def test__field__from_dict__object_state(): # given data = { @@ -1103,6 +1543,15 @@ def test__field__from_model__object_state(): assert result.value == "Active" +def test__field__from_proto__object_state(): + # given + proto = Mock(name="sys/state", type="experimentState", experiment_state_properties=Mock()) + + # when + with pytest.raises(NotImplementedError): + Field.from_proto(proto) + + def test__field__from_dict__file_set(): # given data = { @@ -1137,6 +1586,15 @@ def test__field__from_model__file_set(): assert result.size == 3072 +def test__field__from_proto__file_set(): + # given + proto = Mock(name="some/fileSet", type="fileSet", file_set_properties=Mock()) + + # then + with pytest.raises(NotImplementedError): + FileSetField.from_proto(proto) + + def test__field__from_dict__notebook_ref(): # given data = { @@ -1177,6 +1635,15 @@ def test__field__from_model__notebook_ref(): assert result.notebook_name == "Data Processing.ipynb" +def test__field__from_proto__notebook_ref(): + # given + proto = Mock(name="some/notebookRef", type="notebookRef", notebook_ref_properties=Mock()) + + # then + with pytest.raises(NotImplementedError): + NotebookRefField.from_proto(proto) + + def test__field__from_dict__git_ref(): # given data = { @@ -1215,6 +1682,15 @@ def test__field__from_model__git_ref(): assert result.commit.commit_id == "b2d7f8a" +def test__field__from_proto__git_ref(): + # given + proto = Mock(name="some/gitRef", type="gitRef", git_ref_properties=Mock()) + + # then + with pytest.raises(NotImplementedError): + GitRefField.from_proto(proto) + + def test__field__from_dict__artifact(): # given data = { @@ -1255,6 +1731,15 @@ def test__field__from_model__artifact(): assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" +def test__field__from_proto__artifact(): + # given + proto = Mock(name="some/artifact", type="artifact", artifact_properties=Mock()) + + # then + with pytest.raises(NotImplementedError): + ArtifactField.from_proto(proto) + + def test__field_definition__from_dict(): # given data = { @@ -1285,6 +1770,21 @@ def test__field_definition__from_model(): assert result.type == FieldType.FLOAT +def test__field_definition__from_proto(): + # given + proto = ProtoAttributeDTO( + name="some/float", + type="float", + ) + + # when + result = FieldDefinition.from_proto(proto) + + # then + assert result.path == "some/float" + assert result.type == FieldType.FLOAT + + def test__leaderboard_entry__from_dict(): # given data = { @@ -1371,6 +1871,62 @@ def test__leaderboard_entry__from_model(): assert string_field.path == "some/string" +def test__leaderboard_entry__from_proto(): + # given + proto = ProtoAttributesDTO( + experiment_id="some-id", + attributes=[ + ProtoAttributeDTO( + name="some/float", + type="float", + float_properties=ProtoFloatAttributeDTO( + attribute_name="some/float", + attribute_type="float", + value=18.5, + ), + ), + ProtoAttributeDTO( + name="some/int", + type="int", + int_properties=ProtoIntAttributeDTO( + attribute_name="some/int", + attribute_type="int", + value=18, + ), + ), + ProtoAttributeDTO( + name="some/string", + type="string", + string_properties=ProtoStringAttributeDTO( + attribute_name="some/string", + attribute_type="string", + value="hello", + ), + ), + ], + ) + + # when + result = LeaderboardEntry.from_proto(proto) + + # then + assert result.object_id == "some-id" + assert len(result.fields) == 3 + + float_field = result.fields[0] + assert isinstance(float_field, FloatField) + assert float_field.path == "some/float" + assert float_field.value == 18.5 + + int_field = result.fields[1] + assert isinstance(int_field, IntField) + assert int_field.path == "some/int" + + string_field = result.fields[2] + assert isinstance(string_field, StringField) + assert string_field.path == "some/string" + + def test__leaderboard_entries_search_result__from_dict(): # given data = { @@ -1463,6 +2019,60 @@ def test__leaderboard_entries_search_result__from_model(): assert isinstance(entry_2.fields[0], IntField) +def test__leaderboard_entries_search_result__from_proto(): + # given + proto = ProtoLeaderboardEntriesSearchResultDTO( + matching_item_count=2, + entries=[ + ProtoAttributesDTO( + experiment_id="some-id-1", + attributes=[ + ProtoAttributeDTO( + name="some/float", + type="float", + float_properties=ProtoFloatAttributeDTO( + attribute_name="some/float", + attribute_type="float", + value=18.5, + ), + ), + ], + ), + ProtoAttributesDTO( + experiment_id="some-id-2", + attributes=[ + ProtoAttributeDTO( + name="some/int", + type="int", + int_properties=ProtoIntAttributeDTO( + attribute_name="some/int", + attribute_type="int", + value=18, + ), + ), + ], + ), + ], + ) + + # when + result = LeaderboardEntriesSearchResult.from_proto(proto) + + # then + assert result.matching_item_count == 2 + assert len(result.entries) == 2 + + entry_1 = result.entries[0] + assert entry_1.object_id == "some-id-1" + assert len(entry_1.fields) == 1 + assert isinstance(entry_1.fields[0], FloatField) + + entry_2 = result.entries[1] + assert entry_2.object_id == "some-id-2" + assert len(entry_2.fields) == 1 + assert isinstance(entry_2.fields[0], IntField) + + @pytest.mark.parametrize("field_type", list(FieldType)) def test__all_field_types__have_class(field_type): # when diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index c17e3f13d..01e081871 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -13,11 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import ( - Any, - Dict, - Sequence, -) +from typing import Sequence import pytest from bravado.exception import HTTPBadRequest @@ -38,6 +34,8 @@ iter_over_pages, ) from neptune.exceptions import NeptuneInvalidQueryException +from neptune.internal.backends.nql import RawNQLQuery +from neptune.internal.id_formats import UniqueId def test__to_leaderboard_entry(): @@ -81,7 +79,7 @@ def test__to_leaderboard_entry(): def test__iter_over_pages__single_pagination(get_single_page_mock): # given get_single_page_mock.side_effect = [ - {"matchingItemCount": 9}, + LeaderboardEntriesSearchResult(matching_item_count=9, entries=[]), generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), @@ -101,12 +99,7 @@ def test__iter_over_pages__single_pagination(get_single_page_mock): ) # then - assert ( - result - == LeaderboardEntriesSearchResult.from_dict( - generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - ).entries - ) + assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]).entries assert get_single_page_mock.mock_calls == [ # total checking call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), @@ -121,7 +114,7 @@ def test__iter_over_pages__single_pagination(get_single_page_mock): def test__iter_over_pages__multiple_search_after(get_single_page_mock): # given get_single_page_mock.side_effect = [ - {"matchingItemCount": 9}, + LeaderboardEntriesSearchResult(matching_item_count=9, entries=[]), generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), @@ -142,12 +135,7 @@ def test__iter_over_pages__multiple_search_after(get_single_page_mock): ) # then - assert ( - result - == LeaderboardEntriesSearchResult.from_dict( - generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - ).entries - ) + assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]).entries assert get_single_page_mock.mock_calls == [ # total checking call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), @@ -162,7 +150,7 @@ def test__iter_over_pages__multiple_search_after(get_single_page_mock): def test__iter_over_pages__empty(get_single_page_mock): # given get_single_page_mock.side_effect = [ - {"matchingItemCount": 0}, + LeaderboardEntriesSearchResult(matching_item_count=0, entries=[]), generate_leaderboard_entries(values=[]), ] @@ -191,7 +179,7 @@ def test__iter_over_pages__empty(get_single_page_mock): def test__iter_over_pages__max_server_offset(get_single_page_mock): # given get_single_page_mock.side_effect = [ - {"matchingItemCount": 5}, + LeaderboardEntriesSearchResult(matching_item_count=5, entries=[]), generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e"]), generate_leaderboard_entries(values=[]), @@ -211,12 +199,7 @@ def test__iter_over_pages__max_server_offset(get_single_page_mock): ) # then - assert ( - result - == LeaderboardEntriesSearchResult.from_dict( - generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) - ).entries - ) + assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]).entries assert get_single_page_mock.mock_calls == [ # total checking call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), @@ -233,7 +216,7 @@ def test__iter_over_pages__limit(get_single_page_mock): # given get_single_page_mock.side_effect = [ - {"matchingItemCount": 5}, + LeaderboardEntriesSearchResult(matching_item_count=5, entries=[]), generate_leaderboard_entries(values=["a", "b"]), generate_leaderboard_entries(values=["c", "d"]), generate_leaderboard_entries(values=["e"]), @@ -261,27 +244,19 @@ def test__iter_over_pages__limit(get_single_page_mock): ] -def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> Dict[str, Any]: - return { - "matchingItemCount": len(values), - "entries": [ - { - "experimentId": f"{experiment_id}-{value}", - "attributes": [ - { - "name": "sys/id", - "type": "string", - "stringProperties": { - "attributeName": "sys/id", - "attributeType": "string", - "value": value, - }, - }, +def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + matching_item_count=len(values), + entries=[ + LeaderboardEntry( + object_id=f"{experiment_id}-{value}", + fields=[ + StringField(path="sys/id", value=value), ], - } + ) for value in values ], - } + ) @patch("neptune.api.searching_entries.construct_request") @@ -296,15 +271,15 @@ def test_get_single_page_error_handling(construct_request_mock): # then with pytest.raises(NeptuneInvalidQueryException): get_single_page( - project_id="id", + project_id=UniqueId("id"), attributes_filter={}, types=None, - query="invalid_query", + query=RawNQLQuery("invalid_query"), limit=0, offset=0, sort_by="sys/id", ascending=False, - sort_by_column_type=None, + sort_by_column_type="string", searching_after=None, client=failing_clinet, ) From 8a308d97a8989e15e940967fddf848662915c5bd Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 9 Apr 2024 13:21:41 +0200 Subject: [PATCH 5/8] Unittests skiped and minor adjustments --- src/neptune/api/models.py | 9 +++++---- .../backends/hosted_neptune_backend.py | 2 ++ .../internal/backends/neptune_backend.py | 1 + .../internal/backends/neptune_backend_mock.py | 1 + .../backends/offline_neptune_backend.py | 19 +++++++++++++++++++ .../neptune/new/internal/utils/test_images.py | 2 ++ 6 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 71ccd3ddb..b365259b6 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -137,7 +137,8 @@ def from_dict(data: Dict[str, Any]) -> Field: @staticmethod def from_model(model: Any) -> Field: field_type = str(model.type) - return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}Properties")) + print(model) + return Field._registry[field_type].from_model(model[f"{field_type}Properties"]) @staticmethod def from_proto(data: Any) -> Field: @@ -300,9 +301,9 @@ def from_model(model: Any) -> DateTimeField: @staticmethod def from_proto(data: ProtoDatetimeAttributeDTO) -> DateTimeField: - # TODO: Ensure that the timestamp is in UTC - # TODO: Ensure that we are supporting seconds and miliseconds - return DateTimeField(path=data.attribute_name, value=datetime.fromtimestamp(data.value, tz=timezone.utc)) + return DateTimeField( + path=data.attribute_name, value=datetime.fromtimestamp(data.value / 1000.0, tz=timezone.utc) + ) @dataclass diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 0fb69ad9c..71e1b77ab 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1065,6 +1065,7 @@ def search_leaderboard_entries( ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, step_size: Optional[int] = None, + use_proto: Optional[bool] = False, ) -> Generator[LeaderboardEntry, None, None]: default_step_size = step_size or int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) @@ -1094,6 +1095,7 @@ def search_leaderboard_entries( ascending=ascending, sort_by_column_type=sort_by_column_type, progress_bar=progress_bar, + use_proto=use_proto, ) except HTTPNotFound: raise ProjectNotFound(project_id) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index b87f6557e..c96444f41 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -328,6 +328,7 @@ def search_leaderboard_entries( sort_by: str = "sys/creation_time", ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, + use_proto: Optional[bool] = False, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 5e145bb79..2c0bf8685 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -561,6 +561,7 @@ def search_leaderboard_entries( sort_by: str = "sys/creation_time", ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, + use_proto: Optional[bool] = False, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index 032787c04..a73661c90 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -16,6 +16,8 @@ __all__ = ["OfflineNeptuneBackend"] from typing import ( + Generator, + Iterable, List, Optional, ) @@ -31,6 +33,7 @@ FloatField, FloatSeriesField, IntField, + LeaderboardEntry, StringField, StringSeriesField, StringSetField, @@ -43,7 +46,9 @@ StringSeriesValues, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock +from neptune.internal.backends.nql import NQLQuery from neptune.internal.container_type import ContainerType +from neptune.internal.id_formats import UniqueId from neptune.typing import ProgressBarType @@ -152,3 +157,17 @@ def get_fields_definitions( use_proto: Optional[bool] = False, ) -> List[FieldDefinition]: raise NeptuneOfflineModeFetchException + + def search_leaderboard_entries( + self, + project_id: UniqueId, + types: Optional[Iterable[ContainerType]] = None, + query: Optional[NQLQuery] = None, + columns: Optional[Iterable[str]] = None, + limit: Optional[int] = None, + sort_by: str = "sys/creation_time", + ascending: bool = False, + progress_bar: Optional[ProgressBarType] = None, + use_proto: Optional[bool] = False, + ) -> Generator[LeaderboardEntry, None, None]: + raise NeptuneOfflineModeFetchException diff --git a/tests/unit/neptune/new/internal/utils/test_images.py b/tests/unit/neptune/new/internal/utils/test_images.py index 89fa28f58..7712f4d41 100644 --- a/tests/unit/neptune/new/internal/utils/test_images.py +++ b/tests/unit/neptune/new/internal/utils/test_images.py @@ -27,6 +27,7 @@ import numpy import pandas import plotly.express as px +import pytest import seaborn as sns from bokeh.plotting import figure from matplotlib import pyplot @@ -136,6 +137,7 @@ def test_get_image_content_from_torch_tensor(self): # and make sure that original image's size was preserved self.assertFalse((image_tensor.numpy() * 255 - expected_array).any()) + @pytest.mark.skip("Conflicts with protobuf version") def test_get_image_content_from_tensorflow_tensor(self): import tensorflow as tf From 8ec5ea6bba298050b2a7586bd3604c640b77a85a Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 9 Apr 2024 13:23:04 +0200 Subject: [PATCH 6/8] TODOs removed --- src/neptune/api/models.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index b365259b6..8db218ae1 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -325,7 +325,6 @@ def from_model(model: Any) -> FileField: @staticmethod def from_proto(data: Any) -> FileField: - # TODO: implement raise NotImplementedError() @@ -346,7 +345,6 @@ def from_model(model: Any) -> FileSetField: @staticmethod def from_proto(data: Any) -> FileSetField: - # TODO: implement raise NotImplementedError() @@ -390,7 +388,6 @@ def from_model(model: Any) -> StringSeriesField: @staticmethod def from_proto(data: Any) -> StringSeriesField: - # TODO: implement raise NotImplementedError() @@ -412,7 +409,6 @@ def from_model(model: Any) -> ImageSeriesField: @staticmethod def from_proto(data: Any) -> ImageSeriesField: - # TODO: implement raise NotImplementedError() @@ -451,7 +447,6 @@ def from_model(model: Any) -> GitCommit: @staticmethod def from_proto(data: Any) -> GitCommit: - # TODO: implement raise NotImplementedError() @@ -474,7 +469,6 @@ def from_model(model: Any) -> GitRefField: @staticmethod def from_proto(data: ProtoAttributeDTO) -> GitRefField: - # TODO: implement raise NotImplementedError() @@ -497,7 +491,6 @@ def from_model(model: Any) -> ObjectStateField: @staticmethod def from_proto(data: Any) -> ObjectStateField: - # TODO: implement raise NotImplementedError() @@ -519,7 +512,6 @@ def from_model(model: Any) -> NotebookRefField: @staticmethod def from_proto(data: Any) -> NotebookRefField: - # TODO: implement raise NotImplementedError() @@ -540,7 +532,6 @@ def from_model(model: Any) -> ArtifactField: @staticmethod def from_proto(data: Any) -> ArtifactField: - # TODO: implement raise NotImplementedError() From 6187b23c2d5a73ee34854cf170cf1a57d25813ed Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 9 Apr 2024 13:55:00 +0200 Subject: [PATCH 7/8] Environment variable and removed debug --- src/neptune/api/models.py | 1 - src/neptune/api/searching_entries.py | 2 +- src/neptune/envs.py | 3 +++ .../internal/backends/hosted_neptune_backend.py | 17 +++++++++++++---- .../internal/backends/neptune_backend.py | 6 +++--- .../internal/backends/neptune_backend_mock.py | 6 +++--- .../backends/offline_neptune_backend.py | 6 +++--- 7 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 8db218ae1..18c7cf1ee 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -137,7 +137,6 @@ def from_dict(data: Dict[str, Any]) -> Field: @staticmethod def from_model(model: Any) -> Field: field_type = str(model.type) - print(model) return Field._registry[field_type].from_model(model[f"{field_type}Properties"]) @staticmethod diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index ee405c475..0f7bc9109 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -98,7 +98,7 @@ def get_single_page( types: Optional[Iterable[str]], query: Optional["NQLQuery"], searching_after: Optional[str], - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> LeaderboardEntriesSearchResult: normalized_query = query or NQLEmptyQuery() sort_by_column_type = sort_by_column_type if sort_by_column_type else FieldType.STRING.value diff --git a/src/neptune/envs.py b/src/neptune/envs.py index 3d717ac30..bb210c5ff 100644 --- a/src/neptune/envs.py +++ b/src/neptune/envs.py @@ -32,6 +32,7 @@ "NEPTUNE_RAISE_ERROR_ON_DISK_USAGE_EXCEEDED", "NEPTUNE_ENABLE_DEFAULT_ASYNC_LAG_CALLBACK", "NEPTUNE_ENABLE_DEFAULT_ASYNC_NO_PROGRESS_CALLBACK", + "NEPTUNE_USE_PROTOCOL_BUFFERS", "NEPTUNE_ASYNC_BATCH_SIZE", ] @@ -74,4 +75,6 @@ NEPTUNE_ASYNC_BATCH_SIZE = "NEPTUNE_ASYNC_BATCH_SIZE" +NEPTUNE_USE_PROTOCOL_BUFFERS = "NEPTUNE_USE_PROTOCOL_BUFFERS" + S3_ENDPOINT_URL = "S3_ENDPOINT_URL" diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 71e1b77ab..4313923a7 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -59,7 +59,10 @@ from neptune.api.proto.neptune_pb.api.model.leaderboard_entries_pb2 import ProtoAttributesDTO from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage -from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE +from neptune.envs import ( + NEPTUNE_FETCH_TABLE_STEP_SIZE, + NEPTUNE_USE_PROTOCOL_BUFFERS, +) from neptune.exceptions import ( AmbiguousProjectName, ContainerUUIDNotFound, @@ -174,6 +177,7 @@ def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] = self.credentials = credentials self.proxies = proxies self.missing_features = [] + self.use_proto = os.getenv(NEPTUNE_USE_PROTOCOL_BUFFERS, "False").lower() in {"true", "1", "y"} http_client, client_config = create_http_client_with_auth( credentials=credentials, ssl_verify=ssl_verify(), proxies=proxies @@ -1065,8 +1069,9 @@ def search_leaderboard_entries( ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, step_size: Optional[int] = None, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> Generator[LeaderboardEntry, None, None]: + use_proto = use_proto if use_proto is not None else self.use_proto default_step_size = step_size or int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) step_size = min(default_step_size, limit) if limit else default_step_size @@ -1127,8 +1132,10 @@ def get_fields_definitions( self, container_id: str, container_type: ContainerType, - use_proto: Optional[bool] = False, # TODO: Use environment variable instead + use_proto: Optional[bool] = None, ) -> List[FieldDefinition]: + use_proto = use_proto if use_proto is not None else self.use_proto + params = { "experimentIdentifier": container_id, **DEFAULT_REQUEST_KWARGS, @@ -1149,8 +1156,10 @@ def get_fields_definitions( ) from e def get_fields_with_paths_filter( - self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None ) -> List[Field]: + use_proto = use_proto if use_proto is not None else self.use_proto + params = { "holderIdentifier": container_id, "holderType": "experiment", diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index c96444f41..c89770a07 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -302,13 +302,13 @@ def get_fields_definitions( self, container_id: str, container_type: ContainerType, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> List[FieldDefinition]: ... # WARN: Used in Neptune Fetcher @abc.abstractmethod def get_fields_with_paths_filter( - self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None ) -> List[Field]: ... @abc.abstractmethod @@ -328,7 +328,7 @@ def search_leaderboard_entries( sort_by: str = "sys/creation_time", ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 2c0bf8685..b20ce11bc 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -519,7 +519,7 @@ def get_fields_definitions( self, container_id: str, container_type: ContainerType, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> List[FieldDefinition]: return [] @@ -561,7 +561,7 @@ def search_leaderboard_entries( sort_by: str = "sys/creation_time", ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" @@ -799,6 +799,6 @@ def list_fileset_files(self, attribute: List[str], container_id: str, path: str) ] def get_fields_with_paths_filter( - self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None ) -> List[Field]: return [] diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index a73661c90..400de4600 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -146,7 +146,7 @@ def list_fileset_files(self, attribute: List[str], container_id: str, path: str) raise NeptuneOfflineModeFetchException def get_fields_with_paths_filter( - self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = False + self, container_id: str, container_type: ContainerType, paths: List[str], use_proto: Optional[bool] = None ) -> List[Field]: raise NeptuneOfflineModeFetchException @@ -154,7 +154,7 @@ def get_fields_definitions( self, container_id: str, container_type: ContainerType, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> List[FieldDefinition]: raise NeptuneOfflineModeFetchException @@ -168,6 +168,6 @@ def search_leaderboard_entries( sort_by: str = "sys/creation_time", ascending: bool = False, progress_bar: Optional[ProgressBarType] = None, - use_proto: Optional[bool] = False, + use_proto: Optional[bool] = None, ) -> Generator[LeaderboardEntry, None, None]: raise NeptuneOfflineModeFetchException From 64ad22a8ef58d45bf209f0746172b1269e6afc26 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 9 Apr 2024 14:00:59 +0200 Subject: [PATCH 8/8] Test fixes and field lookup --- src/neptune/api/models.py | 2 +- tests/unit/neptune/new/api/test_models.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 18c7cf1ee..59e6774fc 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -137,7 +137,7 @@ def from_dict(data: Dict[str, Any]) -> Field: @staticmethod def from_model(model: Any) -> Field: field_type = str(model.type) - return Field._registry[field_type].from_model(model[f"{field_type}Properties"]) + return Field._registry[field_type].from_model(model.__getattr__(f"{field_type}Properties")) @staticmethod def from_proto(data: Any) -> Field: diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py index 2300cb7d6..a941027a0 100644 --- a/tests/unit/neptune/new/api/test_models.py +++ b/tests/unit/neptune/new/api/test_models.py @@ -281,10 +281,10 @@ def test__datetime_field__from_model(): def test__datetime_field__from_proto(): # given - at = datetime.datetime(2024, 1, 1, 0, 12, 34, tzinfo=datetime.timezone.utc) + at = datetime.datetime(2024, 1, 1, 0, 12, 34, 123000, tzinfo=datetime.timezone.utc) proto = ProtoDatetimeAttributeDTO( - attribute_name="some/datetime", attribute_type="datetime", value=int(at.timestamp()) + attribute_name="some/datetime", attribute_type="datetime", value=int(at.timestamp() * 1000) ) # when @@ -1221,7 +1221,7 @@ def test__field__from_model__datetime(): def test__field__from_proto__datetime(): # given - at = datetime.datetime(2021, 1, 1, 0, 12, 34, tzinfo=datetime.timezone.utc) + at = datetime.datetime(2021, 1, 1, 0, 12, 34, 123000, tzinfo=datetime.timezone.utc) # and proto = ProtoAttributeDTO( @@ -1230,7 +1230,7 @@ def test__field__from_proto__datetime(): datetime_properties=ProtoDatetimeAttributeDTO( attribute_name="some/datetime", attribute_type="datetime", - value=int(at.timestamp()), + value=int(at.timestamp() * 1000), ), )