From 1721467a15718ed2355941103b0bc6379c525709 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 1 Apr 2024 19:40:51 +0200 Subject: [PATCH 01/22] AttributeDefinition --- src/neptune/internal/backends/api_model.py | 4 ++-- .../internal/backends/hosted_neptune_backend.py | 8 ++++---- src/neptune/internal/backends/neptune_backend.py | 4 ++-- src/neptune/internal/backends/neptune_backend_mock.py | 8 ++++---- .../internal/backends/offline_neptune_backend.py | 4 ++-- tests/unit/neptune/new/client/abstract_tables_test.py | 4 ++-- tests/unit/neptune/new/client/test_model.py | 6 +++--- tests/unit/neptune/new/client/test_model_version.py | 10 +++++----- tests/unit/neptune/new/client/test_project.py | 4 ++-- tests/unit/neptune/new/client/test_run.py | 6 +++--- 10 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/neptune/internal/backends/api_model.py b/src/neptune/internal/backends/api_model.py index 1116de3a5..1d9927b6b 100644 --- a/src/neptune/internal/backends/api_model.py +++ b/src/neptune/internal/backends/api_model.py @@ -21,7 +21,7 @@ "VersionInfo", "ClientConfig", "AttributeType", - "Attribute", + "AttributeDefinition", "AttributeWithProperties", "LeaderboardEntry", "StringPointValue", @@ -218,7 +218,7 @@ class AttributeType(Enum): @dataclass -class Attribute: +class AttributeDefinition: path: str type: AttributeType diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 1a51b930f..fd365c938 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -59,7 +59,7 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - Attribute, + AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, @@ -682,9 +682,9 @@ def _execute_operations( raise NeptuneLimitExceedException(reason=e.response.json().get("title", "Unknown reason")) from e @with_api_exceptions_handler - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: - def to_attribute(attr) -> Attribute: - return Attribute(attr.name, AttributeType(attr.type)) + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: + def to_attribute(attr) -> AttributeDefinition: + return AttributeDefinition(attr.name, AttributeType(attr.type)) params = { "experimentId": container_id, diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index b282042bb..dfcba3454 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -31,7 +31,7 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - Attribute, + AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, @@ -146,7 +146,7 @@ def execute_operations( pass @abc.abstractmethod - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 4ff41e88e..e415e26f5 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -47,7 +47,7 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - Attribute, + AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, @@ -311,7 +311,7 @@ def _execute_operation( else: run.pop(op.path) - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: run = self._get_container(container_id, container_type) return list(self._generate_attributes(None, run.get_structure())) @@ -321,7 +321,7 @@ def _generate_attributes(self, base_path: Optional[str], values: dict): if isinstance(value_or_dict, dict): yield from self._generate_attributes(new_path, value_or_dict) else: - yield Attribute( + yield AttributeDefinition( new_path, value_or_dict.accept(self._attribute_type_converter_value_visitor), ) @@ -597,7 +597,7 @@ def visit_artifact(self, _: Artifact) -> AttributeType: def visit_namespace(self, _: Namespace) -> AttributeType: raise NotImplementedError - def copy_value(self, source_type: Type[Attribute], source_path: List[str]) -> AttributeType: + def copy_value(self, source_type: Type[AttributeDefinition], source_path: List[str]) -> AttributeType: raise NotImplementedError class NewValueOpVisitor(OperationVisitor[Optional[Value]]): diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index f5d2589bf..d333716c2 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -25,7 +25,7 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ArtifactAttribute, - Attribute, + AttributeDefinition, BoolAttribute, DatetimeAttribute, FileAttribute, @@ -47,7 +47,7 @@ class OfflineNeptuneBackend(NeptuneBackendMock): WORKSPACE_NAME = "offline" - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: raise NeptuneOfflineModeFetchException def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 1a5101b8d..b4a0ca8c4 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -29,7 +29,7 @@ ) from neptune.exceptions import MetadataInconsistency from neptune.internal.backends.api_model import ( - Attribute, + AttributeDefinition, AttributeType, AttributeWithProperties, LeaderboardEntry, @@ -43,7 +43,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute(path="test", type=AttributeType.STRING)], + new=lambda _, _uuid, _type: [AttributeDefinition(path="test", type=AttributeType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class AbstractTablesTestMixin: diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index d34982c9c..b337b7138 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -33,7 +33,7 @@ NeptuneWrongInitParametersException, ) from neptune.internal.backends.api_model import ( - Attribute, + AttributeDefinition, AttributeType, IntAttribute, ) @@ -77,7 +77,7 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [AttributeDefinition("some/variable", AttributeType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", @@ -102,7 +102,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], ) def test_resume(self): with init_model(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index 9583dcadd..5ef9cceea 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -34,7 +34,7 @@ NeptuneWrongInitParametersException, ) from neptune.internal.backends.api_model import ( - Attribute, + AttributeDefinition, AttributeType, IntAttribute, StringAttribute, @@ -87,8 +87,8 @@ def test_offline_mode(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - Attribute("some/variable", AttributeType.INT), - Attribute("sys/model_id", AttributeType.STRING), + AttributeDefinition("some/variable", AttributeType.INT), + AttributeDefinition("sys/model_id", AttributeType.STRING), ], ) @patch( @@ -115,8 +115,8 @@ def test_read_only_mode(self, warn_once): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - Attribute("test", AttributeType.STRING), - Attribute("sys/model_id", AttributeType.STRING), + AttributeDefinition("test", AttributeType.STRING), + AttributeDefinition("sys/model_id", AttributeType.STRING), ], ) @patch( diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index 41c8d81df..cba55ad4b 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -33,7 +33,7 @@ NeptuneUnsupportedFunctionalityException, ) from neptune.internal.backends.api_model import ( - Attribute, + AttributeDefinition, AttributeType, AttributeWithProperties, IntAttribute, @@ -54,7 +54,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class TestClientProject(AbstractExperimentTestMixin, unittest.TestCase): diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index e515f2d85..c8e3ff9f5 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -34,7 +34,7 @@ ) from neptune.exceptions import MissingFieldException from neptune.internal.backends.api_model import ( - Attribute, + AttributeDefinition, AttributeType, IntAttribute, ) @@ -68,7 +68,7 @@ def setUpClass(cls) -> None: ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [AttributeDefinition("some/variable", AttributeType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", @@ -93,7 +93,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], ) def test_resume(self): with init_run(flush_period=0.5, with_id="whatever") as exp: From a5da872e40e92c5d1ffd9494f80c63d32eaa000f Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 1 Apr 2024 19:43:21 +0200 Subject: [PATCH 02/22] Field and FieldDefinition --- src/neptune/api/searching_entries.py | 6 ++-- src/neptune/integrations/pandas/__init__.py | 4 +-- src/neptune/internal/backends/api_model.py | 10 +++---- .../backends/hosted_neptune_backend.py | 8 +++--- .../internal/backends/neptune_backend.py | 4 +-- .../internal/backends/neptune_backend_mock.py | 8 +++--- .../backends/offline_neptune_backend.py | 4 +-- src/neptune/objects/utils.py | 4 +-- src/neptune/table.py | 4 +-- .../neptune/new/api/test_searching_entries.py | 8 +++--- .../new/client/abstract_tables_test.py | 28 +++++++++---------- tests/unit/neptune/new/client/test_model.py | 6 ++-- .../neptune/new/client/test_model_version.py | 10 +++---- tests/unit/neptune/new/client/test_project.py | 12 ++++---- tests/unit/neptune/new/client/test_run.py | 6 ++-- .../neptune/new/client/test_run_tables.py | 4 +-- 16 files changed, 63 insertions(+), 63 deletions(-) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 7c9f0007d..dc43529b9 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -36,7 +36,7 @@ from neptune.exceptions import NeptuneInvalidQueryException from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS @@ -161,7 +161,7 @@ def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: return LeaderboardEntry( id=entry["experimentId"], attributes=[ - AttributeWithProperties( + Field( path=attr["name"], type=AttributeType(attr["type"]), properties=attr.__getitem__(f"{attr['type']}Properties"), @@ -172,7 +172,7 @@ def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: ) -def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[AttributeWithProperties]: +def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]: return next((attr for attr in entry.attributes if attr.path == path), None) diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 268bc9ee4..9e53748a5 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -30,7 +30,7 @@ from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) from neptune.internal.utils.logger import get_logger @@ -43,7 +43,7 @@ def to_pandas(table: Table) -> pd.DataFrame: - def make_attribute_value(attribute: AttributeWithProperties) -> Any: + def make_attribute_value(attribute: Field) -> Any: _type = attribute.type _properties = attribute.properties if _type == AttributeType.RUN_STATE: diff --git a/src/neptune/internal/backends/api_model.py b/src/neptune/internal/backends/api_model.py index 1d9927b6b..7c1957b48 100644 --- a/src/neptune/internal/backends/api_model.py +++ b/src/neptune/internal/backends/api_model.py @@ -21,8 +21,8 @@ "VersionInfo", "ClientConfig", "AttributeType", - "AttributeDefinition", - "AttributeWithProperties", + "FieldDefinition", + "Field", "LeaderboardEntry", "StringPointValue", "ImageSeriesValues", @@ -218,13 +218,13 @@ class AttributeType(Enum): @dataclass -class AttributeDefinition: +class FieldDefinition: path: str type: AttributeType @dataclass -class AttributeWithProperties: +class Field: path: str type: AttributeType properties: Any @@ -233,7 +233,7 @@ class AttributeWithProperties: @dataclass class LeaderboardEntry: id: str - attributes: List[AttributeWithProperties] + attributes: List[Field] @dataclass diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index fd365c938..6e22d86df 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -59,10 +59,10 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, + FieldDefinition, FileAttribute, FloatAttribute, FloatPointValue, @@ -682,9 +682,9 @@ def _execute_operations( raise NeptuneLimitExceedException(reason=e.response.json().get("title", "Unknown reason")) from e @with_api_exceptions_handler - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: - def to_attribute(attr) -> AttributeDefinition: - return AttributeDefinition(attr.name, AttributeType(attr.type)) + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: + def to_attribute(attr) -> FieldDefinition: + return FieldDefinition(attr.name, AttributeType(attr.type)) params = { "experimentId": container_id, diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index dfcba3454..23902b55c 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -31,10 +31,10 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, + FieldDefinition, FileAttribute, FloatAttribute, FloatSeriesAttribute, @@ -146,7 +146,7 @@ def execute_operations( pass @abc.abstractmethod - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index e415e26f5..41243314c 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -47,10 +47,10 @@ from neptune.internal.backends.api_model import ( ApiExperiment, ArtifactAttribute, - AttributeDefinition, AttributeType, BoolAttribute, DatetimeAttribute, + FieldDefinition, FileAttribute, FloatAttribute, FloatPointValue, @@ -311,7 +311,7 @@ def _execute_operation( else: run.pop(op.path) - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: run = self._get_container(container_id, container_type) return list(self._generate_attributes(None, run.get_structure())) @@ -321,7 +321,7 @@ def _generate_attributes(self, base_path: Optional[str], values: dict): if isinstance(value_or_dict, dict): yield from self._generate_attributes(new_path, value_or_dict) else: - yield AttributeDefinition( + yield FieldDefinition( new_path, value_or_dict.accept(self._attribute_type_converter_value_visitor), ) @@ -597,7 +597,7 @@ def visit_artifact(self, _: Artifact) -> AttributeType: def visit_namespace(self, _: Namespace) -> AttributeType: raise NotImplementedError - def copy_value(self, source_type: Type[AttributeDefinition], source_path: List[str]) -> AttributeType: + def copy_value(self, source_type: Type[FieldDefinition], source_path: List[str]) -> AttributeType: raise NotImplementedError class NewValueOpVisitor(OperationVisitor[Optional[Value]]): diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index d333716c2..dd3b0b092 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -25,9 +25,9 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ArtifactAttribute, - AttributeDefinition, BoolAttribute, DatetimeAttribute, + FieldDefinition, FileAttribute, FloatAttribute, FloatSeriesAttribute, @@ -47,7 +47,7 @@ class OfflineNeptuneBackend(NeptuneBackendMock): WORKSPACE_NAME = "offline" - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[AttributeDefinition]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: raise NeptuneOfflineModeFetchException def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: diff --git a/src/neptune/objects/utils.py b/src/neptune/objects/utils.py index 446e0637c..40c9acc60 100644 --- a/src/neptune/objects/utils.py +++ b/src/neptune/objects/utils.py @@ -29,7 +29,7 @@ from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) from neptune.internal.backends.nql import ( @@ -146,7 +146,7 @@ def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: entry.id, attributes=[ ( - AttributeWithProperties( + Field( attribute.path, attribute.type, { diff --git a/src/neptune/table.py b/src/neptune/table.py index 9b997cc09..02ebdc646 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -27,7 +27,7 @@ from neptune.integrations.pandas import to_pandas from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend import NeptuneBackend @@ -53,7 +53,7 @@ def __init__( backend: NeptuneBackend, container_type: ContainerType, _id: str, - attributes: List[AttributeWithProperties], + attributes: List[Field], ): self._backend = backend self._container_type = container_type diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 6260b19d1..bf09b74af 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -34,7 +34,7 @@ from neptune.exceptions import NeptuneInvalidQueryException from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) @@ -67,14 +67,14 @@ def test__to_leaderboard_entry(): # then assert result.id == "foo" assert result.attributes == [ - AttributeWithProperties( + Field( path="plugh", type=AttributeType.FLOAT, properties={ "value": 1.0, }, ), - AttributeWithProperties( + Field( path="sys/id", type=AttributeType.STRING, properties={ @@ -255,7 +255,7 @@ def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") - return [ LeaderboardEntry( id=experiment_id, - attributes=[AttributeWithProperties(path="sys/id", type=AttributeType.STRING, properties={"value": value})], + attributes=[Field(path="sys/id", type=AttributeType.STRING, properties={"value": value})], ) for value in values ] diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index b4a0ca8c4..e413c1369 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -29,9 +29,9 @@ ) from neptune.exceptions import MetadataInconsistency from neptune.internal.backends.api_model import ( - AttributeDefinition, AttributeType, - AttributeWithProperties, + Field, + FieldDefinition, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -43,7 +43,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition(path="test", type=AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition(path="test", type=AttributeType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class AbstractTablesTestMixin: @@ -69,23 +69,23 @@ def setUp(cls) -> None: @staticmethod def build_attributes_leaderboard(now: datetime): attributes = [] - attributes.append(AttributeWithProperties("run/state", AttributeType.RUN_STATE, {"value": "idle"})) - attributes.append(AttributeWithProperties("float", AttributeType.FLOAT, {"value": 12.5})) - attributes.append(AttributeWithProperties("string", AttributeType.STRING, {"value": "some text"})) - attributes.append(AttributeWithProperties("datetime", AttributeType.DATETIME, {"value": now})) - attributes.append(AttributeWithProperties("float/series", AttributeType.FLOAT_SERIES, {"last": 8.7})) - attributes.append(AttributeWithProperties("string/series", AttributeType.STRING_SERIES, {"last": "last text"})) - attributes.append(AttributeWithProperties("string/set", AttributeType.STRING_SET, {"values": ["a", "b"]})) + attributes.append(Field("run/state", AttributeType.RUN_STATE, {"value": "idle"})) + attributes.append(Field("float", AttributeType.FLOAT, {"value": 12.5})) + attributes.append(Field("string", AttributeType.STRING, {"value": "some text"})) + attributes.append(Field("datetime", AttributeType.DATETIME, {"value": now})) + attributes.append(Field("float/series", AttributeType.FLOAT_SERIES, {"last": 8.7})) + attributes.append(Field("string/series", AttributeType.STRING_SERIES, {"last": "last text"})) + attributes.append(Field("string/set", AttributeType.STRING_SET, {"values": ["a", "b"]})) attributes.append( - AttributeWithProperties( + Field( "git/ref", AttributeType.GIT_REF, {"commit": {"commitId": "abcdef0123456789"}}, ) ) - attributes.append(AttributeWithProperties("file", AttributeType.FILE, None)) - attributes.append(AttributeWithProperties("file/set", AttributeType.FILE_SET, None)) - attributes.append(AttributeWithProperties("image/series", AttributeType.IMAGE_SERIES, None)) + attributes.append(Field("file", AttributeType.FILE, None)) + attributes.append(Field("file/set", AttributeType.FILE_SET, None)) + attributes.append(Field("image/series", AttributeType.IMAGE_SERIES, None)) return attributes @patch.object(NeptuneBackendMock, "search_leaderboard_entries") diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index b337b7138..8013b99ae 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -33,8 +33,8 @@ NeptuneWrongInitParametersException, ) from neptune.internal.backends.api_model import ( - AttributeDefinition, AttributeType, + FieldDefinition, IntAttribute, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -77,7 +77,7 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", AttributeType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", @@ -102,7 +102,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], ) def test_resume(self): with init_model(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index 5ef9cceea..787189b3e 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -34,8 +34,8 @@ NeptuneWrongInitParametersException, ) from neptune.internal.backends.api_model import ( - AttributeDefinition, AttributeType, + FieldDefinition, IntAttribute, StringAttribute, ) @@ -87,8 +87,8 @@ def test_offline_mode(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - AttributeDefinition("some/variable", AttributeType.INT), - AttributeDefinition("sys/model_id", AttributeType.STRING), + FieldDefinition("some/variable", AttributeType.INT), + FieldDefinition("sys/model_id", AttributeType.STRING), ], ) @patch( @@ -115,8 +115,8 @@ def test_read_only_mode(self, warn_once): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - AttributeDefinition("test", AttributeType.STRING), - AttributeDefinition("sys/model_id", AttributeType.STRING), + FieldDefinition("test", AttributeType.STRING), + FieldDefinition("sys/model_id", AttributeType.STRING), ], ) @patch( diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index cba55ad4b..895a30064 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -33,9 +33,9 @@ NeptuneUnsupportedFunctionalityException, ) from neptune.internal.backends.api_model import ( - AttributeDefinition, AttributeType, - AttributeWithProperties, + Field, + FieldDefinition, IntAttribute, LeaderboardEntry, ) @@ -54,7 +54,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class TestClientProject(AbstractExperimentTestMixin, unittest.TestCase): @@ -175,12 +175,12 @@ def entries_generator(): yield LeaderboardEntry( id="test", attributes=[ - AttributeWithProperties( + Field( "attr1", AttributeType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, ), - AttributeWithProperties( + Field( "attr2", AttributeType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, @@ -199,7 +199,7 @@ def test_parse_dates_wrong_format(mock_warn_once): LeaderboardEntry( id="test", attributes=[ - AttributeWithProperties( + Field( "attr1", AttributeType.DATETIME, {"value": "07-02-2024"}, # different format than expected diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index c8e3ff9f5..dd9d510b9 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -34,8 +34,8 @@ ) from neptune.exceptions import MissingFieldException from neptune.internal.backends.api_model import ( - AttributeDefinition, AttributeType, + FieldDefinition, IntAttribute, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -68,7 +68,7 @@ def setUpClass(cls) -> None: ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", AttributeType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", @@ -93,7 +93,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [AttributeDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], ) def test_resume(self): with init_run(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index 8b5a75869..c66472d8a 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -23,7 +23,7 @@ from neptune import init_project from neptune.internal.backends.api_model import ( AttributeType, - AttributeWithProperties, + Field, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -68,7 +68,7 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): LeaderboardEntry( id="123", attributes=[ - AttributeWithProperties( + Field( "sys/creation_time", AttributeType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, From 049820d3920cbc13bd1d9d1ace42c8fb25c3dcb1 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 1 Apr 2024 19:45:38 +0200 Subject: [PATCH 03/22] Fields renamed --- src/neptune/api/searching_entries.py | 4 ++-- src/neptune/integrations/pandas/__init__.py | 2 +- src/neptune/internal/backends/api_model.py | 2 +- .../internal/backends/hosted_neptune_backend.py | 6 +++--- src/neptune/objects/utils.py | 4 ++-- src/neptune/table.py | 2 +- tests/unit/neptune/new/api/test_searching_entries.py | 4 ++-- tests/unit/neptune/new/client/test_project.py | 10 +++++----- tests/unit/neptune/new/client/test_run_tables.py | 2 +- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index dc43529b9..ffa13a26e 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -160,7 +160,7 @@ def get_single_page( def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: return LeaderboardEntry( id=entry["experimentId"], - attributes=[ + fields=[ Field( path=attr["name"], type=AttributeType(attr["type"]), @@ -173,7 +173,7 @@ def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]: - return next((attr for attr in entry.attributes if attr.path == path), None) + return next((attr for attr in entry.fields if attr.path == path), None) def iter_over_pages( diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 9e53748a5..771f380da 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -80,7 +80,7 @@ def make_attribute_value(attribute: Field) -> Any: def make_row(entry: LeaderboardEntry) -> Dict[str, Any]: row: Dict[str, Union[str, float, datetime]] = dict() - for attr in entry.attributes: + for attr in entry.fields: value = make_attribute_value(attr) if value is not None: row[attr.path] = value diff --git a/src/neptune/internal/backends/api_model.py b/src/neptune/internal/backends/api_model.py index 7c1957b48..6837c5666 100644 --- a/src/neptune/internal/backends/api_model.py +++ b/src/neptune/internal/backends/api_model.py @@ -233,7 +233,7 @@ class Field: @dataclass class LeaderboardEntry: id: str - attributes: List[Field] + fields: List[Field] @dataclass diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 6e22d86df..214ad83c6 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -694,10 +694,10 @@ def to_attribute(attr) -> FieldDefinition: experiment = self.leaderboard_client.api.getExperimentAttributes(**params).response().result attribute_type_names = [at.value for at in AttributeType] - accepted_attributes = [attr for attr in experiment.attributes if attr.type in attribute_type_names] + accepted_attributes = [attr for attr in experiment.fields if attr.type in attribute_type_names] # Notify about ignored attrs - ignored_attributes = set(attr.type for attr in experiment.attributes) - set( + ignored_attributes = set(attr.type for attr in experiment.fields) - set( attr.type for attr in accepted_attributes ) if ignored_attributes: @@ -1028,7 +1028,7 @@ def fetch_atom_attribute_values( result = self.leaderboard_client.api.getExperimentAttributes(**params).response().result return [ (attr.name, attr.type, map_attribute_result_to_value(attr)) - for attr in result.attributes + for attr in result.fields if attr.name.startswith(namespace_prefix) ] except HTTPNotFound as e: diff --git a/src/neptune/objects/utils.py b/src/neptune/objects/utils.py index 40c9acc60..c76666139 100644 --- a/src/neptune/objects/utils.py +++ b/src/neptune/objects/utils.py @@ -144,7 +144,7 @@ def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: try: return LeaderboardEntry( entry.id, - attributes=[ + fields=[ ( Field( attribute.path, @@ -157,7 +157,7 @@ def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: if attribute.type == AttributeType.DATETIME else attribute ) - for attribute in entry.attributes + for attribute in entry.fields ], ) except ValueError: diff --git a/src/neptune/table.py b/src/neptune/table.py index 02ebdc646..812bd28bb 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -194,7 +194,7 @@ def __next__(self) -> TableEntry: backend=self._backend, container_type=self._container_type, _id=entry.id, - attributes=entry.attributes, + attributes=entry.fields, ) def to_pandas(self) -> "pandas.DataFrame": diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index bf09b74af..d5fff2c9c 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -66,7 +66,7 @@ def test__to_leaderboard_entry(): # then assert result.id == "foo" - assert result.attributes == [ + assert result.fields == [ Field( path="plugh", type=AttributeType.FLOAT, @@ -255,7 +255,7 @@ def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") - return [ LeaderboardEntry( id=experiment_id, - attributes=[Field(path="sys/id", type=AttributeType.STRING, properties={"value": value})], + fields=[Field(path="sys/id", type=AttributeType.STRING, properties={"value": value})], ) for value in values ] diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index 895a30064..1176462ea 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -174,7 +174,7 @@ def test_parse_dates(): def entries_generator(): yield LeaderboardEntry( id="test", - attributes=[ + fields=[ Field( "attr1", AttributeType.DATETIME, @@ -189,8 +189,8 @@ def entries_generator(): ) parsed = list(parse_dates(entries_generator())) - assert parsed[0].attributes[0].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) - assert parsed[0].attributes[1].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) + assert parsed[0].fields[0].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) + assert parsed[0].fields[1].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) @patch("neptune.objects.utils.warn_once") @@ -198,7 +198,7 @@ def test_parse_dates_wrong_format(mock_warn_once): entries = [ LeaderboardEntry( id="test", - attributes=[ + fields=[ Field( "attr1", AttributeType.DATETIME, @@ -209,7 +209,7 @@ def test_parse_dates_wrong_format(mock_warn_once): ] parsed = list(parse_dates(entries)) - assert parsed[0].attributes[0].properties["value"] == "07-02-2024" # should be left unchanged due to ValueError + assert parsed[0].fields[0].properties["value"] == "07-02-2024" # should be left unchanged due to ValueError mock_warn_once.assert_called_once_with( "Date parsing failed. The date format is incorrect. Returning as string instead of datetime.", exception=NeptuneWarning, diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index c66472d8a..b626cc0b1 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -67,7 +67,7 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): new=lambda *args, **kwargs: [ LeaderboardEntry( id="123", - attributes=[ + fields=[ Field( "sys/creation_time", AttributeType.DATETIME, From 0e211f0171b75703c34f17b696daa98fd6d6622f Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Thu, 4 Apr 2024 13:43:22 +0200 Subject: [PATCH 04/22] Initial refactor --- src/neptune/api/dtos.py | 33 -- src/neptune/api/models.py | 383 ++++++++++++++++++ src/neptune/api/searching_entries.py | 42 +- src/neptune/attributes/utils.py | 34 +- src/neptune/integrations/pandas/__init__.py | 146 ++++--- src/neptune/internal/backends/api_model.py | 107 ----- .../backends/hosted_artifact_operations.py | 6 +- .../backends/hosted_neptune_backend.py | 108 ++--- .../internal/backends/neptune_backend.py | 37 +- .../internal/backends/neptune_backend_mock.py | 131 +++--- .../backends/offline_neptune_backend.py | 33 +- .../utils/generic_attribute_mapper.py | 30 +- src/neptune/objects/neptune_object.py | 4 +- src/neptune/objects/utils.py | 10 +- src/neptune/table.py | 136 ++++--- .../neptune/new/api/test_searching_entries.py | 16 +- .../new/attributes/test_attribute_utils.py | 4 +- .../new/client/abstract_tables_test.py | 31 +- tests/unit/neptune/new/client/test_model.py | 12 +- .../neptune/new/client/test_model_version.py | 21 +- tests/unit/neptune/new/client/test_project.py | 22 +- tests/unit/neptune/new/client/test_run.py | 12 +- .../neptune/new/client/test_run_tables.py | 10 +- .../internal/backends/test_hosted_client.py | 14 +- .../backends/test_neptune_backend_mock.py | 37 +- 25 files changed, 840 insertions(+), 579 deletions(-) delete mode 100644 src/neptune/api/dtos.py create mode 100644 src/neptune/api/models.py diff --git a/src/neptune/api/dtos.py b/src/neptune/api/dtos.py deleted file mode 100644 index 2895b43df..000000000 --- a/src/neptune/api/dtos.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Copyright (c) 2024, Neptune Labs Sp. z o.o. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -__all__ = ["FileEntry"] - -import datetime -from dataclasses import dataclass -from typing import Any - - -@dataclass -class FileEntry: - name: str - size: int - mtime: datetime.datetime - file_type: str - - @classmethod - def from_dto(cls, file_dto: Any) -> "FileEntry": - return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py new file mode 100644 index 000000000..30998ec09 --- /dev/null +++ b/src/neptune/api/models.py @@ -0,0 +1,383 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +__all__ = ( + "FileEntry", + "Field", + "FieldType", + "LeaderboardEntry", + "LeaderboardEntriesSearchResult", + "FieldVisitor", + "FloatField", + "IntField", + "BoolField", + "StringField", + "DatetimeField", + "FileField", + "FileSetField", + "FloatSeriesField", + "StringSeriesField", + "ImageSeriesField", + "StringSetField", + "GitRefField", + "ObjectStateField", + "NotebookRefField", + "ArtifactField", + "FieldDefinition", +) + +import abc +from dataclasses import ( + dataclass, + field as dataclass_field, +) +from typing import TypeVar, Generic, Dict +from datetime import datetime +from enum import Enum +from typing import ( + Any, + Optional, + Set, + List, +) + +Ret = TypeVar("Ret") + + +@dataclass +class FileEntry: + name: str + size: int + mtime: datetime + file_type: str + + @classmethod + def from_dto(cls, file_dto: Any) -> "FileEntry": + return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType) + + +class FieldType(Enum): + FLOAT = "float" + INT = "int" + BOOL = "bool" + STRING = "string" + DATETIME = "datetime" + FILE = "file" + FILE_SET = "fileSet" + FLOAT_SERIES = "floatSeries" + STRING_SERIES = "stringSeries" + IMAGE_SERIES = "imageSeries" + STRING_SET = "stringSet" + GIT_REF = "gitRef" + OBJECT_STATE = "experimentState" + NOTEBOOK_REF = "notebookRef" + ARTIFACT = "artifact" + + +@dataclass +class Field(abc.ABC): + path: str + type: FieldType = dataclass_field(init=False, default=None) + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + # TODO: remove this when we have proper type hints + cls.type = kwargs.get('type', None) + + @abc.abstractmethod + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + ... + + @staticmethod + def from_dict(field: Dict[str, Any]) -> Field: + raise NotImplementedError() + + +class FieldVisitor(Generic[Ret], abc.ABC): + + def visit(self, field: Field) -> Ret: + return field.accept(self) + + def visit_float(self, field: FloatField) -> Ret: + ... + + def visit_int(self, field: IntField) -> Ret: + ... + + def visit_bool(self, field: BoolField) -> Ret: + ... + + def visit_string(self, field: StringField) -> Ret: + ... + + def visit_datetime(self, field: DatetimeField) -> Ret: + ... + + def visit_file(self, field: FileField) -> Ret: + ... + + def visit_file_set(self, field: FileSetField) -> Ret: + ... + + def visit_float_series(self, field: FloatSeriesField) -> Ret: + ... + + def visit_string_series(self, field: StringSeriesField) -> Ret: + ... + + def visit_image_series(self, field: ImageSeriesField) -> Ret: + ... + + def visit_string_set(self, field: StringSetField) -> Ret: + ... + + def visit_git_ref(self, field: GitRefField) -> Ret: + ... + + def visit_object_state(self, field: ObjectStateField) -> Ret: + ... + + def visit_notebook_ref(self, field: NotebookRefField) -> Ret: + ... + + def visit_artifact(self, field: ArtifactField) -> Ret: + ... + + +@dataclass +class FloatField(Field, type=FieldType.FLOAT): + value: float + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_float(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FloatField: + # TODO: Map only if not null + return FloatField(path=data["attributeName"], value=float(data["value"])) + + +@dataclass +class IntField(Field, type=FieldType.INT): + value: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_int(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> IntField: + # TODO: Map only if not null + return IntField(path=data["attributeName"], value=int(data["value"])) + + +@dataclass +class BoolField(Field, type=FieldType.BOOL): + value: bool + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_bool(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> BoolField: + # TODO: Map only if not null + return BoolField(path=data["attributeName"], value=bool(data["value"])) + + +@dataclass +class StringField(Field, type=FieldType.STRING): + value: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_string(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringField: + # TODO: Map only if not null + return StringField(path=data["attributeName"], value=str(data["value"])) + + +@dataclass +class DatetimeField(Field, type=FieldType.DATETIME): + value: datetime + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_datetime(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> DatetimeField: + # TODO: parse datetime + return DatetimeField(path=data["attributeName"], value=data["value"]) + + +@dataclass +class FileField(Field, type=FieldType.FILE): + name: str + ext: str + size: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_file(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FileField: + return FileField( + path=data["path"], + name=data["name"], + ext=data["ext"], + size=int(data["size"]) + ) + + +@dataclass +class FileSetField(Field, type=FieldType.FILE_SET): + size: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_file_set(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FileSetField: + return FileSetField(path=data["attributeName"], size=int(data["size"])) + + +@dataclass +class FloatSeriesField(Field, type=FieldType.FLOAT_SERIES): + last: Optional[float] + + def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + return visitor.visit_float_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FloatSeriesField: + # TODO: last is optional so map to float if present + return FloatSeriesField(path=data["attributeName"], last=data["last"]) + + +@dataclass +class StringSeriesField(Field, type=FieldType.STRING_SERIES): + last: Optional[str] + + def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + return visitor.visit_string_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringSeriesField: + # TODO: last is optional so map to str if present + return StringSeriesField(path=data["attributeName"], last=data["last"]) + + +@dataclass +class ImageSeriesField(Field, type=FieldType.IMAGE_SERIES): + last_step: Optional[float] + + def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + return visitor.visit_image_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> ImageSeriesField: + # TODO: last_step is optional so map to float if present + return ImageSeriesField(path=data["attributeName"], last_step=data["lastStep"]) + + +@dataclass +class StringSetField(Field, type=FieldType.STRING_SET): + values: Set[str] + + def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + return visitor.visit_string_set(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringSetField: + return StringSetField(path=data["attributeName"], values=set(data["values"])) + + +@dataclass +class GitRefField(Field, type=FieldType.GIT_REF): + commit_id: Optional[str] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_git_ref(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> GitRefField: + # TODO: commit and commit_id is optional so map to str if present + return GitRefField(path=data["attributeName"], commit_id=data["commit"]["commitId"]) + + +@dataclass +class ObjectStateField(Field, type=FieldType.OBJECT_STATE): + value: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_object_state(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> ObjectStateField: + return ObjectStateField(path=data["attributeName"], value=str(data["value"])) + + +@dataclass +class NotebookRefField(Field, type=FieldType.NOTEBOOK_REF): + notebook_name: Optional[str] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_notebook_ref(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> NotebookRefField: + # TODO: notebook_name is optional so map to str if present + return NotebookRefField(path=data["attributeName"], notebook_name=data["notebookName"]) + + +@dataclass +class ArtifactField(Field, type=FieldType.ARTIFACT): + hash: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_artifact(self) + + +@dataclass +class LeaderboardEntry: + object_id: str + fields: List[Field] + + @staticmethod + def from_dict(data: Dict[str, Any]) -> LeaderboardEntry: + return LeaderboardEntry( + object_id=data["experimentId"], + fields=[] # TODO: map fields + ) + + +@dataclass +class LeaderboardEntriesSearchResult: + entries: List[LeaderboardEntry] + matching_item_count: int + + @staticmethod + def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + entries=[LeaderboardEntry.from_dict(entry) for entry in result["entries"]], + matching_item_count=result["matchingItemCount"], + ) + + +@dataclass +class FieldDefinition: + path: str + type: FieldType diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index ffa13a26e..b12b0a4be 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -34,11 +34,7 @@ ) from neptune.exceptions import NeptuneInvalidQueryException -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - LeaderboardEntry, -) +from neptune.api.models import Field, FieldType, LeaderboardEntry, LeaderboardEntriesSearchResult from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS from neptune.internal.backends.nql import ( NQLAggregator, @@ -58,7 +54,7 @@ from neptune.internal.id_formats import UniqueId -SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in AttributeType} +SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in FieldType} SORT_BY_COLUMN_TYPE: TypeAlias = Literal["string", "datetime", "integer", "boolean", "float"] @@ -98,7 +94,7 @@ def get_single_page( searching_after: Optional[str], ) -> Any: normalized_query = query or NQLEmptyQuery() - sort_by_column_type = sort_by_column_type if sort_by_column_type else AttributeType.STRING.value + sort_by_column_type = sort_by_column_type if sort_by_column_type else FieldType.STRING.value if sort_by and searching_after: sort_by_as_nql = NQLQueryAttribute( name=sort_by, @@ -119,7 +115,7 @@ def get_single_page( "aggregationMode": "none", "sortBy": { "name": sort_by, - "type": sort_by_column_type if sort_by_column_type else AttributeType.STRING.value, + "type": sort_by_column_type if sort_by_column_type else FieldType.STRING.value, }, } } @@ -158,18 +154,17 @@ def get_single_page( def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: - return LeaderboardEntry( - id=entry["experimentId"], - fields=[ - Field( - path=attr["name"], - type=AttributeType(attr["type"]), - properties=attr.__getitem__(f"{attr['type']}Properties"), - ) - for attr in entry["attributes"] - if attr["type"] in SUPPORTED_ATTRIBUTE_TYPES - ], - ) + # return LeaderboardEntry( + # fields=[ + # Field( + # path=attr["name"], + # type=FieldType(attr["type"]), + # properties=attr.__getitem__(f"{attr['type']}Properties"), + # ) + # for attr in entry["attributes"] + # if attr["type"] in SUPPORTED_ATTRIBUTE_TYPES + # ], + # ) def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]: @@ -190,6 +185,7 @@ def iter_over_pages( searching_after = None last_page = None + # TODO: Refactor total = get_single_page( limit=0, offset=0, @@ -222,12 +218,14 @@ def iter_over_pages( if not page_attribute: raise ValueError(f"Cannot find attribute {sort_by} in last page") + # TODO: Refactor searching_after = page_attribute.properties["value"] for offset in range(0, max_offset, step_size): local_limit = min(step_size, max_offset - offset) if extracted_records + local_limit > limit: local_limit = limit - extracted_records + result = get_single_page( limit=local_limit, offset=offset, @@ -240,6 +238,7 @@ def iter_over_pages( # fetch the item count everytime a new page is started (except for the very fist page) if offset == 0 and last_page is not None: + # TODO: Refactor total += result.get("matchingItemCount", 0) total = min(total, limit) @@ -259,5 +258,6 @@ def iter_over_pages( last_page = page +# TODO: Refactor def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]: - return list(map(to_leaderboard_entry, single_page.get("entries", []))) + return LeaderboardEntriesSearchResult.from_dict(single_page).entries diff --git a/src/neptune/attributes/utils.py b/src/neptune/attributes/utils.py index ea50f1f1f..f7bb02b6f 100644 --- a/src/neptune/attributes/utils.py +++ b/src/neptune/attributes/utils.py @@ -37,7 +37,7 @@ StringSeries, StringSet, ) -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType from neptune.internal.exceptions import InternalClientError if TYPE_CHECKING: @@ -45,26 +45,26 @@ from neptune.objects import NeptuneObject _attribute_type_to_attr_class_map = { - AttributeType.FLOAT: Float, - AttributeType.INT: Integer, - AttributeType.BOOL: Boolean, - AttributeType.STRING: String, - AttributeType.DATETIME: Datetime, - AttributeType.FILE: File, - AttributeType.FILE_SET: FileSet, - AttributeType.FLOAT_SERIES: FloatSeries, - AttributeType.STRING_SERIES: StringSeries, - AttributeType.IMAGE_SERIES: FileSeries, - AttributeType.STRING_SET: StringSet, - AttributeType.GIT_REF: GitRef, - AttributeType.RUN_STATE: RunState, - AttributeType.NOTEBOOK_REF: NotebookRef, - AttributeType.ARTIFACT: Artifact, + FieldType.FLOAT: Float, + FieldType.INT: Integer, + FieldType.BOOL: Boolean, + FieldType.STRING: String, + FieldType.DATETIME: Datetime, + FieldType.FILE: File, + FieldType.FILE_SET: FileSet, + FieldType.FLOAT_SERIES: FloatSeries, + FieldType.STRING_SERIES: StringSeries, + FieldType.IMAGE_SERIES: FileSeries, + FieldType.STRING_SET: StringSet, + FieldType.GIT_REF: GitRef, + FieldType.OBJECT_STATE: RunState, + FieldType.NOTEBOOK_REF: NotebookRef, + FieldType.ARTIFACT: Artifact, } def create_attribute_from_type( - attribute_type: AttributeType, + attribute_type: FieldType, container: "NeptuneObject", path: List[str], ) -> "Attribute": diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 771f380da..f0ff57014 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -20,81 +20,113 @@ from datetime import datetime from typing import ( TYPE_CHECKING, - Any, Dict, Tuple, Union, + Optional, ) import pandas as pd -from neptune.internal.backends.api_model import ( - AttributeType, - Field, +from neptune.api.models import ( LeaderboardEntry, + FieldVisitor, + FloatField, + IntField, + BoolField, + StringField, + DatetimeField, + FloatSeriesField, + StringSeriesField, + ImageSeriesField, + FileField, + FileSetField, + GitRefField, + NotebookRefField, + ArtifactField, + StringSetField, + ObjectStateField, ) -from neptune.internal.utils.logger import get_logger from neptune.internal.utils.run_state import RunState if TYPE_CHECKING: from neptune.table import Table -logger = get_logger() +PANDAS_AVAILABLE_TYPES = Union[str, float, int, bool, datetime, None] -def to_pandas(table: Table) -> pd.DataFrame: - def make_attribute_value(attribute: Field) -> Any: - _type = attribute.type - _properties = attribute.properties - if _type == AttributeType.RUN_STATE: - return RunState.from_api(_properties.get("value")).value - if _type in ( - AttributeType.FLOAT, - AttributeType.INT, - AttributeType.BOOL, - AttributeType.STRING, - AttributeType.DATETIME, - ): - return _properties.get("value") - if _type == AttributeType.FLOAT_SERIES: - return _properties.get("last") - if _type == AttributeType.STRING_SERIES: - return _properties.get("last") - if _type == AttributeType.IMAGE_SERIES: - return None - if _type == AttributeType.FILE or _type == AttributeType.FILE_SET: - return None - if _type == AttributeType.STRING_SET: - return ",".join(_properties.get("values")) - if _type == AttributeType.GIT_REF: - return _properties.get("commit", {}).get("commitId") - if _type == AttributeType.NOTEBOOK_REF: - return _properties.get("notebookName") - if _type == AttributeType.ARTIFACT: - return _properties.get("hash") - logger.error( - "Attribute type %s not supported in this version, yielding None. Recommended client upgrade.", - _type, - ) +class FieldToPandasValueVisitor(FieldVisitor[PANDAS_AVAILABLE_TYPES]): + + def visit_float(self, field: FloatField) -> float: + return field.value + + def visit_int(self, field: IntField) -> int: + return field.value + + def visit_bool(self, field: BoolField) -> bool: + return field.value + + def visit_string(self, field: StringField) -> str: + return field.value + + def visit_datetime(self, field: DatetimeField) -> datetime: + return field.value + + def visit_file(self, field: FileField) -> None: + return None + + def visit_string_set(self, field: StringSetField) -> Optional[str]: + return ",".join(field.values) + + def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: + return field.last + + def visit_string_series(self, field: StringSeriesField) -> Optional[str]: + return field.last + + def visit_image_series(self, field: ImageSeriesField) -> None: + return None + + def visit_file_set(self, field: FileSetField) -> None: return None - def make_row(entry: LeaderboardEntry) -> Dict[str, Any]: - row: Dict[str, Union[str, float, datetime]] = dict() - for attr in entry.fields: - value = make_attribute_value(attr) - if value is not None: - row[attr.path] = value - return row - - def sort_key(attr: str) -> Tuple[int, str]: - domain = attr.split("/")[0] - if domain == "sys": - return 0, attr - if domain == "monitoring": - return 2, attr - return 1, attr - - rows = dict((n, make_row(entry)) for (n, entry) in enumerate(table._entries)) + def visit_git_ref(self, field: GitRefField) -> str: + return field.commit_id + + def visit_object_state(self, field: ObjectStateField) -> str: + return RunState.from_api(field.value).value + + def visit_notebook_ref(self, field: NotebookRefField) -> str: + return field.notebook_name + + def visit_artifact(self, field: ArtifactField) -> str: + return field.hash + + +def make_row(entry: LeaderboardEntry, to_value_visitor: FieldVisitor) -> Dict[str, PANDAS_AVAILABLE_TYPES]: + row: Dict[str, PANDAS_AVAILABLE_TYPES] = dict() + + for field in entry.fields: + value = to_value_visitor.visit(field) + if value is not None: + row[field.path] = value + + return row + + +def sort_key(field: str) -> Tuple[int, str]: + domain = field.split("/")[0] + if domain == "sys": + return 0, field + if domain == "monitoring": + return 2, field + return 1, field + + +def to_pandas(table: Table) -> pd.DataFrame: + + to_value_visitor = FieldToPandasValueVisitor() + rows = dict((n, make_row(entry, to_value_visitor)) for (n, entry) in enumerate(table._entries)) df = pd.DataFrame.from_dict(data=rows, orient="index") df = df.reindex(sorted(df.columns, key=sort_key), axis="columns") diff --git a/src/neptune/internal/backends/api_model.py b/src/neptune/internal/backends/api_model.py index 6837c5666..660d1c7e5 100644 --- a/src/neptune/internal/backends/api_model.py +++ b/src/neptune/internal/backends/api_model.py @@ -20,38 +20,20 @@ "OptionalFeatures", "VersionInfo", "ClientConfig", - "AttributeType", - "FieldDefinition", - "Field", - "LeaderboardEntry", "StringPointValue", "ImageSeriesValues", "StringSeriesValues", "FloatPointValue", "FloatSeriesValues", - "FloatAttribute", - "IntAttribute", - "BoolAttribute", - "FileAttribute", - "StringAttribute", - "DatetimeAttribute", - "ArtifactAttribute", "ArtifactModel", - "FloatSeriesAttribute", - "StringSeriesAttribute", - "StringSetAttribute", "MultipartConfig", ] from dataclasses import dataclass -from datetime import datetime -from enum import Enum from typing import ( - Any, FrozenSet, List, Optional, - Set, ) from packaging import version @@ -199,43 +181,6 @@ def from_api_response(config) -> "ClientConfig": ) -class AttributeType(Enum): - FLOAT = "float" - INT = "int" - BOOL = "bool" - STRING = "string" - DATETIME = "datetime" - FILE = "file" - FILE_SET = "fileSet" - FLOAT_SERIES = "floatSeries" - STRING_SERIES = "stringSeries" - IMAGE_SERIES = "imageSeries" - STRING_SET = "stringSet" - GIT_REF = "gitRef" - RUN_STATE = "experimentState" - NOTEBOOK_REF = "notebookRef" - ARTIFACT = "artifact" - - -@dataclass -class FieldDefinition: - path: str - type: AttributeType - - -@dataclass -class Field: - path: str - type: AttributeType - properties: Any - - -@dataclass -class LeaderboardEntry: - id: str - fields: List[Field] - - @dataclass class StringPointValue: timestampMillis: int @@ -267,60 +212,8 @@ class FloatSeriesValues: values: List[FloatPointValue] -@dataclass -class FloatAttribute: - value: float - - -@dataclass -class IntAttribute: - value: int - - -@dataclass -class BoolAttribute: - value: bool - - -@dataclass -class FileAttribute: - name: str - ext: str - size: int - - -@dataclass -class StringAttribute: - value: str - - -@dataclass -class DatetimeAttribute: - value: datetime - - -@dataclass -class ArtifactAttribute: - hash: str - - @dataclass class ArtifactModel: received_metadata: bool hash: str size: int - - -@dataclass -class FloatSeriesAttribute: - last: Optional[float] - - -@dataclass -class StringSeriesAttribute: - last: Optional[str] - - -@dataclass -class StringSetAttribute: - values: Set[str] diff --git a/src/neptune/internal/backends/hosted_artifact_operations.py b/src/neptune/internal/backends/hosted_artifact_operations.py index 497e85fbc..393ce3aa2 100644 --- a/src/neptune/internal/backends/hosted_artifact_operations.py +++ b/src/neptune/internal/backends/hosted_artifact_operations.py @@ -43,9 +43,9 @@ ArtifactFileData, ) from neptune.internal.backends.api_model import ( - ArtifactAttribute, ArtifactModel, ) +from neptune.api.models import ArtifactField from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper from neptune.internal.backends.utils import with_api_exceptions_handler from neptune.internal.operation import ( @@ -254,7 +254,7 @@ def get_artifact_attribute( parent_identifier: str, path: List[str], default_request_params: Dict, -) -> ArtifactAttribute: +) -> ArtifactField: requests_params = add_artifact_version_to_request_params(default_request_params) params = { "experimentId": parent_identifier, @@ -263,7 +263,7 @@ def get_artifact_attribute( } try: result = swagger_client.api.getArtifactAttribute(**params).response().result - return ArtifactAttribute(hash=result.hash) + return ArtifactField(hash=result.hash) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 214ad83c6..d29799f97 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -38,7 +38,6 @@ HTTPUnprocessableEntity, ) -from neptune.api.dtos import FileEntry from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE @@ -58,28 +57,31 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FieldDefinition, - FileAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, OptionalFeatures, Project, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) +from neptune.api.models import ( + FloatField, + IntField, + BoolField, + FileField, + StringField, + DatetimeField, + ArtifactField, + FloatSeriesField, + StringSeriesField, + StringSetField, + FieldType, + FieldDefinition, + LeaderboardEntry, + FileEntry, +) from neptune.internal.backends.hosted_artifact_operations import ( get_artifact_attribute, list_artifact_files, @@ -155,21 +157,21 @@ _logger = get_logger() ATOMIC_ATTRIBUTE_TYPES = { - AttributeType.INT.value, - AttributeType.FLOAT.value, - AttributeType.STRING.value, - AttributeType.BOOL.value, - AttributeType.DATETIME.value, - AttributeType.RUN_STATE.value, + FieldType.INT.value, + FieldType.FLOAT.value, + FieldType.STRING.value, + FieldType.BOOL.value, + FieldType.DATETIME.value, + FieldType.OBJECT_STATE.value, } ATOMIC_ATTRIBUTE_TYPES = { - AttributeType.INT.value, - AttributeType.FLOAT.value, - AttributeType.STRING.value, - AttributeType.BOOL.value, - AttributeType.DATETIME.value, - AttributeType.RUN_STATE.value, + FieldType.INT.value, + FieldType.FLOAT.value, + FieldType.STRING.value, + FieldType.BOOL.value, + FieldType.DATETIME.value, + FieldType.OBJECT_STATE.value, } @@ -684,7 +686,7 @@ def _execute_operations( @with_api_exceptions_handler def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: def to_attribute(attr) -> FieldDefinition: - return FieldDefinition(attr.name, AttributeType(attr.type)) + return FieldDefinition(attr.name, FieldType(attr.type)) params = { "experimentId": container_id, @@ -693,7 +695,7 @@ def to_attribute(attr) -> FieldDefinition: try: experiment = self.leaderboard_client.api.getExperimentAttributes(**params).response().result - attribute_type_names = [at.value for at in AttributeType] + attribute_type_names = [at.value for at in FieldType] accepted_attributes = [attr for attr in experiment.fields if attr.type in attribute_type_names] # Notify about ignored attrs @@ -782,7 +784,7 @@ def download_file_set( raise @with_api_exceptions_handler - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -790,12 +792,12 @@ def get_float_attribute(self, container_id: str, container_type: ContainerType, } try: result = self.leaderboard_client.api.getFloatAttribute(**params).response().result - return FloatAttribute(result.value) + return FloatField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -803,12 +805,12 @@ def get_int_attribute(self, container_id: str, container_type: ContainerType, pa } try: result = self.leaderboard_client.api.getIntAttribute(**params).response().result - return IntAttribute(result.value) + return IntField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -816,12 +818,12 @@ def get_bool_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getBoolAttribute(**params).response().result - return BoolAttribute(result.value) + return BoolField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -829,14 +831,14 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getFileAttribute(**params).response().result - return FileAttribute(name=result.name, ext=result.ext, size=result.size) + return FileField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_string_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + ) -> StringField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -844,14 +846,14 @@ def get_string_attribute( } try: result = self.leaderboard_client.api.getStringAttribute(**params).response().result - return StringAttribute(result.value) + return StringField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DatetimeField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -859,13 +861,13 @@ def get_datetime_attribute( } try: result = self.leaderboard_client.api.getDatetimeAttribute(**params).response().result - return DatetimeAttribute(result.value) + return DatetimeField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: return get_artifact_attribute( swagger_client=self.leaderboard_client, parent_identifier=container_id, @@ -899,7 +901,7 @@ def list_fileset_files(self, attribute: List[str], container_id: str, path: str) @with_api_exceptions_handler def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -907,14 +909,14 @@ def get_float_series_attribute( } try: result = self.leaderboard_client.api.getFloatSeriesAttribute(**params).response().result - return FloatSeriesAttribute(result.last) + return FloatSeriesField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -922,14 +924,14 @@ def get_string_series_attribute( } try: result = self.leaderboard_client.api.getStringSeriesAttribute(**params).response().result - return StringSeriesAttribute(result.last) + return StringSeriesField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -937,7 +939,7 @@ def get_string_set_attribute( } try: result = self.leaderboard_client.api.getStringSetAttribute(**params).response().result - return StringSetAttribute(set(result.values)) + return StringSetField.from_dict(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -1016,7 +1018,7 @@ def get_float_series_values( @with_api_exceptions_handler def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: params = { "experimentId": container_id, } @@ -1081,9 +1083,9 @@ def search_leaderboard_entries( attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {} if sort_by == "sys/creation_time": - sort_by_column_type = AttributeType.DATETIME.value + sort_by_column_type = FieldType.DATETIME.value if sort_by == "sys/id": - sort_by_column_type = AttributeType.STRING.value + sort_by_column_type = FieldType.STRING.value else: sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter) sort_by_column_type = _get_column_type_from_entries(sort_by_column_type_candidates, sort_by) @@ -1147,11 +1149,11 @@ def _get_column_type_from_entries(entries: List[Any], column: str) -> str: ) types.add(entry.type) - if types == {AttributeType.INT.value, AttributeType.FLOAT.value}: - return AttributeType.FLOAT.value + if types == {FieldType.INT.value, FieldType.FLOAT.value}: + return FieldType.FLOAT.value warn_once( f"Column {column} contains more than one simple data type. Sorting result might be inaccurate.", exception=NeptuneWarning, ) - return AttributeType.STRING.value + return FieldType.STRING.value diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 23902b55c..51766968b 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -30,25 +30,14 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FieldDefinition, - FileAttribute, - FloatAttribute, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, Project, - StringAttribute, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) +from neptune.api.models import FloatField, IntField, BoolField, FileField, StringField, DatetimeField, ArtifactField, \ + FloatSeriesField, StringSeriesField, StringSetField, FieldType, FieldDefinition, LeaderboardEntry from neptune.internal.backends.nql import NQLQuery from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException @@ -172,37 +161,37 @@ def download_file_set( pass @abc.abstractmethod - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: pass @abc.abstractmethod - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: pass @abc.abstractmethod - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: pass @abc.abstractmethod - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: pass @abc.abstractmethod def get_string_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + ) -> StringField: pass @abc.abstractmethod def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DatetimeField: pass @abc.abstractmethod def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: pass @abc.abstractmethod @@ -212,19 +201,19 @@ def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[Artif @abc.abstractmethod def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: pass @abc.abstractmethod def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: pass @abc.abstractmethod def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: pass @abc.abstractmethod @@ -298,7 +287,7 @@ def get_model_version_url( @abc.abstractmethod def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 41243314c..f3b0b0729 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -34,7 +34,7 @@ ) from zipfile import ZipFile -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.core.components.operation_storage import OperationStorage from neptune.exceptions import ( ContainerUUIDNotFound, @@ -46,27 +46,29 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FieldDefinition, - FileAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, Project, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) +from neptune.api.models import ( + FloatField, + IntField, + BoolField, + FileField, + StringField, + DatetimeField, + ArtifactField, + FloatSeriesField, + StringSeriesField, + StringSetField, + FieldType, + FieldDefinition, + LeaderboardEntry, +) from neptune.internal.backends.hosted_file_operations import get_unique_upload_entries from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.backends.nql import NQLQuery @@ -370,21 +372,22 @@ def download_file_set( for upload_entry in upload_entries: zipObj.write(upload_entry.source, upload_entry.target_path) - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: val = self._get_attribute(container_id, container_type, path, Float) - return FloatAttribute(val.value) + return FloatField(path=path_to_str(path), value=val.value) - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: val = self._get_attribute(container_id, container_type, path, Integer) - return IntAttribute(val.value) + return IntField(path=path_to_str(path), value=val.value) - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: val = self._get_attribute(container_id, container_type, path, Boolean) - return BoolAttribute(val.value) + return BoolField(path=path_to_str(path), value=val.value) - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: val = self._get_attribute(container_id, container_type, path, File) - return FileAttribute( + return FileField( + path=path_to_str(path), name=os.path.basename(val.path) if val.file_type is FileType.LOCAL_FILE else "", ext=val.extension or "", size=0, @@ -392,42 +395,42 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p def get_string_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + ) -> StringField: val = self._get_attribute(container_id, container_type, path, String) - return StringAttribute(val.value) + return StringField(path=path_to_str(path), value=val.value) def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DatetimeField: val = self._get_attribute(container_id, container_type, path, Datetime) - return DatetimeAttribute(val.value) + return DatetimeField(path=path_to_str(path), value=val.value) def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: val = self._get_attribute(container_id, container_type, path, Artifact) - return ArtifactAttribute(val.hash) + return ArtifactField(path=path_to_str(path), hash=val.hash) def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[ArtifactFileData]: return self._artifacts[(project_id, artifact_hash)] def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: val = self._get_attribute(container_id, container_type, path, FloatSeries) - return FloatSeriesAttribute(val.values[-1] if val.values else None) + return FloatSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None) def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: val = self._get_attribute(container_id, container_type, path, StringSeries) - return StringSeriesAttribute(val.values[-1] if val.values else None) + return StringSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None) def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: val = self._get_attribute(container_id, container_type, path, StringSet) - return StringSetAttribute(set(val.values)) + return StringSetField(path=path_to_str(path), values=set(val.values)) def _get_attribute( self, @@ -528,7 +531,7 @@ def _get_attribute_values(self, value_dict, path_prefix: List[str]): def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: run = self._get_container(container_id, container_type) values = self._get_attribute_values(run.get(path), path) namespace_prefix = path_to_str(path) @@ -554,50 +557,50 @@ def search_leaderboard_entries( ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" - class AttributeTypeConverterValueVisitor(ValueVisitor[AttributeType]): - def visit_float(self, _: Float) -> AttributeType: - return AttributeType.FLOAT + class AttributeTypeConverterValueVisitor(ValueVisitor[FieldType]): + def visit_float(self, _: Float) -> FieldType: + return FieldType.FLOAT - def visit_integer(self, _: Integer) -> AttributeType: - return AttributeType.INT + def visit_integer(self, _: Integer) -> FieldType: + return FieldType.INT - def visit_boolean(self, _: Boolean) -> AttributeType: - return AttributeType.BOOL + def visit_boolean(self, _: Boolean) -> FieldType: + return FieldType.BOOL - def visit_string(self, _: String) -> AttributeType: - return AttributeType.STRING + def visit_string(self, _: String) -> FieldType: + return FieldType.STRING - def visit_datetime(self, _: Datetime) -> AttributeType: - return AttributeType.DATETIME + def visit_datetime(self, _: Datetime) -> FieldType: + return FieldType.DATETIME - def visit_file(self, _: File) -> AttributeType: - return AttributeType.FILE + def visit_file(self, _: File) -> FieldType: + return FieldType.FILE - def visit_file_set(self, _: FileSet) -> AttributeType: - return AttributeType.FILE_SET + def visit_file_set(self, _: FileSet) -> FieldType: + return FieldType.FILE_SET - def visit_float_series(self, _: FloatSeries) -> AttributeType: - return AttributeType.FLOAT_SERIES + def visit_float_series(self, _: FloatSeries) -> FieldType: + return FieldType.FLOAT_SERIES - def visit_string_series(self, _: StringSeries) -> AttributeType: - return AttributeType.STRING_SERIES + def visit_string_series(self, _: StringSeries) -> FieldType: + return FieldType.STRING_SERIES - def visit_image_series(self, _: FileSeries) -> AttributeType: - return AttributeType.IMAGE_SERIES + def visit_image_series(self, _: FileSeries) -> FieldType: + return FieldType.IMAGE_SERIES - def visit_string_set(self, _: StringSet) -> AttributeType: - return AttributeType.STRING_SET + def visit_string_set(self, _: StringSet) -> FieldType: + return FieldType.STRING_SET - def visit_git_ref(self, _: GitRef) -> AttributeType: - return AttributeType.GIT_REF + def visit_git_ref(self, _: GitRef) -> FieldType: + return FieldType.GIT_REF - def visit_artifact(self, _: Artifact) -> AttributeType: - return AttributeType.ARTIFACT + def visit_artifact(self, _: Artifact) -> FieldType: + return FieldType.ARTIFACT - def visit_namespace(self, _: Namespace) -> AttributeType: + def visit_namespace(self, _: Namespace) -> FieldType: raise NotImplementedError - def copy_value(self, source_type: Type[FieldDefinition], source_path: List[str]) -> AttributeType: + def copy_value(self, source_type: Type[FieldDefinition], source_path: List[str]) -> FieldType: raise NotImplementedError class NewValueOpVisitor(OperationVisitor[Optional[Value]]): diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index dd3b0b092..dda11fe74 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -24,21 +24,12 @@ from neptune.exceptions import NeptuneOfflineModeFetchException from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( - ArtifactAttribute, - BoolAttribute, - DatetimeAttribute, - FieldDefinition, - FileAttribute, - FloatAttribute, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - StringAttribute, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, ) +from neptune.api.models import FloatField, IntField, BoolField, FileField, StringField, DatetimeField, ArtifactField, \ + FloatSeriesField, StringSeriesField, StringSetField, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.typing import ProgressBarType @@ -50,31 +41,31 @@ class OfflineNeptuneBackend(NeptuneBackendMock): def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: raise NeptuneOfflineModeFetchException - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: raise NeptuneOfflineModeFetchException - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: raise NeptuneOfflineModeFetchException - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: raise NeptuneOfflineModeFetchException - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: raise NeptuneOfflineModeFetchException def get_string_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + ) -> StringField: raise NeptuneOfflineModeFetchException def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DatetimeField: raise NeptuneOfflineModeFetchException def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: raise NeptuneOfflineModeFetchException def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[ArtifactFileData]: @@ -82,17 +73,17 @@ def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[Artif def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: raise NeptuneOfflineModeFetchException def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: raise NeptuneOfflineModeFetchException def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: raise NeptuneOfflineModeFetchException def get_string_series_values( diff --git a/src/neptune/internal/utils/generic_attribute_mapper.py b/src/neptune/internal/utils/generic_attribute_mapper.py index 3f5a29752..da40ba9b2 100644 --- a/src/neptune/internal/utils/generic_attribute_mapper.py +++ b/src/neptune/internal/utils/generic_attribute_mapper.py @@ -15,7 +15,7 @@ # __all__ = ["NoValue", "atomic_attribute_types_map", "map_attribute_result_to_value"] -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType class NoValue: @@ -27,30 +27,30 @@ class NoValue: VALUES = "values" atomic_attribute_types_map = { - AttributeType.FLOAT.value: "floatProperties", - AttributeType.INT.value: "intProperties", - AttributeType.BOOL.value: "boolProperties", - AttributeType.STRING.value: "stringProperties", - AttributeType.DATETIME.value: "datetimeProperties", - AttributeType.RUN_STATE.value: "experimentStateProperties", - AttributeType.NOTEBOOK_REF.value: "notebookRefProperties", + FieldType.FLOAT.value: "floatProperties", + FieldType.INT.value: "intProperties", + FieldType.BOOL.value: "boolProperties", + FieldType.STRING.value: "stringProperties", + FieldType.DATETIME.value: "datetimeProperties", + FieldType.OBJECT_STATE.value: "experimentStateProperties", + FieldType.NOTEBOOK_REF.value: "notebookRefProperties", } value_series_attribute_types_map = { - AttributeType.FLOAT_SERIES.value: "floatSeriesProperties", - AttributeType.STRING_SERIES.value: "stringSeriesProperties", + FieldType.FLOAT_SERIES.value: "floatSeriesProperties", + FieldType.STRING_SERIES.value: "stringSeriesProperties", } value_set_attribute_types_map = { - AttributeType.STRING_SET.value: "stringSetProperties", + FieldType.STRING_SET.value: "stringSetProperties", } # TODO: nicer mapping? _unmapped_attribute_types_map = { - AttributeType.FILE_SET.value: "fileSetProperties", # TODO: return size? - AttributeType.FILE.value: "fileProperties", # TODO: name? size? - AttributeType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step? - AttributeType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch? + FieldType.FILE_SET.value: "fileSetProperties", # TODO: return size? + FieldType.FILE.value: "fileProperties", # TODO: name? size? + FieldType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step? + FieldType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch? } diff --git a/src/neptune/objects/neptune_object.py b/src/neptune/objects/neptune_object.py index ce4def13c..6f5d05202 100644 --- a/src/neptune/objects/neptune_object.py +++ b/src/neptune/objects/neptune_object.py @@ -51,9 +51,9 @@ from neptune.handler import Handler from neptune.internal.backends.api_model import ( ApiExperiment, - AttributeType, Project, ) +from neptune.api.models import FieldType from neptune.internal.backends.factory import get_backend from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.backends.nql import NQLQuery @@ -630,7 +630,7 @@ def sync(self, *, wait: bool = True) -> None: for attribute in attributes: self._define_attribute(parse_path(attribute.path), attribute.type) - def _define_attribute(self, _path: List[str], _type: AttributeType): + def _define_attribute(self, _path: List[str], _type: FieldType): attr = create_attribute_from_type(_type, self, _path) self._structure.set(_path, attr) diff --git a/src/neptune/objects/utils.py b/src/neptune/objects/utils.py index c76666139..2dd033a35 100644 --- a/src/neptune/objects/utils.py +++ b/src/neptune/objects/utils.py @@ -27,11 +27,7 @@ Union, ) -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - LeaderboardEntry, -) +from neptune.api.models import Field, FieldType, LeaderboardEntry from neptune.internal.backends.nql import ( NQLAggregator, NQLAttributeOperator, @@ -143,7 +139,7 @@ def parse_dates(leaderboard_entries: Iterable[LeaderboardEntry]) -> Generator[Le def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: try: return LeaderboardEntry( - entry.id, + entry.object_id, fields=[ ( Field( @@ -154,7 +150,7 @@ def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: "value": parse_iso_date(attribute.properties["value"]), }, ) - if attribute.type == AttributeType.DATETIME + if attribute.type == FieldType.DATETIME else attribute ) for attribute in entry.fields diff --git a/src/neptune/table.py b/src/neptune/table.py index 812bd28bb..d9651a125 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -20,15 +20,32 @@ Any, Generator, List, - Optional, + Optional, Set, ) +from datetime import datetime from neptune.exceptions import MetadataInconsistency from neptune.integrations.pandas import to_pandas -from neptune.internal.backends.api_model import ( - AttributeType, +from neptune.api.models import ( Field, + FieldType, LeaderboardEntry, + FieldVisitor, + FloatField, + IntField, + BoolField, + StringField, + DatetimeField, + FileField, + FileSetField, + FloatSeriesField, + StringSeriesField, + ImageSeriesField, + StringSetField, + GitRefField, + ObjectStateField, + NotebookRefField, + ArtifactField ) from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.container_type import ContainerType @@ -47,6 +64,54 @@ logger = get_logger() +class FieldToValueVisitor(FieldVisitor[Any]): + + def visit_float(self, field: FloatField) -> float: + return field.value + + def visit_int(self, field: IntField) -> int: + return field.value + + def visit_bool(self, field: BoolField) -> bool: + return field.value + + def visit_string(self, field: StringField) -> str: + return field.value + + def visit_datetime(self, field: DatetimeField) -> datetime: + ... + + def visit_file(self, field: FileField) -> None: + raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") + + def visit_file_set(self, field: FileSetField) -> None: + raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") + + def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: + return field.last + + def visit_string_series(self, field: StringSeriesField) -> Optional[str]: + return field.last + + def visit_image_series(self, field: ImageSeriesField) -> None: + raise MetadataInconsistency("Cannot get value for image series.") + + def visit_string_set(self, field: StringSetField) -> Set[str]: + return field.values + + def visit_git_ref(self, field: GitRefField) -> Optional[str]: + return field.commit_id + + def visit_object_state(self, field: ObjectStateField) -> str: + return RunState.from_api(field.value).value + + def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: + return field.notebook_name + + def visit_artifact(self, field: ArtifactField) -> str: + return field.hash + + class TableEntry: def __init__( self, @@ -58,52 +123,23 @@ def __init__( self._backend = backend self._container_type = container_type self._id = _id - self._attributes = attributes + self._fields = attributes + self._field_to_value_visitor = FieldToValueVisitor() def __getitem__(self, path: str) -> "LeaderboardHandler": return LeaderboardHandler(table_entry=self, path=path) - def get_attribute_type(self, path: str) -> AttributeType: - for attr in self._attributes: - if attr.path == path: - return attr.type - raise ValueError("Could not find {} attribute".format(path)) + def get_attribute_type(self, path: str) -> FieldType: + for field in self._fields: + if field.path == path: + return field.type + + raise ValueError(f"Could not find {path} field") def get_attribute_value(self, path: str) -> Any: - for attr in self._attributes: - if attr.path == path: - _type = attr.type - if _type == AttributeType.RUN_STATE: - return RunState.from_api(attr.properties.get("value")).value - if _type in ( - AttributeType.FLOAT, - AttributeType.INT, - AttributeType.BOOL, - AttributeType.STRING, - AttributeType.DATETIME, - ): - return attr.properties.get("value") - if _type == AttributeType.FLOAT_SERIES or _type == AttributeType.STRING_SERIES: - return attr.properties.get("last") - if _type == AttributeType.IMAGE_SERIES: - raise MetadataInconsistency("Cannot get value for image series.") - if _type == AttributeType.FILE: - raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") - if _type == AttributeType.FILE_SET: - raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") - if _type == AttributeType.STRING_SET: - return set(attr.properties.get("values")) - if _type == AttributeType.GIT_REF: - return attr.properties.get("commit", {}).get("commitId") - if _type == AttributeType.NOTEBOOK_REF: - return attr.properties.get("notebookName") - if _type == AttributeType.ARTIFACT: - return attr.properties.get("hash") - logger.error( - "Attribute type %s not supported in this version, yielding None. Recommended client upgrade.", - _type, - ) - return None + for field in self._fields: + if field.path == path: + return self._field_to_value_visitor.visit(field) raise ValueError("Could not find {} attribute".format(path)) def download_file_attribute( @@ -112,10 +148,10 @@ def download_file_attribute( destination: Optional[str], progress_bar: Optional[ProgressBarType] = None, ) -> None: - for attr in self._attributes: + for attr in self._fields: if attr.path == path: _type = attr.type - if _type == AttributeType.FILE: + if _type == FieldType.FILE: self._backend.download_file( container_id=self._id, container_type=self._container_type, @@ -133,10 +169,10 @@ def download_file_set_attribute( destination: Optional[str], progress_bar: Optional[ProgressBarType] = None, ) -> None: - for attr in self._attributes: + for attr in self._fields: if attr.path == path: _type = attr.type - if _type == AttributeType.FILE_SET: + if _type == FieldType.FILE_SET: self._backend.download_file_set( container_id=self._id, container_type=self._container_type, @@ -162,9 +198,9 @@ def get(self) -> Any: def download(self, destination: Optional[str]) -> None: attr_type = self._table_entry.get_attribute_type(self._path) - if attr_type == AttributeType.FILE: + if attr_type == FieldType.FILE: return self._table_entry.download_file_attribute(self._path, destination) - elif attr_type == AttributeType.FILE_SET: + elif attr_type == FieldType.FILE_SET: return self._table_entry.download_file_set_attribute(path=self._path, destination=destination) raise MetadataInconsistency("Cannot download file from attribute of type {}".format(attr_type)) @@ -193,7 +229,7 @@ def __next__(self) -> TableEntry: return TableEntry( backend=self._backend, container_type=self._container_type, - _id=entry.id, + _id=entry.object_id, attributes=entry.fields, ) diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index d5fff2c9c..4054337c1 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -32,11 +32,7 @@ to_leaderboard_entry, ) from neptune.exceptions import NeptuneInvalidQueryException -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - LeaderboardEntry, -) +from neptune.api.models import Field, StringField, FieldType, LeaderboardEntry def test__to_leaderboard_entry(): @@ -65,18 +61,18 @@ def test__to_leaderboard_entry(): result = to_leaderboard_entry(entry=entry) # then - assert result.id == "foo" + assert result.object_id == "foo" assert result.fields == [ Field( path="plugh", - type=AttributeType.FLOAT, + type=FieldType.FLOAT, properties={ "value": 1.0, }, ), Field( path="sys/id", - type=AttributeType.STRING, + type=FieldType.STRING, properties={ "value": "TEST-123", }, @@ -254,8 +250,8 @@ def test__iter_over_pages__limit(get_single_page, entries_from_page): def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> List[LeaderboardEntry]: return [ LeaderboardEntry( - id=experiment_id, - fields=[Field(path="sys/id", type=AttributeType.STRING, properties={"value": value})], + object_id=experiment_id, + fields=[StringField(path="sys/id", value=value)], ) for value in values ] diff --git a/tests/unit/neptune/new/attributes/test_attribute_utils.py b/tests/unit/neptune/new/attributes/test_attribute_utils.py index 4bbfc7735..3e0456d32 100644 --- a/tests/unit/neptune/new/attributes/test_attribute_utils.py +++ b/tests/unit/neptune/new/attributes/test_attribute_utils.py @@ -18,7 +18,7 @@ from neptune.attributes import create_attribute_from_type from neptune.attributes.attribute import Attribute -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType class TestAttributeUtils(unittest.TestCase): @@ -28,6 +28,6 @@ def test_attribute_type_to_atom(self): self.assertTrue( all( isinstance(create_attribute_from_type(attr_type, MagicMock(), ""), Attribute) - for attr_type in AttributeType + for attr_type in FieldType ) ) diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index e413c1369..1c3b201a1 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -28,12 +28,7 @@ PROJECT_ENV_NAME, ) from neptune.exceptions import MetadataInconsistency -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - FieldDefinition, - LeaderboardEntry, -) +from neptune.api.models import Field, FieldType, FieldDefinition, LeaderboardEntry from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.table import ( Table, @@ -43,7 +38,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition(path="test", type=AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition(path="test", type=FieldType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class AbstractTablesTestMixin: @@ -69,23 +64,23 @@ def setUp(cls) -> None: @staticmethod def build_attributes_leaderboard(now: datetime): attributes = [] - attributes.append(Field("run/state", AttributeType.RUN_STATE, {"value": "idle"})) - attributes.append(Field("float", AttributeType.FLOAT, {"value": 12.5})) - attributes.append(Field("string", AttributeType.STRING, {"value": "some text"})) - attributes.append(Field("datetime", AttributeType.DATETIME, {"value": now})) - attributes.append(Field("float/series", AttributeType.FLOAT_SERIES, {"last": 8.7})) - attributes.append(Field("string/series", AttributeType.STRING_SERIES, {"last": "last text"})) - attributes.append(Field("string/set", AttributeType.STRING_SET, {"values": ["a", "b"]})) + attributes.append(Field("run/state", FieldType.OBJECT_STATE, {"value": "idle"})) + attributes.append(Field("float", FieldType.FLOAT, {"value": 12.5})) + attributes.append(Field("string", FieldType.STRING, {"value": "some text"})) + attributes.append(Field("datetime", FieldType.DATETIME, {"value": now})) + attributes.append(Field("float/series", FieldType.FLOAT_SERIES, {"last": 8.7})) + attributes.append(Field("string/series", FieldType.STRING_SERIES, {"last": "last text"})) + attributes.append(Field("string/set", FieldType.STRING_SET, {"values": ["a", "b"]})) attributes.append( Field( "git/ref", - AttributeType.GIT_REF, + FieldType.GIT_REF, {"commit": {"commitId": "abcdef0123456789"}}, ) ) - attributes.append(Field("file", AttributeType.FILE, None)) - attributes.append(Field("file/set", AttributeType.FILE_SET, None)) - attributes.append(Field("image/series", AttributeType.IMAGE_SERIES, None)) + attributes.append(Field("file", FieldType.FILE, None)) + attributes.append(Field("file/set", FieldType.FILE_SET, None)) + attributes.append(Field("image/series", FieldType.IMAGE_SERIES, None)) return attributes @patch.object(NeptuneBackendMock, "search_leaderboard_entries") diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index 8013b99ae..213a1e30b 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -32,11 +32,7 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.internal.backends.api_model import ( - AttributeType, - FieldDefinition, - IntAttribute, -) +from neptune.api.models import IntField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException from neptune.internal.warnings import ( @@ -77,11 +73,11 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", FieldType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(42), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -102,7 +98,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) def test_resume(self): with init_model(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index 787189b3e..a37ba7a03 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -33,12 +33,7 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.internal.backends.api_model import ( - AttributeType, - FieldDefinition, - IntAttribute, - StringAttribute, -) +from neptune.api.models import IntField, StringField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException @@ -87,17 +82,17 @@ def test_offline_mode(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - FieldDefinition("some/variable", AttributeType.INT), - FieldDefinition("sys/model_id", AttributeType.STRING), + FieldDefinition("some/variable", FieldType.INT), + FieldDefinition("sys/model_id", FieldType.STRING), ], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(42), ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringAttribute("MDL"), + new=lambda _, _uuid, _type, _path: StringField("MDL"), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -115,13 +110,13 @@ def test_read_only_mode(self, warn_once): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - FieldDefinition("test", AttributeType.STRING), - FieldDefinition("sys/model_id", AttributeType.STRING), + FieldDefinition("test", FieldType.STRING), + FieldDefinition("sys/model_id", FieldType.STRING), ], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringAttribute("MDL"), + new=lambda _, _uuid, _type, _path: StringField("MDL"), ) def test_resume(self): with init_model_version(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index 1176462ea..800216dd3 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -32,13 +32,7 @@ NeptuneMissingProjectNameException, NeptuneUnsupportedFunctionalityException, ) -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - FieldDefinition, - IntAttribute, - LeaderboardEntry, -) +from neptune.api.models import Field, IntField, FieldType, FieldDefinition, LeaderboardEntry from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException from neptune.internal.warnings import ( @@ -54,7 +48,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class TestClientProject(AbstractExperimentTestMixin, unittest.TestCase): @@ -100,7 +94,7 @@ def test_project_name_env_var(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(42), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -173,16 +167,16 @@ def test_prepare_nql_query(): def test_parse_dates(): def entries_generator(): yield LeaderboardEntry( - id="test", + object_id="test", fields=[ Field( "attr1", - AttributeType.DATETIME, + FieldType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, ), Field( "attr2", - AttributeType.DATETIME, + FieldType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, ), ], @@ -197,11 +191,11 @@ def entries_generator(): def test_parse_dates_wrong_format(mock_warn_once): entries = [ LeaderboardEntry( - id="test", + object_id="test", fields=[ Field( "attr1", - AttributeType.DATETIME, + FieldType.DATETIME, {"value": "07-02-2024"}, # different format than expected ) ], diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index dd9d510b9..cdb6aa198 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -33,11 +33,7 @@ PROJECT_ENV_NAME, ) from neptune.exceptions import MissingFieldException -from neptune.internal.backends.api_model import ( - AttributeType, - FieldDefinition, - IntAttribute, -) +from neptune.api.models import IntField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.utils.utils import IS_WINDOWS from neptune.internal.warnings import ( @@ -68,11 +64,11 @@ def setUpClass(cls) -> None: ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", FieldType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(42), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -93,7 +89,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [FieldDefinition("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) def test_resume(self): with init_run(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index b626cc0b1..b53616ceb 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -21,11 +21,7 @@ from mock import patch from neptune import init_project -from neptune.internal.backends.api_model import ( - AttributeType, - Field, - LeaderboardEntry, -) +from neptune.api.models import Field, FieldType, LeaderboardEntry from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.table import ( @@ -66,11 +62,11 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.search_leaderboard_entries", new=lambda *args, **kwargs: [ LeaderboardEntry( - id="123", + object_id="123", fields=[ Field( "sys/creation_time", - AttributeType.DATETIME, + FieldType.DATETIME, {"value": "2024-02-05T20:37:40.915000Z"}, ) ], diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_client.py b/tests/unit/neptune/new/internal/backends/test_hosted_client.py index aefe62b1a..ba77f1e72 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_client.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_client.py @@ -32,7 +32,7 @@ patch, ) -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType from neptune.internal.backends.hosted_client import ( DEFAULT_REQUEST_KWARGS, _get_token_client, @@ -545,15 +545,15 @@ class DTO: # when test_cases = [ {"entries": [], "exc": ValueError}, - {"entries": [DTO(type="float")], "result": AttributeType.FLOAT.value}, - {"entries": [DTO(type="string")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float")], "result": FieldType.FLOAT.value}, + {"entries": [DTO(type="string")], "result": FieldType.STRING.value}, {"entries": [DTO(type="float"), DTO(type="floatSeries")], "exc": ValueError}, - {"entries": [DTO(type="float"), DTO(type="int")], "result": AttributeType.FLOAT.value}, - {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": AttributeType.STRING.value}, - {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="int")], "result": FieldType.FLOAT.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": FieldType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": FieldType.STRING.value}, { "entries": [DTO(type="float"), DTO(type="int"), DTO(type="string", name="test_column_different")], - "result": AttributeType.FLOAT.value, + "result": FieldType.FLOAT.value, }, ] diff --git a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py index a83f7da84..eca5d3343 100644 --- a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py +++ b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py @@ -27,17 +27,13 @@ MetadataInconsistency, ) from neptune.internal.backends.api_model import ( - DatetimeAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, ) +from neptune.api.models import FloatField, StringField, DatetimeField, FloatSeriesField, StringSeriesField, \ + StringSetField from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.operation import ( @@ -78,36 +74,38 @@ def test_get_float_attribute(self): with self.subTest(f"For containerType: {container_type}"): # given digit = random.randint(1, 10**4) + path = ["x"] self.backend.execute_operations( container_id, container_type, - operations=[AssignFloat(["x"], digit)], + operations=[AssignFloat(path, digit)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_float_attribute(container_id, container_type, path=["x"]) + ret = self.backend.get_float_attribute(container_id, container_type, path=path) # then - self.assertEqual(FloatAttribute(digit), ret) + self.assertEqual(FloatField(path="x", value=digit), ret) def test_get_string_attribute(self): for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): # given text = a_string() + path = ["x"] self.backend.execute_operations( container_id, container_type, - operations=[AssignString(["x"], text)], + operations=[AssignString(path, text)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_string_attribute(container_id, container_type, path=["x"]) + ret = self.backend.get_string_attribute(container_id, container_type, path=path) # then - self.assertEqual(StringAttribute(text), ret) + self.assertEqual(StringField(path="x", value=text), ret) def test_get_datetime_attribute(self): for container_id, container_type in self.ids_with_types: @@ -115,18 +113,21 @@ def test_get_datetime_attribute(self): # given now = datetime.datetime.now() now = now.replace(microsecond=1000 * int(now.microsecond / 1000)) + path = ["x"] + + # and self.backend.execute_operations( container_id, container_type, - [AssignDatetime(["x"], now)], + [AssignDatetime(path, now)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_datetime_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_datetime_attribute(container_id, container_type, path) # then - self.assertEqual(DatetimeAttribute(now), ret) + self.assertEqual(DatetimeField(path="x", value=now), ret) def test_get_float_series_attribute(self): # given @@ -165,7 +166,7 @@ def test_get_float_series_attribute(self): ret = self.backend.get_float_series_attribute(container_id, container_type, ["x"]) # then - self.assertEqual(FloatSeriesAttribute(9), ret) + self.assertEqual(FloatSeriesField(9), ret) def test_get_string_series_attribute(self): # given @@ -204,7 +205,7 @@ def test_get_string_series_attribute(self): ret = self.backend.get_string_series_attribute(container_id, container_type, ["x"]) # then - self.assertEqual(StringSeriesAttribute("qwe"), ret) + self.assertEqual(StringSeriesField("qwe"), ret) def test_get_string_set_attribute(self): # given @@ -221,7 +222,7 @@ def test_get_string_set_attribute(self): ret = self.backend.get_string_set_attribute(container_id, container_type, ["x"]) # then - self.assertEqual(StringSetAttribute({"abcx", "qwe"}), ret) + self.assertEqual(StringSetField({"abcx", "qwe"}), ret) def test_get_string_series_values(self): # given From f10a79c7889b6e820cc40c0dbce2c66223160033 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Thu, 4 Apr 2024 13:56:15 +0200 Subject: [PATCH 05/22] Fixes --- .../backends/hosted_artifact_operations.py | 2 +- .../internal/backends/neptune_backend.py | 17 +++++++++++++++-- .../backends/offline_neptune_backend.py | 15 +++++++++++++-- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/neptune/internal/backends/hosted_artifact_operations.py b/src/neptune/internal/backends/hosted_artifact_operations.py index 393ce3aa2..f6dda166a 100644 --- a/src/neptune/internal/backends/hosted_artifact_operations.py +++ b/src/neptune/internal/backends/hosted_artifact_operations.py @@ -263,7 +263,7 @@ def get_artifact_attribute( } try: result = swagger_client.api.getArtifactAttribute(**params).response().result - return ArtifactField(hash=result.hash) + return ArtifactField(path=path_to_str(path), hash=result.hash) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 51766968b..4b15e5cd1 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -36,8 +36,21 @@ StringSeriesValues, Workspace, ) -from neptune.api.models import FloatField, IntField, BoolField, FileField, StringField, DatetimeField, ArtifactField, \ - FloatSeriesField, StringSeriesField, StringSetField, FieldType, FieldDefinition, LeaderboardEntry +from neptune.api.models import ( + FloatField, + IntField, + BoolField, + FileField, + StringField, + DatetimeField, + ArtifactField, + FloatSeriesField, + StringSeriesField, + StringSetField, + FieldType, + FieldDefinition, + LeaderboardEntry, +) from neptune.internal.backends.nql import NQLQuery from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index dda11fe74..435aa87a1 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -28,8 +28,19 @@ ImageSeriesValues, StringSeriesValues, ) -from neptune.api.models import FloatField, IntField, BoolField, FileField, StringField, DatetimeField, ArtifactField, \ - FloatSeriesField, StringSeriesField, StringSetField, FieldDefinition +from neptune.api.models import ( + FloatField, + IntField, + BoolField, + FileField, + StringField, + DatetimeField, + ArtifactField, + FloatSeriesField, + StringSeriesField, + StringSetField, + FieldDefinition, +) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.typing import ProgressBarType From 675e4672d488999781c99d8520fbe1e866917848 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Thu, 4 Apr 2024 14:32:21 +0200 Subject: [PATCH 06/22] AttributesDTO conversion --- src/neptune/api/field_visitor.py | 92 ++++++++++++++++++++++++++++ src/neptune/api/models.py | 14 +++-- src/neptune/api/searching_entries.py | 35 +++-------- src/neptune/table.py | 69 +-------------------- 4 files changed, 112 insertions(+), 98 deletions(-) create mode 100644 src/neptune/api/field_visitor.py diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py new file mode 100644 index 000000000..0f2054405 --- /dev/null +++ b/src/neptune/api/field_visitor.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__all__ = ('FieldToValueVisitor',) + +from typing import ( + Any, + Optional, + Set, +) +from datetime import datetime + +from neptune.exceptions import MetadataInconsistency +from neptune.api.models import ( + FieldVisitor, + FloatField, + IntField, + BoolField, + StringField, + DatetimeField, + FileField, + FileSetField, + FloatSeriesField, + StringSeriesField, + ImageSeriesField, + StringSetField, + GitRefField, + ObjectStateField, + NotebookRefField, + ArtifactField +) +from neptune.internal.utils.run_state import RunState + + +class FieldToValueVisitor(FieldVisitor[Any]): + + def visit_float(self, field: FloatField) -> float: + return field.value + + def visit_int(self, field: IntField) -> int: + return field.value + + def visit_bool(self, field: BoolField) -> bool: + return field.value + + def visit_string(self, field: StringField) -> str: + return field.value + + def visit_datetime(self, field: DatetimeField) -> datetime: + ... + + def visit_file(self, field: FileField) -> None: + raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") + + def visit_file_set(self, field: FileSetField) -> None: + raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") + + def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: + return field.last + + def visit_string_series(self, field: StringSeriesField) -> Optional[str]: + return field.last + + def visit_image_series(self, field: ImageSeriesField) -> None: + raise MetadataInconsistency("Cannot get value for image series.") + + def visit_string_set(self, field: StringSetField) -> Set[str]: + return field.values + + def visit_git_ref(self, field: GitRefField) -> Optional[str]: + return field.commit_id + + def visit_object_state(self, field: ObjectStateField) -> str: + return RunState.from_api(field.value).value + + def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: + return field.notebook_name + + def visit_artifact(self, field: ArtifactField) -> str: + return field.hash diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 30998ec09..ee3c3fc73 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -45,7 +45,7 @@ dataclass, field as dataclass_field, ) -from typing import TypeVar, Generic, Dict +from typing import TypeVar, Generic, Dict, Type, ClassVar from datetime import datetime from enum import Enum from typing import ( @@ -92,11 +92,14 @@ class FieldType(Enum): class Field(abc.ABC): path: str type: FieldType = dataclass_field(init=False, default=None) + _registry: ClassVar[Dict[str, Type[Field]]] = {t.value: {} for t in FieldType} def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) - # TODO: remove this when we have proper type hints - cls.type = kwargs.get('type', None) + field_type: Optional[FieldType] = kwargs.get('type', None) + if field_type is not None: + cls.type = field_type + cls._registry[field_type.value] = cls @abc.abstractmethod def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -104,7 +107,8 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(field: Dict[str, Any]) -> Field: - raise NotImplementedError() + field_type = field["type"].value + return Field._registry[field_type].from_dict(field[f"{field_type}Properties"]) class FieldVisitor(Generic[Ret], abc.ABC): @@ -360,7 +364,7 @@ class LeaderboardEntry: def from_dict(data: Dict[str, Any]) -> LeaderboardEntry: return LeaderboardEntry( object_id=data["experimentId"], - fields=[] # TODO: map fields + fields=[Field.from_dict(field) for field in data["attributes"]] ) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index b12b0a4be..3bf730af2 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -21,7 +21,6 @@ Dict, Generator, Iterable, - List, Optional, ) @@ -153,20 +152,6 @@ def get_single_page( raise e -def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: - # return LeaderboardEntry( - # fields=[ - # Field( - # path=attr["name"], - # type=FieldType(attr["type"]), - # properties=attr.__getitem__(f"{attr['type']}Properties"), - # ) - # for attr in entry["attributes"] - # if attr["type"] in SUPPORTED_ATTRIBUTE_TYPES - # ], - # ) - - def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]: return next((attr for attr in entry.fields if attr.path == path), None) @@ -186,7 +171,7 @@ def iter_over_pages( last_page = None # TODO: Refactor - total = get_single_page( + data = get_single_page( limit=0, offset=0, sort_by=sort_by, @@ -194,7 +179,8 @@ def iter_over_pages( sort_by_column_type=sort_by_column_type, searching_after=None, **kwargs, - ).get("matchingItemCount", 0) + ) + total = LeaderboardEntriesSearchResult.from_dict(data).matching_item_count limit = limit if limit is not None else NoLimit() @@ -204,6 +190,8 @@ def iter_over_pages( extracted_records = 0 + field_to_value_visitor = FieldToValueVisitor() + with construct_progress_bar(progress_bar, "Fetching table...") as bar: # beginning of the first page bar.update( @@ -226,7 +214,7 @@ def iter_over_pages( if extracted_records + local_limit > limit: local_limit = limit - extracted_records - result = get_single_page( + data = get_single_page( limit=local_limit, offset=offset, sort_by=sort_by, @@ -235,15 +223,15 @@ def iter_over_pages( ascending=ascending, **kwargs, ) + result = LeaderboardEntriesSearchResult.from_dict(data) # fetch the item count everytime a new page is started (except for the very fist page) if offset == 0 and last_page is not None: - # TODO: Refactor - total += result.get("matchingItemCount", 0) + total += result.matching_item_count total = min(total, limit) - page = _entries_from_page(result) + page = result.entries extracted_records += len(page) bar.update(by=len(page), total=total) @@ -256,8 +244,3 @@ def iter_over_pages( return last_page = page - - -# TODO: Refactor -def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]: - return LeaderboardEntriesSearchResult.from_dict(single_page).entries diff --git a/src/neptune/table.py b/src/neptune/table.py index d9651a125..43f0030f8 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -20,32 +20,16 @@ Any, Generator, List, - Optional, Set, + Optional, ) -from datetime import datetime +from neptune.api.field_visitor import FieldToValueVisitor from neptune.exceptions import MetadataInconsistency from neptune.integrations.pandas import to_pandas from neptune.api.models import ( Field, FieldType, LeaderboardEntry, - FieldVisitor, - FloatField, - IntField, - BoolField, - StringField, - DatetimeField, - FileField, - FileSetField, - FloatSeriesField, - StringSeriesField, - ImageSeriesField, - StringSetField, - GitRefField, - ObjectStateField, - NotebookRefField, - ArtifactField ) from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.container_type import ContainerType @@ -54,7 +38,6 @@ join_paths, parse_path, ) -from neptune.internal.utils.run_state import RunState from neptune.typing import ProgressBarType if TYPE_CHECKING: @@ -64,54 +47,6 @@ logger = get_logger() -class FieldToValueVisitor(FieldVisitor[Any]): - - def visit_float(self, field: FloatField) -> float: - return field.value - - def visit_int(self, field: IntField) -> int: - return field.value - - def visit_bool(self, field: BoolField) -> bool: - return field.value - - def visit_string(self, field: StringField) -> str: - return field.value - - def visit_datetime(self, field: DatetimeField) -> datetime: - ... - - def visit_file(self, field: FileField) -> None: - raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") - - def visit_file_set(self, field: FileSetField) -> None: - raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") - - def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: - return field.last - - def visit_string_series(self, field: StringSeriesField) -> Optional[str]: - return field.last - - def visit_image_series(self, field: ImageSeriesField) -> None: - raise MetadataInconsistency("Cannot get value for image series.") - - def visit_string_set(self, field: StringSetField) -> Set[str]: - return field.values - - def visit_git_ref(self, field: GitRefField) -> Optional[str]: - return field.commit_id - - def visit_object_state(self, field: ObjectStateField) -> str: - return RunState.from_api(field.value).value - - def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: - return field.notebook_name - - def visit_artifact(self, field: ArtifactField) -> str: - return field.hash - - class TableEntry: def __init__( self, From 67ff0250d8c73d8338821fe21bee936a8ce03012 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Thu, 4 Apr 2024 14:34:56 +0200 Subject: [PATCH 07/22] Visitor in searching after --- src/neptune/api/field_visitor.py | 2 ++ src/neptune/api/searching_entries.py | 11 ++++------- tests/unit/neptune/new/api/test_searching_entries.py | 1 - 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index 0f2054405..1a3283a25 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -59,6 +59,7 @@ def visit_string(self, field: StringField) -> str: return field.value def visit_datetime(self, field: DatetimeField) -> datetime: + # TODO: Datetime ... def visit_file(self, field: FileField) -> None: @@ -83,6 +84,7 @@ def visit_git_ref(self, field: GitRefField) -> Optional[str]: return field.commit_id def visit_object_state(self, field: ObjectStateField) -> str: + # TODO: Refactor not to use RunState return RunState.from_api(field.value).value def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 3bf730af2..0ec5b4609 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -27,6 +27,7 @@ from bravado.client import construct_request # type: ignore from bravado.config import RequestConfig # type: ignore from bravado.exception import HTTPBadRequest # type: ignore +from neptune.api.field_visitor import FieldToValueVisitor from typing_extensions import ( Literal, TypeAlias, @@ -170,7 +171,6 @@ def iter_over_pages( searching_after = None last_page = None - # TODO: Refactor data = get_single_page( limit=0, offset=0, @@ -201,13 +201,10 @@ def iter_over_pages( while True: if last_page: - page_attribute = find_attribute(entry=last_page[-1], path=sort_by) - - if not page_attribute: + searching_after_filed = find_attribute(entry=last_page[-1], path=sort_by) + if not searching_after_filed: raise ValueError(f"Cannot find attribute {sort_by} in last page") - - # TODO: Refactor - searching_after = page_attribute.properties["value"] + searching_after = field_to_value_visitor.visit(searching_after_filed) for offset in range(0, max_offset, step_size): local_limit = min(step_size, max_offset - offset) diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 4054337c1..79e26f8c6 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -29,7 +29,6 @@ from neptune.api.searching_entries import ( get_single_page, iter_over_pages, - to_leaderboard_entry, ) from neptune.exceptions import NeptuneInvalidQueryException from neptune.api.models import Field, StringField, FieldType, LeaderboardEntry From 488ceb376b0a08484fb9bd433683ef1f65d2cd46 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Thu, 4 Apr 2024 18:04:23 +0200 Subject: [PATCH 08/22] e2e fixes --- src/neptune/api/field_visitor.py | 3 +- src/neptune/api/models.py | 53 +++++++++++-------- src/neptune/attributes/file_set.py | 2 +- src/neptune/handler.py | 2 +- .../backends/hosted_neptune_backend.py | 11 ++-- .../internal/backends/neptune_backend.py | 2 +- .../backends/offline_neptune_backend.py | 2 +- src/neptune/objects/neptune_object.py | 3 -- src/neptune/objects/utils.py | 41 -------------- tests/unit/neptune/new/client/test_project.py | 5 +- 10 files changed, 42 insertions(+), 82 deletions(-) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index 1a3283a25..bbf377a23 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -59,8 +59,7 @@ def visit_string(self, field: StringField) -> str: return field.value def visit_datetime(self, field: DatetimeField) -> datetime: - # TODO: Datetime - ... + return field.value def visit_file(self, field: FileField) -> None: raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index ee3c3fc73..ee0719d40 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -55,6 +55,9 @@ List, ) +from neptune.internal.utils.iso_dates import parse_iso_date +from neptune.internal.warnings import warn_once, NeptuneWarning + Ret = TypeVar("Ret") @@ -94,9 +97,8 @@ class Field(abc.ABC): type: FieldType = dataclass_field(init=False, default=None) _registry: ClassVar[Dict[str, Type[Field]]] = {t.value: {} for t in FieldType} - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, field_type: FieldType, **kwargs) -> None: super().__init_subclass__(**kwargs) - field_type: Optional[FieldType] = kwargs.get('type', None) if field_type is not None: cls.type = field_type cls._registry[field_type.value] = cls @@ -107,7 +109,7 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(field: Dict[str, Any]) -> Field: - field_type = field["type"].value + field_type = field["type"] return Field._registry[field_type].from_dict(field[f"{field_type}Properties"]) @@ -163,7 +165,7 @@ def visit_artifact(self, field: ArtifactField) -> Ret: @dataclass -class FloatField(Field, type=FieldType.FLOAT): +class FloatField(Field, field_type=FieldType.FLOAT): value: float def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -176,7 +178,7 @@ def from_dict(data: Dict[str, Any]) -> FloatField: @dataclass -class IntField(Field, type=FieldType.INT): +class IntField(Field, field_type=FieldType.INT): value: int def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -189,7 +191,7 @@ def from_dict(data: Dict[str, Any]) -> IntField: @dataclass -class BoolField(Field, type=FieldType.BOOL): +class BoolField(Field, field_type=FieldType.BOOL): value: bool def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -202,7 +204,7 @@ def from_dict(data: Dict[str, Any]) -> BoolField: @dataclass -class StringField(Field, type=FieldType.STRING): +class StringField(Field, field_type=FieldType.STRING): value: str def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -215,7 +217,7 @@ def from_dict(data: Dict[str, Any]) -> StringField: @dataclass -class DatetimeField(Field, type=FieldType.DATETIME): +class DatetimeField(Field, field_type=FieldType.DATETIME): value: datetime def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -223,12 +225,12 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> DatetimeField: - # TODO: parse datetime - return DatetimeField(path=data["attributeName"], value=data["value"]) + # TODO: what if none + return DatetimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) @dataclass -class FileField(Field, type=FieldType.FILE): +class FileField(Field, field_type=FieldType.FILE): name: str ext: str size: int @@ -239,7 +241,7 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FileField: return FileField( - path=data["path"], + path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"]) @@ -247,7 +249,7 @@ def from_dict(data: Dict[str, Any]) -> FileField: @dataclass -class FileSetField(Field, type=FieldType.FILE_SET): +class FileSetField(Field, field_type=FieldType.FILE_SET): size: int def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -259,7 +261,7 @@ def from_dict(data: Dict[str, Any]) -> FileSetField: @dataclass -class FloatSeriesField(Field, type=FieldType.FLOAT_SERIES): +class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES): last: Optional[float] def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @@ -272,7 +274,7 @@ def from_dict(data: Dict[str, Any]) -> FloatSeriesField: @dataclass -class StringSeriesField(Field, type=FieldType.STRING_SERIES): +class StringSeriesField(Field, field_type=FieldType.STRING_SERIES): last: Optional[str] def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @@ -285,7 +287,7 @@ def from_dict(data: Dict[str, Any]) -> StringSeriesField: @dataclass -class ImageSeriesField(Field, type=FieldType.IMAGE_SERIES): +class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES): last_step: Optional[float] def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @@ -298,7 +300,7 @@ def from_dict(data: Dict[str, Any]) -> ImageSeriesField: @dataclass -class StringSetField(Field, type=FieldType.STRING_SET): +class StringSetField(Field, field_type=FieldType.STRING_SET): values: Set[str] def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @@ -310,7 +312,7 @@ def from_dict(data: Dict[str, Any]) -> StringSetField: @dataclass -class GitRefField(Field, type=FieldType.GIT_REF): +class GitRefField(Field, field_type=FieldType.GIT_REF): commit_id: Optional[str] def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -323,7 +325,7 @@ def from_dict(data: Dict[str, Any]) -> GitRefField: @dataclass -class ObjectStateField(Field, type=FieldType.OBJECT_STATE): +class ObjectStateField(Field, field_type=FieldType.OBJECT_STATE): value: str def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -335,7 +337,7 @@ def from_dict(data: Dict[str, Any]) -> ObjectStateField: @dataclass -class NotebookRefField(Field, type=FieldType.NOTEBOOK_REF): +class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF): notebook_name: Optional[str] def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @@ -348,12 +350,16 @@ def from_dict(data: Dict[str, Any]) -> NotebookRefField: @dataclass -class ArtifactField(Field, type=FieldType.ARTIFACT): +class ArtifactField(Field, field_type=FieldType.ARTIFACT): hash: str def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_artifact(self) + @staticmethod + def from_dict(data: Dict[str, Any]) -> ArtifactField: + return ArtifactField(path=data["attributeName"], hash=str(data["hash"])) + @dataclass class LeaderboardEntry: @@ -376,6 +382,7 @@ class LeaderboardEntriesSearchResult: @staticmethod def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: return LeaderboardEntriesSearchResult( + # TODO: Use generator instead entries=[LeaderboardEntry.from_dict(entry) for entry in result["entries"]], matching_item_count=result["matchingItemCount"], ) @@ -385,3 +392,7 @@ def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: class FieldDefinition: path: str type: FieldType + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FieldDefinition: + return FieldDefinition(path=data["name"], type=FieldType(data["type"])) diff --git a/src/neptune/attributes/file_set.py b/src/neptune/attributes/file_set.py index 922522a23..5f32a4705 100644 --- a/src/neptune/attributes/file_set.py +++ b/src/neptune/attributes/file_set.py @@ -23,7 +23,7 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.attributes.attribute import Attribute from neptune.internal.operation import ( DeleteFiles, diff --git a/src/neptune/handler.py b/src/neptune/handler.py index 4607eba6c..6e5f8d28c 100644 --- a/src/neptune/handler.py +++ b/src/neptune/handler.py @@ -28,7 +28,7 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.attributes import File from neptune.attributes.atoms.artifact import Artifact from neptune.attributes.constants import SYSTEM_STAGE_ATTRIBUTE_PATH diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index d29799f97..aa3d82e3c 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -685,9 +685,6 @@ def _execute_operations( @with_api_exceptions_handler def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: - def to_attribute(attr) -> FieldDefinition: - return FieldDefinition(attr.name, FieldType(attr.type)) - params = { "experimentId": container_id, **DEFAULT_REQUEST_KWARGS, @@ -696,10 +693,10 @@ def to_attribute(attr) -> FieldDefinition: experiment = self.leaderboard_client.api.getExperimentAttributes(**params).response().result attribute_type_names = [at.value for at in FieldType] - accepted_attributes = [attr for attr in experiment.fields if attr.type in attribute_type_names] + accepted_attributes = [attr for attr in experiment.attributes if attr.type in attribute_type_names] # Notify about ignored attrs - ignored_attributes = set(attr.type for attr in experiment.fields) - set( + ignored_attributes = set(attr.type for attr in experiment.attributes) - set( attr.type for attr in accepted_attributes ) if ignored_attributes: @@ -708,7 +705,7 @@ def to_attribute(attr) -> FieldDefinition: ignored_attributes, ) - return [to_attribute(attr) for attr in accepted_attributes if attr.type in attribute_type_names] + return [FieldDefinition.from_dict(field) for field in accepted_attributes if field.type in attribute_type_names] except HTTPNotFound as e: raise ContainerUUIDNotFound( container_id=container_id, @@ -1030,7 +1027,7 @@ def fetch_atom_attribute_values( result = self.leaderboard_client.api.getExperimentAttributes(**params).response().result return [ (attr.name, attr.type, map_attribute_result_to_value(attr)) - for attr in result.fields + for attr in result.attributes if attr.name.startswith(namespace_prefix) ] except HTTPNotFound as e: diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 4b15e5cd1..fefeb914f 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -25,7 +25,7 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.core.components.operation_storage import OperationStorage from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index 435aa87a1..c3b8fdf25 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -20,7 +20,7 @@ Optional, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.exceptions import NeptuneOfflineModeFetchException from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( diff --git a/src/neptune/objects/neptune_object.py b/src/neptune/objects/neptune_object.py index 6f5d05202..c5824f768 100644 --- a/src/neptune/objects/neptune_object.py +++ b/src/neptune/objects/neptune_object.py @@ -97,7 +97,6 @@ AbstractNeptuneObject, NeptuneObjectCallback, ) -from neptune.objects.utils import parse_dates from neptune.table import Table from neptune.types.mode import Mode from neptune.types.type_casting import cast_value @@ -685,8 +684,6 @@ def _fetch_entries( progress_bar=progress_bar, ) - leaderboard_entries = parse_dates(leaderboard_entries) - return Table( backend=self._backend, container_type=child_type, diff --git a/src/neptune/objects/utils.py b/src/neptune/objects/utils.py index 2dd033a35..73763e746 100644 --- a/src/neptune/objects/utils.py +++ b/src/neptune/objects/utils.py @@ -15,19 +15,16 @@ # __all__ = [ - "parse_dates", "prepare_nql_query", ] from typing import ( - Generator, Iterable, List, Optional, Union, ) -from neptune.api.models import Field, FieldType, LeaderboardEntry from neptune.internal.backends.nql import ( NQLAggregator, NQLAttributeOperator, @@ -37,12 +34,7 @@ NQLQueryAttribute, RawNQLQuery, ) -from neptune.internal.utils.iso_dates import parse_iso_date from neptune.internal.utils.run_state import RunState -from neptune.internal.warnings import ( - NeptuneWarning, - warn_once, -) def prepare_nql_query( @@ -132,39 +124,6 @@ def prepare_nql_query( return query -def parse_dates(leaderboard_entries: Iterable[LeaderboardEntry]) -> Generator[LeaderboardEntry, None, None]: - yield from [_parse_entry(entry) for entry in leaderboard_entries] - - -def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: - try: - return LeaderboardEntry( - entry.object_id, - fields=[ - ( - Field( - attribute.path, - attribute.type, - { - **attribute.properties, - "value": parse_iso_date(attribute.properties["value"]), - }, - ) - if attribute.type == FieldType.DATETIME - else attribute - ) - for attribute in entry.fields - ], - ) - except ValueError: - # the parsing format is incorrect - warn_once( - "Date parsing failed. The date format is incorrect. Returning as string instead of datetime.", - exception=NeptuneWarning, - ) - return entry - - def build_raw_query(query: str, trashed: Optional[bool]) -> NQLQuery: raw_nql = RawNQLQuery(query) diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index 800216dd3..c4d3a7bd3 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -39,10 +39,7 @@ NeptuneWarning, warned_once, ) -from neptune.objects.utils import ( - parse_dates, - prepare_nql_query, -) +from neptune.objects.utils import prepare_nql_query from tests.unit.neptune.new.client.abstract_experiment_test_mixin import AbstractExperimentTestMixin From d8e137f599e093d44753a6040d9ea9bb1ba42eb1 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 08:41:25 +0200 Subject: [PATCH 09/22] RunState mapping --- src/neptune/api/field_visitor.py | 4 +--- src/neptune/api/models.py | 5 +++-- src/neptune/integrations/pandas/__init__.py | 3 +-- src/neptune/internal/backends/hosted_neptune_backend.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index bbf377a23..6442f958c 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -41,7 +41,6 @@ NotebookRefField, ArtifactField ) -from neptune.internal.utils.run_state import RunState class FieldToValueVisitor(FieldVisitor[Any]): @@ -83,8 +82,7 @@ def visit_git_ref(self, field: GitRefField) -> Optional[str]: return field.commit_id def visit_object_state(self, field: ObjectStateField) -> str: - # TODO: Refactor not to use RunState - return RunState.from_api(field.value).value + return field.value def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: return field.notebook_name diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index ee0719d40..6840f568d 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -56,7 +56,7 @@ ) from neptune.internal.utils.iso_dates import parse_iso_date -from neptune.internal.warnings import warn_once, NeptuneWarning +from neptune.internal.utils.run_state import RunState Ret = TypeVar("Ret") @@ -333,7 +333,8 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ObjectStateField: - return ObjectStateField(path=data["attributeName"], value=str(data["value"])) + value = RunState.from_api(str(data["value"])).value + return ObjectStateField(path=data["attributeName"], value=value) @dataclass diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index f0ff57014..011a6d568 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -47,7 +47,6 @@ StringSetField, ObjectStateField, ) -from neptune.internal.utils.run_state import RunState if TYPE_CHECKING: from neptune.table import Table @@ -94,7 +93,7 @@ def visit_git_ref(self, field: GitRefField) -> str: return field.commit_id def visit_object_state(self, field: ObjectStateField) -> str: - return RunState.from_api(field.value).value + return field.value def visit_notebook_ref(self, field: NotebookRefField) -> str: return field.notebook_name diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index aa3d82e3c..2431bfc20 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1081,7 +1081,7 @@ def search_leaderboard_entries( if sort_by == "sys/creation_time": sort_by_column_type = FieldType.DATETIME.value - if sort_by == "sys/id": + elif sort_by == "sys/id": sort_by_column_type = FieldType.STRING.value else: sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter) From 1a055a9eb591ec9901d0dc15766b1d82b821e52b Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 09:11:08 +0200 Subject: [PATCH 10/22] Pre-commit suggestions --- src/neptune/api/field_visitor.py | 24 ++--- src/neptune/api/models.py | 100 ++++++++---------- src/neptune/api/searching_entries.py | 9 +- src/neptune/attributes/utils.py | 2 +- src/neptune/integrations/pandas/__init__.py | 26 ++--- .../backends/hosted_artifact_operations.py | 6 +- .../backends/hosted_neptune_backend.py | 40 +++---- .../internal/backends/neptune_backend.py | 36 +++---- .../internal/backends/neptune_backend_mock.py | 36 +++---- .../backends/offline_neptune_backend.py | 32 +++--- src/neptune/objects/neptune_object.py | 2 +- src/neptune/table.py | 4 +- .../neptune/new/api/test_searching_entries.py | 7 +- .../new/attributes/test_attribute_utils.py | 5 +- .../new/client/abstract_tables_test.py | 7 +- tests/unit/neptune/new/client/test_model.py | 6 +- .../neptune/new/client/test_model_version.py | 7 +- tests/unit/neptune/new/client/test_project.py | 8 +- tests/unit/neptune/new/client/test_run.py | 6 +- .../neptune/new/client/test_run_tables.py | 6 +- .../backends/test_neptune_backend_mock.py | 10 +- 21 files changed, 204 insertions(+), 175 deletions(-) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index 6442f958c..5c22bac31 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -13,34 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__all__ = ('FieldToValueVisitor',) +__all__ = ("FieldToValueVisitor",) +from datetime import datetime from typing import ( Any, Optional, Set, ) -from datetime import datetime -from neptune.exceptions import MetadataInconsistency from neptune.api.models import ( - FieldVisitor, - FloatField, - IntField, + ArtifactField, BoolField, - StringField, DatetimeField, + FieldVisitor, FileField, FileSetField, + FloatField, FloatSeriesField, - StringSeriesField, - ImageSeriesField, - StringSetField, GitRefField, - ObjectStateField, + ImageSeriesField, + IntField, NotebookRefField, - ArtifactField + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, ) +from neptune.exceptions import MetadataInconsistency class FieldToValueVisitor(FieldVisitor[Any]): diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 6840f568d..2b95ab978 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -41,18 +41,20 @@ ) import abc -from dataclasses import ( - dataclass, - field as dataclass_field, -) -from typing import TypeVar, Generic, Dict, Type, ClassVar +from dataclasses import dataclass +from dataclasses import field as dataclass_field from datetime import datetime from enum import Enum from typing import ( Any, + ClassVar, + Dict, + Generic, + List, Optional, Set, - List, + Type, + TypeVar, ) from neptune.internal.utils.iso_dates import parse_iso_date @@ -94,18 +96,16 @@ class FieldType(Enum): @dataclass class Field(abc.ABC): path: str - type: FieldType = dataclass_field(init=False, default=None) - _registry: ClassVar[Dict[str, Type[Field]]] = {t.value: {} for t in FieldType} + type: FieldType = dataclass_field(init=False) + _registry: ClassVar[Dict[str, Type[Field]]] = {} - def __init_subclass__(cls, field_type: FieldType, **kwargs) -> None: - super().__init_subclass__(**kwargs) - if field_type is not None: - cls.type = field_type - cls._registry[field_type.value] = cls + def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> None: + super().__init_subclass__(*args, **kwargs) + cls.type = field_type + cls._registry[field_type.value] = cls @abc.abstractmethod - def accept(self, visitor: FieldVisitor[Ret]) -> Ret: - ... + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: ... @staticmethod def from_dict(field: Dict[str, Any]) -> Field: @@ -118,50 +118,50 @@ class FieldVisitor(Generic[Ret], abc.ABC): def visit(self, field: Field) -> Ret: return field.accept(self) - def visit_float(self, field: FloatField) -> Ret: - ... + @abc.abstractmethod + def visit_float(self, field: FloatField) -> Ret: ... - def visit_int(self, field: IntField) -> Ret: - ... + @abc.abstractmethod + def visit_int(self, field: IntField) -> Ret: ... - def visit_bool(self, field: BoolField) -> Ret: - ... + @abc.abstractmethod + def visit_bool(self, field: BoolField) -> Ret: ... - def visit_string(self, field: StringField) -> Ret: - ... + @abc.abstractmethod + def visit_string(self, field: StringField) -> Ret: ... - def visit_datetime(self, field: DatetimeField) -> Ret: - ... + @abc.abstractmethod + def visit_datetime(self, field: DatetimeField) -> Ret: ... - def visit_file(self, field: FileField) -> Ret: - ... + @abc.abstractmethod + def visit_file(self, field: FileField) -> Ret: ... - def visit_file_set(self, field: FileSetField) -> Ret: - ... + @abc.abstractmethod + def visit_file_set(self, field: FileSetField) -> Ret: ... - def visit_float_series(self, field: FloatSeriesField) -> Ret: - ... + @abc.abstractmethod + def visit_float_series(self, field: FloatSeriesField) -> Ret: ... - def visit_string_series(self, field: StringSeriesField) -> Ret: - ... + @abc.abstractmethod + def visit_string_series(self, field: StringSeriesField) -> Ret: ... - def visit_image_series(self, field: ImageSeriesField) -> Ret: - ... + @abc.abstractmethod + def visit_image_series(self, field: ImageSeriesField) -> Ret: ... - def visit_string_set(self, field: StringSetField) -> Ret: - ... + @abc.abstractmethod + def visit_string_set(self, field: StringSetField) -> Ret: ... - def visit_git_ref(self, field: GitRefField) -> Ret: - ... + @abc.abstractmethod + def visit_git_ref(self, field: GitRefField) -> Ret: ... - def visit_object_state(self, field: ObjectStateField) -> Ret: - ... + @abc.abstractmethod + def visit_object_state(self, field: ObjectStateField) -> Ret: ... - def visit_notebook_ref(self, field: NotebookRefField) -> Ret: - ... + @abc.abstractmethod + def visit_notebook_ref(self, field: NotebookRefField) -> Ret: ... - def visit_artifact(self, field: ArtifactField) -> Ret: - ... + @abc.abstractmethod + def visit_artifact(self, field: ArtifactField) -> Ret: ... @dataclass @@ -240,12 +240,7 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FileField: - return FileField( - path=data["attributeName"], - name=data["name"], - ext=data["ext"], - size=int(data["size"]) - ) + return FileField(path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"])) @dataclass @@ -370,8 +365,7 @@ class LeaderboardEntry: @staticmethod def from_dict(data: Dict[str, Any]) -> LeaderboardEntry: return LeaderboardEntry( - object_id=data["experimentId"], - fields=[Field.from_dict(field) for field in data["attributes"]] + object_id=data["experimentId"], fields=[Field.from_dict(field) for field in data["attributes"]] ) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 0ec5b4609..d2eb222e0 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -27,14 +27,19 @@ from bravado.client import construct_request # type: ignore from bravado.config import RequestConfig # type: ignore from bravado.exception import HTTPBadRequest # type: ignore -from neptune.api.field_visitor import FieldToValueVisitor from typing_extensions import ( Literal, TypeAlias, ) +from neptune.api.field_visitor import FieldToValueVisitor +from neptune.api.models import ( + Field, + FieldType, + LeaderboardEntriesSearchResult, + LeaderboardEntry, +) from neptune.exceptions import NeptuneInvalidQueryException -from neptune.api.models import Field, FieldType, LeaderboardEntry, LeaderboardEntriesSearchResult from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS from neptune.internal.backends.nql import ( NQLAggregator, diff --git a/src/neptune/attributes/utils.py b/src/neptune/attributes/utils.py index f7bb02b6f..18d83410b 100644 --- a/src/neptune/attributes/utils.py +++ b/src/neptune/attributes/utils.py @@ -20,6 +20,7 @@ List, ) +from neptune.api.models import FieldType from neptune.attributes import ( Artifact, Boolean, @@ -37,7 +38,6 @@ StringSeries, StringSet, ) -from neptune.api.models import FieldType from neptune.internal.exceptions import InternalClientError if TYPE_CHECKING: diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 011a6d568..f7260a912 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -21,31 +21,31 @@ from typing import ( TYPE_CHECKING, Dict, + Optional, Tuple, Union, - Optional, ) import pandas as pd from neptune.api.models import ( - LeaderboardEntry, - FieldVisitor, - FloatField, - IntField, + ArtifactField, BoolField, - StringField, DatetimeField, - FloatSeriesField, - StringSeriesField, - ImageSeriesField, + FieldVisitor, FileField, FileSetField, + FloatField, + FloatSeriesField, GitRefField, + ImageSeriesField, + IntField, + LeaderboardEntry, NotebookRefField, - ArtifactField, - StringSetField, ObjectStateField, + StringField, + StringSeriesField, + StringSetField, ) if TYPE_CHECKING: @@ -89,13 +89,13 @@ def visit_image_series(self, field: ImageSeriesField) -> None: def visit_file_set(self, field: FileSetField) -> None: return None - def visit_git_ref(self, field: GitRefField) -> str: + def visit_git_ref(self, field: GitRefField) -> Optional[str]: return field.commit_id def visit_object_state(self, field: ObjectStateField) -> str: return field.value - def visit_notebook_ref(self, field: NotebookRefField) -> str: + def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: return field.notebook_name def visit_artifact(self, field: ArtifactField) -> str: diff --git a/src/neptune/internal/backends/hosted_artifact_operations.py b/src/neptune/internal/backends/hosted_artifact_operations.py index f6dda166a..672d102fb 100644 --- a/src/neptune/internal/backends/hosted_artifact_operations.py +++ b/src/neptune/internal/backends/hosted_artifact_operations.py @@ -30,6 +30,7 @@ from bravado.exception import HTTPNotFound +from neptune.api.models import ArtifactField from neptune.exceptions import ( ArtifactNotFoundException, ArtifactUploadingError, @@ -42,10 +43,7 @@ ArtifactDriversMap, ArtifactFileData, ) -from neptune.internal.backends.api_model import ( - ArtifactModel, -) -from neptune.api.models import ArtifactField +from neptune.internal.backends.api_model import ArtifactModel from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper from neptune.internal.backends.utils import with_api_exceptions_handler from neptune.internal.operation import ( diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 2431bfc20..2d4499f3a 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -38,6 +38,22 @@ HTTPUnprocessableEntity, ) +from neptune.api.models import ( + ArtifactField, + BoolField, + DatetimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE @@ -66,22 +82,6 @@ StringSeriesValues, Workspace, ) -from neptune.api.models import ( - FloatField, - IntField, - BoolField, - FileField, - StringField, - DatetimeField, - ArtifactField, - FloatSeriesField, - StringSeriesField, - StringSetField, - FieldType, - FieldDefinition, - LeaderboardEntry, - FileEntry, -) from neptune.internal.backends.hosted_artifact_operations import ( get_artifact_attribute, list_artifact_files, @@ -705,7 +705,9 @@ def get_attributes(self, container_id: str, container_type: ContainerType) -> Li ignored_attributes, ) - return [FieldDefinition.from_dict(field) for field in accepted_attributes if field.type in attribute_type_names] + return [ + FieldDefinition.from_dict(field) for field in accepted_attributes if field.type in attribute_type_names + ] except HTTPNotFound as e: raise ContainerUUIDNotFound( container_id=container_id, @@ -833,9 +835,7 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringField: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: params = { "experimentId": container_id, "attribute": path_to_str(path), diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index fefeb914f..006a2ab75 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -25,7 +25,22 @@ Union, ) -from neptune.api.models import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DatetimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( @@ -36,21 +51,6 @@ StringSeriesValues, Workspace, ) -from neptune.api.models import ( - FloatField, - IntField, - BoolField, - FileField, - StringField, - DatetimeField, - ArtifactField, - FloatSeriesField, - StringSeriesField, - StringSetField, - FieldType, - FieldDefinition, - LeaderboardEntry, -) from neptune.internal.backends.nql import NQLQuery from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException @@ -190,9 +190,7 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p pass @abc.abstractmethod - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringField: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index f3b0b0729..6e40c0894 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -34,7 +34,22 @@ ) from zipfile import ZipFile -from neptune.api.models import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DatetimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.exceptions import ( ContainerUUIDNotFound, @@ -54,21 +69,6 @@ StringSeriesValues, Workspace, ) -from neptune.api.models import ( - FloatField, - IntField, - BoolField, - FileField, - StringField, - DatetimeField, - ArtifactField, - FloatSeriesField, - StringSeriesField, - StringSetField, - FieldType, - FieldDefinition, - LeaderboardEntry, -) from neptune.internal.backends.hosted_file_operations import get_unique_upload_entries from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.backends.nql import NQLQuery @@ -393,9 +393,7 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p size=0, ) - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringField: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: val = self._get_attribute(container_id, container_type, path, String) return StringField(path=path_to_str(path), value=val.value) diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index c3b8fdf25..aee0c7c83 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -20,26 +20,26 @@ Optional, ) -from neptune.api.models import FileEntry -from neptune.exceptions import NeptuneOfflineModeFetchException -from neptune.internal.artifacts.types import ArtifactFileData -from neptune.internal.backends.api_model import ( - FloatSeriesValues, - ImageSeriesValues, - StringSeriesValues, -) from neptune.api.models import ( - FloatField, - IntField, + ArtifactField, BoolField, - FileField, - StringField, DatetimeField, - ArtifactField, + FieldDefinition, + FileEntry, + FileField, + FloatField, FloatSeriesField, + IntField, + StringField, StringSeriesField, StringSetField, - FieldDefinition, +) +from neptune.exceptions import NeptuneOfflineModeFetchException +from neptune.internal.artifacts.types import ArtifactFileData +from neptune.internal.backends.api_model import ( + FloatSeriesValues, + ImageSeriesValues, + StringSeriesValues, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType @@ -64,9 +64,7 @@ def get_bool_attribute(self, container_id: str, container_type: ContainerType, p def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: raise NeptuneOfflineModeFetchException - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringField: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: raise NeptuneOfflineModeFetchException def get_datetime_attribute( diff --git a/src/neptune/objects/neptune_object.py b/src/neptune/objects/neptune_object.py index c5824f768..d9459ef06 100644 --- a/src/neptune/objects/neptune_object.py +++ b/src/neptune/objects/neptune_object.py @@ -39,6 +39,7 @@ Union, ) +from neptune.api.models import FieldType from neptune.attributes import create_attribute_from_type from neptune.attributes.attribute import Attribute from neptune.attributes.namespace import Namespace as NamespaceAttr @@ -53,7 +54,6 @@ ApiExperiment, Project, ) -from neptune.api.models import FieldType from neptune.internal.backends.factory import get_backend from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.backends.nql import NQLQuery diff --git a/src/neptune/table.py b/src/neptune/table.py index 43f0030f8..ac4a75c2e 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -24,13 +24,13 @@ ) from neptune.api.field_visitor import FieldToValueVisitor -from neptune.exceptions import MetadataInconsistency -from neptune.integrations.pandas import to_pandas from neptune.api.models import ( Field, FieldType, LeaderboardEntry, ) +from neptune.exceptions import MetadataInconsistency +from neptune.integrations.pandas import to_pandas from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.container_type import ContainerType from neptune.internal.utils.logger import get_logger diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 79e26f8c6..50924bb1a 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -26,12 +26,17 @@ patch, ) +from neptune.api.models import ( + Field, + FieldType, + LeaderboardEntry, + StringField, +) from neptune.api.searching_entries import ( get_single_page, iter_over_pages, ) from neptune.exceptions import NeptuneInvalidQueryException -from neptune.api.models import Field, StringField, FieldType, LeaderboardEntry def test__to_leaderboard_entry(): diff --git a/tests/unit/neptune/new/attributes/test_attribute_utils.py b/tests/unit/neptune/new/attributes/test_attribute_utils.py index 3e0456d32..5839101e2 100644 --- a/tests/unit/neptune/new/attributes/test_attribute_utils.py +++ b/tests/unit/neptune/new/attributes/test_attribute_utils.py @@ -16,9 +16,9 @@ import unittest from unittest.mock import MagicMock +from neptune.api.models import FieldType from neptune.attributes import create_attribute_from_type from neptune.attributes.attribute import Attribute -from neptune.api.models import FieldType class TestAttributeUtils(unittest.TestCase): @@ -27,7 +27,6 @@ def test_attribute_type_to_atom(self): # ... and this reflection is class based on `Attribute` self.assertTrue( all( - isinstance(create_attribute_from_type(attr_type, MagicMock(), ""), Attribute) - for attr_type in FieldType + isinstance(create_attribute_from_type(attr_type, MagicMock(), ""), Attribute) for attr_type in FieldType ) ) diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 1c3b201a1..c8b3a7bfd 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -23,12 +23,17 @@ from mock import patch from neptune import ANONYMOUS_API_TOKEN +from neptune.api.models import ( + Field, + FieldDefinition, + FieldType, + LeaderboardEntry, +) from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, ) from neptune.exceptions import MetadataInconsistency -from neptune.api.models import Field, FieldType, FieldDefinition, LeaderboardEntry from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.table import ( Table, diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index 213a1e30b..0041eb1a6 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -23,6 +23,11 @@ ANONYMOUS_API_TOKEN, init_model, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, +) from neptune.attributes import String from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -32,7 +37,6 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.api.models import IntField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException from neptune.internal.warnings import ( diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index a37ba7a03..0ce87d90e 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -23,6 +23,12 @@ ANONYMOUS_API_TOKEN, init_model_version, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, + StringField, +) from neptune.attributes import String from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -33,7 +39,6 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.api.models import IntField, StringField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index c4d3a7bd3..d7ada9078 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -24,6 +24,13 @@ ANONYMOUS_API_TOKEN, init_project, ) +from neptune.api.models import ( + Field, + FieldDefinition, + FieldType, + IntField, + LeaderboardEntry, +) from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, @@ -32,7 +39,6 @@ NeptuneMissingProjectNameException, NeptuneUnsupportedFunctionalityException, ) -from neptune.api.models import Field, IntField, FieldType, FieldDefinition, LeaderboardEntry from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException from neptune.internal.warnings import ( diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index cdb6aa198..e36c4eb0c 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -27,13 +27,17 @@ ANONYMOUS_API_TOKEN, init_run, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, +) from neptune.attributes.atoms import String from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, ) from neptune.exceptions import MissingFieldException -from neptune.api.models import IntField, FieldType, FieldDefinition from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.utils.utils import IS_WINDOWS from neptune.internal.warnings import ( diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index b53616ceb..ce7adaf9f 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -21,7 +21,11 @@ from mock import patch from neptune import init_project -from neptune.api.models import Field, FieldType, LeaderboardEntry +from neptune.api.models import ( + Field, + FieldType, + LeaderboardEntry, +) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.table import ( diff --git a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py index eca5d3343..febea363e 100644 --- a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py +++ b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py @@ -21,6 +21,14 @@ from pathlib import Path from time import time +from neptune.api.models import ( + DatetimeField, + FloatField, + FloatSeriesField, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.exceptions import ( ContainerUUIDNotFound, @@ -32,8 +40,6 @@ StringPointValue, StringSeriesValues, ) -from neptune.api.models import FloatField, StringField, DatetimeField, FloatSeriesField, StringSeriesField, \ - StringSetField from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.operation import ( From 7ea78829d21a624d415371a3a1bd190abea5166a Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 12:50:18 +0200 Subject: [PATCH 11/22] Fixes and Bravado model deserialization --- src/neptune/api/field_visitor.py | 2 +- src/neptune/api/models.py | 127 ++++++++++++++++-- src/neptune/integrations/pandas/__init__.py | 2 +- .../backends/hosted_neptune_backend.py | 22 +-- 4 files changed, 130 insertions(+), 23 deletions(-) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index 5c22bac31..a39571783 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -79,7 +79,7 @@ def visit_string_set(self, field: StringSetField) -> Set[str]: return field.values def visit_git_ref(self, field: GitRefField) -> Optional[str]: - return field.commit_id + return field.commit.commit_id if field.commit is not None else None def visit_object_state(self, field: ObjectStateField) -> str: return field.value diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 2b95ab978..0ce8b29a3 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -19,6 +19,7 @@ "FileEntry", "Field", "FieldType", + "GitCommit", "LeaderboardEntry", "LeaderboardEntriesSearchResult", "FieldVisitor", @@ -108,9 +109,14 @@ def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> def accept(self, visitor: FieldVisitor[Ret]) -> Ret: ... @staticmethod - def from_dict(field: Dict[str, Any]) -> Field: - field_type = field["type"] - return Field._registry[field_type].from_dict(field[f"{field_type}Properties"]) + def from_dict(data: Dict[str, Any]) -> Field: + field_type = data["type"] + return Field._registry[field_type].from_dict(data[f"{field_type}Properties"]) + + @staticmethod + def from_model(model: Any) -> Field: + field_type = str(model.type) + return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}_properties")) class FieldVisitor(Generic[Ret], abc.ABC): @@ -176,6 +182,10 @@ def from_dict(data: Dict[str, Any]) -> FloatField: # TODO: Map only if not null return FloatField(path=data["attributeName"], value=float(data["value"])) + @staticmethod + def from_model(model: Any) -> FloatField: + return FloatField(path=model.attributeName, value=model.value) + @dataclass class IntField(Field, field_type=FieldType.INT): @@ -189,6 +199,10 @@ def from_dict(data: Dict[str, Any]) -> IntField: # TODO: Map only if not null return IntField(path=data["attributeName"], value=int(data["value"])) + @staticmethod + def from_model(model: Any) -> IntField: + return IntField(path=model.attributeName, value=model.value) + @dataclass class BoolField(Field, field_type=FieldType.BOOL): @@ -202,6 +216,10 @@ def from_dict(data: Dict[str, Any]) -> BoolField: # TODO: Map only if not null return BoolField(path=data["attributeName"], value=bool(data["value"])) + @staticmethod + def from_model(model: Any) -> BoolField: + return BoolField(path=model.attributeName, value=model.value) + @dataclass class StringField(Field, field_type=FieldType.STRING): @@ -215,6 +233,10 @@ def from_dict(data: Dict[str, Any]) -> StringField: # TODO: Map only if not null return StringField(path=data["attributeName"], value=str(data["value"])) + @staticmethod + def from_model(model: Any) -> StringField: + return StringField(path=model.attributeName, value=model.value) + @dataclass class DatetimeField(Field, field_type=FieldType.DATETIME): @@ -226,8 +248,13 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> DatetimeField: # TODO: what if none + # TODO: Exceptions return DatetimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) + @staticmethod + def from_model(model: Any) -> DatetimeField: + return DatetimeField(path=model.attributeName, value=parse_iso_date(model.value)) + @dataclass class FileField(Field, field_type=FieldType.FILE): @@ -240,8 +267,13 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FileField: + # TODO: Map to str if not null name and ext return FileField(path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"])) + @staticmethod + def from_model(model: Any) -> FileField: + return FileField(path=model.attributeName, name=model.name, ext=model.ext, size=model.size) + @dataclass class FileSetField(Field, field_type=FieldType.FILE_SET): @@ -254,6 +286,10 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: def from_dict(data: Dict[str, Any]) -> FileSetField: return FileSetField(path=data["attributeName"], size=int(data["size"])) + @staticmethod + def from_model(model: Any) -> FileSetField: + return FileSetField(path=model.attributeName, size=model.size) + @dataclass class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES): @@ -265,7 +301,13 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FloatSeriesField: # TODO: last is optional so map to float if present - return FloatSeriesField(path=data["attributeName"], last=data["last"]) + # TODO: Last may not be present at all + # TODO: Ensure that it's same as previously (last vs lastStep) + return FloatSeriesField(path=data["attributeName"], last=data.get("last", None)) + + @staticmethod + def from_model(model: Any) -> FloatSeriesField: + return FloatSeriesField(path=model.attributeName, last=model.last) @dataclass @@ -278,7 +320,13 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> StringSeriesField: # TODO: last is optional so map to str if present - return StringSeriesField(path=data["attributeName"], last=data["last"]) + # TODO: Last may not be present at all + # TODO: Ensure that it's same as previously (last vs lastStep) + return StringSeriesField(path=data["attributeName"], last=data.get("last", "")) + + @staticmethod + def from_model(model: Any) -> StringSeriesField: + return StringSeriesField(path=model.attributeName, last=model.last) @dataclass @@ -293,6 +341,10 @@ def from_dict(data: Dict[str, Any]) -> ImageSeriesField: # TODO: last_step is optional so map to float if present return ImageSeriesField(path=data["attributeName"], last_step=data["lastStep"]) + @staticmethod + def from_model(model: Any) -> ImageSeriesField: + return ImageSeriesField(path=model.attributeName, last_step=model.lastStep) + @dataclass class StringSetField(Field, field_type=FieldType.STRING_SET): @@ -303,20 +355,43 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> StringSetField: - return StringSetField(path=data["attributeName"], values=set(data["values"])) + return StringSetField(path=data["attributeName"], values=set(map(str, data.get("values", [])))) + + @staticmethod + def from_model(model: Any) -> StringSetField: + return StringSetField(path=model.attributeName, values=set(model.values)) + + +@dataclass +class GitCommit: + commit_id: str + + @staticmethod + def from_dict(data: Dict[str, Any]) -> GitCommit: + # TODO: commit and commit_id is optional so map to str if present + return GitCommit(commit_id=str(data["commitId"])) + + @staticmethod + def from_model(model: Any) -> GitCommit: + return GitCommit(commit_id=model.commitId) @dataclass class GitRefField(Field, field_type=FieldType.GIT_REF): - commit_id: Optional[str] + commit: Optional[GitCommit] def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_git_ref(self) @staticmethod def from_dict(data: Dict[str, Any]) -> GitRefField: - # TODO: commit and commit_id is optional so map to str if present - return GitRefField(path=data["attributeName"], commit_id=data["commit"]["commitId"]) + commit = GitCommit.from_dict(data["commit"]) if "commit" in data else None + return GitRefField(path=data["attributeName"], commit=commit) + + @staticmethod + def from_model(model: Any) -> GitRefField: + commit = GitCommit.from_model(model.commit) if model.commit is not None else None + return GitRefField(path=model.attributeName, commit=commit) @dataclass @@ -328,9 +403,15 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ObjectStateField: + # TODO: value is optional so map to str if present value = RunState.from_api(str(data["value"])).value return ObjectStateField(path=data["attributeName"], value=value) + @staticmethod + def from_model(model: Any) -> ObjectStateField: + value = RunState.from_api(str(model.value)).value + return ObjectStateField(path=model.attributeName, value=value) + @dataclass class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF): @@ -342,7 +423,11 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> NotebookRefField: # TODO: notebook_name is optional so map to str if present - return NotebookRefField(path=data["attributeName"], notebook_name=data["notebookName"]) + return NotebookRefField(path=data["attributeName"], notebook_name=data.get("notebookName", None)) + + @staticmethod + def from_model(model: Any) -> NotebookRefField: + return NotebookRefField(path=model.attributeName, notebook_name=model.notebookName) @dataclass @@ -354,8 +439,13 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ArtifactField: + # TODO: hash is optional so map to str if present return ArtifactField(path=data["attributeName"], hash=str(data["hash"])) + @staticmethod + def from_model(model: Any) -> ArtifactField: + return ArtifactField(path=model.attributeName, hash=model.hash) + @dataclass class LeaderboardEntry: @@ -368,6 +458,12 @@ def from_dict(data: Dict[str, Any]) -> LeaderboardEntry: object_id=data["experimentId"], fields=[Field.from_dict(field) for field in data["attributes"]] ) + @staticmethod + def from_model(model: Any) -> LeaderboardEntry: + return LeaderboardEntry( + object_id=model.experimentId, fields=[Field.from_model(field) for field in model.attributes] + ) + @dataclass class LeaderboardEntriesSearchResult: @@ -382,6 +478,13 @@ def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: matching_item_count=result["matchingItemCount"], ) + @staticmethod + def from_model(result: Any) -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + entries=[LeaderboardEntry.from_model(entry) for entry in result.entries], + matching_item_count=result.matchingItemCount, + ) + @dataclass class FieldDefinition: @@ -391,3 +494,7 @@ class FieldDefinition: @staticmethod def from_dict(data: Dict[str, Any]) -> FieldDefinition: return FieldDefinition(path=data["name"], type=FieldType(data["type"])) + + @staticmethod + def from_model(model: Any) -> FieldDefinition: + return FieldDefinition(path=model.name, type=FieldType(model.type)) diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index f7260a912..523cc3e10 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -90,7 +90,7 @@ def visit_file_set(self, field: FileSetField) -> None: return None def visit_git_ref(self, field: GitRefField) -> Optional[str]: - return field.commit_id + return field.commit.commit_id if field.commit is not None else None def visit_object_state(self, field: ObjectStateField) -> str: return field.value diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 2d4499f3a..b040d4e7c 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -706,7 +706,7 @@ def get_attributes(self, container_id: str, container_type: ContainerType) -> Li ) return [ - FieldDefinition.from_dict(field) for field in accepted_attributes if field.type in attribute_type_names + FieldDefinition.from_model(field) for field in accepted_attributes if field.type in attribute_type_names ] except HTTPNotFound as e: raise ContainerUUIDNotFound( @@ -791,7 +791,7 @@ def get_float_attribute(self, container_id: str, container_type: ContainerType, } try: result = self.leaderboard_client.api.getFloatAttribute(**params).response().result - return FloatField.from_dict(result) + return FloatField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -804,7 +804,7 @@ def get_int_attribute(self, container_id: str, container_type: ContainerType, pa } try: result = self.leaderboard_client.api.getIntAttribute(**params).response().result - return IntField.from_dict(result) + return IntField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -817,7 +817,7 @@ def get_bool_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getBoolAttribute(**params).response().result - return BoolField.from_dict(result) + return BoolField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -830,7 +830,7 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getFileAttribute(**params).response().result - return FileField.from_dict(result) + return FileField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -843,7 +843,7 @@ def get_string_attribute(self, container_id: str, container_type: ContainerType, } try: result = self.leaderboard_client.api.getStringAttribute(**params).response().result - return StringField.from_dict(result) + return StringField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -858,7 +858,7 @@ def get_datetime_attribute( } try: result = self.leaderboard_client.api.getDatetimeAttribute(**params).response().result - return DatetimeField.from_dict(result) + return DatetimeField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -906,7 +906,7 @@ def get_float_series_attribute( } try: result = self.leaderboard_client.api.getFloatSeriesAttribute(**params).response().result - return FloatSeriesField.from_dict(result) + return FloatSeriesField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -921,7 +921,7 @@ def get_string_series_attribute( } try: result = self.leaderboard_client.api.getStringSeriesAttribute(**params).response().result - return StringSeriesField.from_dict(result) + return StringSeriesField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -935,8 +935,8 @@ def get_string_set_attribute( **DEFAULT_REQUEST_KWARGS, } try: - result = self.leaderboard_client.api.getStringSetAttribute(**params).response().result - return StringSetField.from_dict(result) + result = self.leaderboard_client.api.getStringSetAttribute(**params) + return StringSetField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) From bef202650cb952e9799c260ffe35b33351e8adbd Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 12:55:10 +0200 Subject: [PATCH 12/22] Fixes --- src/neptune/internal/backends/hosted_neptune_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index b040d4e7c..ececaa08e 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -935,7 +935,7 @@ def get_string_set_attribute( **DEFAULT_REQUEST_KWARGS, } try: - result = self.leaderboard_client.api.getStringSetAttribute(**params) + result = self.leaderboard_client.api.getStringSetAttribute(**params).response().result return StringSetField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) From 582106d3909e2e8ab130e34084c09fb5815c2e79 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 13:41:12 +0200 Subject: [PATCH 13/22] Unittests fixes --- src/neptune/api/models.py | 3 +- .../neptune/new/api/test_searching_entries.py | 187 ++++++++++-------- .../new/client/abstract_tables_test.py | 65 +++--- tests/unit/neptune/new/client/test_project.py | 52 +---- tests/unit/neptune/new/client/test_run.py | 3 +- .../neptune/new/client/test_run_tables.py | 15 +- .../backends/test_neptune_backend_mock.py | 31 +-- .../neptune/new/internal/test_file_entry.py | 2 +- 8 files changed, 176 insertions(+), 182 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 0ce8b29a3..0b8ebf066 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -473,8 +473,7 @@ class LeaderboardEntriesSearchResult: @staticmethod def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: return LeaderboardEntriesSearchResult( - # TODO: Use generator instead - entries=[LeaderboardEntry.from_dict(entry) for entry in result["entries"]], + entries=[LeaderboardEntry.from_dict(entry) for entry in result.get("entries", [])], matching_item_count=result["matchingItemCount"], ) diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 50924bb1a..c17e3f13d 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -14,7 +14,8 @@ # limitations under the License. # from typing import ( - List, + Any, + Dict, Sequence, ) @@ -27,8 +28,8 @@ ) from neptune.api.models import ( - Field, - FieldType, + FloatField, + LeaderboardEntriesSearchResult, LeaderboardEntry, StringField, ) @@ -48,6 +49,8 @@ def test__to_leaderboard_entry(): "name": "plugh", "type": "float", "floatProperties": { + "attributeName": "plugh", + "attributeType": "float", "value": 1.0, }, }, @@ -55,6 +58,8 @@ def test__to_leaderboard_entry(): "name": "sys/id", "type": "string", "stringProperties": { + "attributeName": "sys/id", + "attributeType": "string", "value": "TEST-123", }, }, @@ -62,37 +67,25 @@ def test__to_leaderboard_entry(): } # when - result = to_leaderboard_entry(entry=entry) + result = LeaderboardEntry.from_dict(entry) # then assert result.object_id == "foo" assert result.fields == [ - Field( - path="plugh", - type=FieldType.FLOAT, - properties={ - "value": 1.0, - }, - ), - Field( - path="sys/id", - type=FieldType.STRING, - properties={ - "value": "TEST-123", - }, - ), + FloatField(path="plugh", value=1.0), + StringField(path="sys/id", value="TEST-123"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) -def test__iter_over_pages__single_pagination(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__single_pagination(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 9}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -101,33 +94,38 @@ def test__iter_over_pages__single_pagination(get_single_page, entries_from_page) step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=9, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=9, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) -def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__multiple_search_after(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 9}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -136,7 +134,7 @@ def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_p step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, max_offset=6, @@ -144,22 +142,29 @@ def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_p ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="f"), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="f"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) -def test__iter_over_pages__empty(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__empty(get_single_page_mock): # given - entries_from_page.side_effect = [[]] + get_single_page_mock.side_effect = [ + {"matchingItemCount": 0}, + generate_leaderboard_entries(values=[]), + ] # when result = list( @@ -167,7 +172,7 @@ def test__iter_over_pages__empty(get_single_page, entries_from_page): step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) @@ -175,21 +180,21 @@ def test__iter_over_pages__empty(get_single_page, entries_from_page): # then assert result == [] - assert get_single_page.mock_calls == [ + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) -def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__max_server_offset(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 5}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -198,7 +203,7 @@ def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page) step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, max_offset=5, @@ -206,28 +211,33 @@ def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page) ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="e"), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="e"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 5}) -def test__iter_over_pages__limit(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__limit(get_single_page_mock): # since the limiting itself takes place in an external service, we can't test the results # we can only test if the limit is properly passed to the external service call # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 5}, generate_leaderboard_entries(values=["a", "b"]), generate_leaderboard_entries(values=["c", "d"]), generate_leaderboard_entries(values=["e"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -236,29 +246,42 @@ def test__iter_over_pages__limit(get_single_page, entries_from_page): step_size=2, limit=4, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) ) # then - assert get_single_page.mock_calls == [ + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=2, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=2, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> List[LeaderboardEntry]: - return [ - LeaderboardEntry( - object_id=experiment_id, - fields=[StringField(path="sys/id", value=value)], - ) - for value in values - ] +def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> Dict[str, Any]: + return { + "matchingItemCount": len(values), + "entries": [ + { + "experimentId": f"{experiment_id}-{value}", + "attributes": [ + { + "name": "sys/id", + "type": "string", + "stringProperties": { + "attributeName": "sys/id", + "attributeType": "string", + "value": value, + }, + }, + ], + } + for value in values + ], + } @patch("neptune.api.searching_entries.construct_request") diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index c8b3a7bfd..e9fbba2ea 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -24,10 +24,21 @@ from neptune import ANONYMOUS_API_TOKEN from neptune.api.models import ( - Field, + DatetimeField, FieldDefinition, FieldType, + FileField, + FileSetField, + FloatField, + FloatSeriesField, + GitCommit, + GitRefField, + ImageSeriesField, LeaderboardEntry, + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, ) from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -67,26 +78,20 @@ def setUp(cls) -> None: del os.environ[PROJECT_ENV_NAME] @staticmethod - def build_attributes_leaderboard(now: datetime): - attributes = [] - attributes.append(Field("run/state", FieldType.OBJECT_STATE, {"value": "idle"})) - attributes.append(Field("float", FieldType.FLOAT, {"value": 12.5})) - attributes.append(Field("string", FieldType.STRING, {"value": "some text"})) - attributes.append(Field("datetime", FieldType.DATETIME, {"value": now})) - attributes.append(Field("float/series", FieldType.FLOAT_SERIES, {"last": 8.7})) - attributes.append(Field("string/series", FieldType.STRING_SERIES, {"last": "last text"})) - attributes.append(Field("string/set", FieldType.STRING_SET, {"values": ["a", "b"]})) - attributes.append( - Field( - "git/ref", - FieldType.GIT_REF, - {"commit": {"commitId": "abcdef0123456789"}}, - ) - ) - attributes.append(Field("file", FieldType.FILE, None)) - attributes.append(Field("file/set", FieldType.FILE_SET, None)) - attributes.append(Field("image/series", FieldType.IMAGE_SERIES, None)) - return attributes + def build_fields_leaderboard(now: datetime): + return [ + ObjectStateField(path="run/state", value="Inactive"), + FloatField(path="float", value=12.5), + StringField(path="string", value="some text"), + DatetimeField(path="datetime", value=now), + FloatSeriesField(path="float/series", last=8.7), + StringSeriesField(path="string/series", last="last text"), + StringSetField(path="string/set", values={"a", "b"}), + GitRefField(path="git/ref", commit=GitCommit(commit_id="abcdef0123456789")), + FileField(path="file", size=0, name="file.txt", ext="txt"), + FileSetField(path="file/set", size=0), + ImageSeriesField(path="image/series", last_step=None), + ] @patch.object(NeptuneBackendMock, "search_leaderboard_entries") def test_get_table_with_columns_filter(self, search_leaderboard_entries): @@ -102,11 +107,11 @@ def test_get_table_with_columns_filter(self, search_leaderboard_entries): def test_get_table_as_pandas(self, search_leaderboard_entries): # given now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) - filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + empty_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=[]) + filled_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=fields) search_leaderboard_entries.return_value = [empty_entry, filled_entry] # when @@ -119,7 +124,7 @@ def test_get_table_as_pandas(self, search_leaderboard_entries): self.assertEqual(now, df["datetime"][1]) self.assertEqual(8.7, df["float/series"][1]) self.assertEqual("last text", df["string/series"][1]) - self.assertEqual("a,b", df["string/set"][1]) + self.assertEqual({"a", "b"}, set(df["string/set"][1].split(","))) self.assertEqual("abcdef0123456789", df["git/ref"][1]) with self.assertRaises(KeyError): @@ -133,11 +138,11 @@ def test_get_table_as_pandas(self, search_leaderboard_entries): def test_get_table_as_rows(self, search_leaderboard_entries): # given now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) - filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + empty_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=[]) + filled_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=fields) search_leaderboard_entries.return_value = [empty_entry, filled_entry] # and @@ -173,10 +178,10 @@ def test_get_table_as_table_entries( # given exp_id = str(uuid.uuid4()) now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - search_leaderboard_entries.return_value = [LeaderboardEntry(exp_id, attributes)] + search_leaderboard_entries.return_value = [LeaderboardEntry(object_id=exp_id, fields=fields)] # when table_entry = self.get_table_entries(table=self.get_table())[0] diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index d7ada9078..e73284c8c 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -15,7 +15,6 @@ # import os import unittest -from datetime import datetime import pytest from mock import patch @@ -25,11 +24,9 @@ init_project, ) from neptune.api.models import ( - Field, FieldDefinition, FieldType, IntField, - LeaderboardEntry, ) from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -41,6 +38,7 @@ ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, @@ -97,7 +95,7 @@ def test_project_name_env_var(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntField(42), + new=lambda _, _uuid, _type, _path: IntField(value=42, path=path_to_str(_path)), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -165,49 +163,3 @@ def test_prepare_nql_query(): trashed=None, ) assert len(query.items) == 0 - - -def test_parse_dates(): - def entries_generator(): - yield LeaderboardEntry( - object_id="test", - fields=[ - Field( - "attr1", - FieldType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, - ), - Field( - "attr2", - FieldType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, - ), - ], - ) - - parsed = list(parse_dates(entries_generator())) - assert parsed[0].fields[0].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) - assert parsed[0].fields[1].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) - - -@patch("neptune.objects.utils.warn_once") -def test_parse_dates_wrong_format(mock_warn_once): - entries = [ - LeaderboardEntry( - object_id="test", - fields=[ - Field( - "attr1", - FieldType.DATETIME, - {"value": "07-02-2024"}, # different format than expected - ) - ], - ) - ] - - parsed = list(parse_dates(entries)) - assert parsed[0].fields[0].properties["value"] == "07-02-2024" # should be left unchanged due to ValueError - mock_warn_once.assert_called_once_with( - "Date parsing failed. The date format is incorrect. Returning as string instead of datetime.", - exception=NeptuneWarning, - ) diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index e36c4eb0c..765d724bc 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -39,6 +39,7 @@ ) from neptune.exceptions import MissingFieldException from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock +from neptune.internal.utils.paths import path_to_str from neptune.internal.utils.utils import IS_WINDOWS from neptune.internal.warnings import ( NeptuneWarning, @@ -72,7 +73,7 @@ def setUpClass(cls) -> None: ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntField(42), + new=lambda _, _uuid, _type, _path: IntField(value=42, path=path_to_str(_path)), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index ce7adaf9f..146855068 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -23,7 +23,6 @@ from neptune import init_project from neptune.api.models import ( Field, - FieldType, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -68,10 +67,16 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): LeaderboardEntry( object_id="123", fields=[ - Field( - "sys/creation_time", - FieldType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, + Field.from_dict( + { + "type": "datetime", + "path": "sys/creation_time", + "datetimeProperties": { + "attributeName": "sys/creation_time", + "attributeType": "datetime", + "value": "2024-02-05T20:37:40.915000Z", + }, + } ) ], ) diff --git a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py index febea363e..e2b7797aa 100644 --- a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py +++ b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py @@ -137,6 +137,9 @@ def test_get_datetime_attribute(self): def test_get_float_series_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( @@ -144,7 +147,7 @@ def test_get_float_series_attribute(self): container_type, [ LogFloats( - ["x"], + path, [ LogFloats.ValueType(5, None, time()), LogFloats.ValueType(3, None, time()), @@ -158,7 +161,7 @@ def test_get_float_series_attribute(self): container_type, [ LogFloats( - ["x"], + path, [ LogFloats.ValueType(2, None, time()), LogFloats.ValueType(9, None, time()), @@ -169,13 +172,16 @@ def test_get_float_series_attribute(self): ) # when - ret = self.backend.get_float_series_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_float_series_attribute(container_id, container_type, path) # then - self.assertEqual(FloatSeriesField(9), ret) + self.assertEqual(FloatSeriesField(last=9, path="x"), ret) def test_get_string_series_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( @@ -183,7 +189,7 @@ def test_get_string_series_attribute(self): container_type, [ LogStrings( - ["x"], + path, [ LogStrings.ValueType("adf", None, time()), LogStrings.ValueType("sdg", None, time()), @@ -197,7 +203,7 @@ def test_get_string_series_attribute(self): container_type, [ LogStrings( - ["x"], + path, [ LogStrings.ValueType("dfh", None, time()), LogStrings.ValueType("qwe", None, time()), @@ -208,27 +214,30 @@ def test_get_string_series_attribute(self): ) # when - ret = self.backend.get_string_series_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_string_series_attribute(container_id, container_type, path) # then - self.assertEqual(StringSeriesField("qwe"), ret) + self.assertEqual(StringSeriesField(last="qwe", path="x"), ret) def test_get_string_set_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( container_id, container_type, - [AddStrings(["x"], {"abcx", "qwe"})], + [AddStrings(path, {"abcx", "qwe"})], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_string_set_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_string_set_attribute(container_id, container_type, path) # then - self.assertEqual(StringSetField({"abcx", "qwe"}), ret) + self.assertEqual(StringSetField(values={"abcx", "qwe"}, path="x"), ret) def test_get_string_series_values(self): # given diff --git a/tests/unit/neptune/new/internal/test_file_entry.py b/tests/unit/neptune/new/internal/test_file_entry.py index ba3d9844c..9e81b800c 100644 --- a/tests/unit/neptune/new/internal/test_file_entry.py +++ b/tests/unit/neptune/new/internal/test_file_entry.py @@ -1,7 +1,7 @@ import datetime from dataclasses import dataclass -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry def test_file_entry_from_dto(): From c74a41dab12dce40f13d69a9609164db183c364a Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 14:12:28 +0200 Subject: [PATCH 14/22] More tests --- src/neptune/api/field_visitor.py | 4 +- src/neptune/api/models.py | 18 +-- src/neptune/integrations/pandas/__init__.py | 4 +- .../backends/hosted_neptune_backend.py | 6 +- .../internal/backends/neptune_backend.py | 4 +- .../internal/backends/neptune_backend_mock.py | 6 +- .../backends/offline_neptune_backend.py | 4 +- tests/unit/neptune/new/api/test_models.py | 143 ++++++++++++++++++ .../new/client/abstract_tables_test.py | 4 +- .../backends/test_neptune_backend_mock.py | 4 +- 10 files changed, 168 insertions(+), 29 deletions(-) create mode 100644 tests/unit/neptune/new/api/test_models.py diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py index a39571783..acc4c8d51 100644 --- a/src/neptune/api/field_visitor.py +++ b/src/neptune/api/field_visitor.py @@ -25,7 +25,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldVisitor, FileField, FileSetField, @@ -57,7 +57,7 @@ def visit_bool(self, field: BoolField) -> bool: def visit_string(self, field: StringField) -> str: return field.value - def visit_datetime(self, field: DatetimeField) -> datetime: + def visit_datetime(self, field: DateTimeField) -> datetime: return field.value def visit_file(self, field: FileField) -> None: diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 0b8ebf066..4365f7db9 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -27,7 +27,7 @@ "IntField", "BoolField", "StringField", - "DatetimeField", + "DateTimeField", "FileField", "FileSetField", "FloatSeriesField", @@ -137,7 +137,7 @@ def visit_bool(self, field: BoolField) -> Ret: ... def visit_string(self, field: StringField) -> Ret: ... @abc.abstractmethod - def visit_datetime(self, field: DatetimeField) -> Ret: ... + def visit_datetime(self, field: DateTimeField) -> Ret: ... @abc.abstractmethod def visit_file(self, field: FileField) -> Ret: ... @@ -179,7 +179,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FloatField: - # TODO: Map only if not null return FloatField(path=data["attributeName"], value=float(data["value"])) @staticmethod @@ -196,7 +195,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> IntField: - # TODO: Map only if not null return IntField(path=data["attributeName"], value=int(data["value"])) @staticmethod @@ -213,7 +211,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> BoolField: - # TODO: Map only if not null return BoolField(path=data["attributeName"], value=bool(data["value"])) @staticmethod @@ -230,7 +227,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> StringField: - # TODO: Map only if not null return StringField(path=data["attributeName"], value=str(data["value"])) @staticmethod @@ -239,21 +235,21 @@ def from_model(model: Any) -> StringField: @dataclass -class DatetimeField(Field, field_type=FieldType.DATETIME): +class DateTimeField(Field, field_type=FieldType.DATETIME): value: datetime def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_datetime(self) @staticmethod - def from_dict(data: Dict[str, Any]) -> DatetimeField: + def from_dict(data: Dict[str, Any]) -> DateTimeField: # TODO: what if none # TODO: Exceptions - return DatetimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) + return DateTimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) @staticmethod - def from_model(model: Any) -> DatetimeField: - return DatetimeField(path=model.attributeName, value=parse_iso_date(model.value)) + def from_model(model: Any) -> DateTimeField: + return DateTimeField(path=model.attributeName, value=parse_iso_date(model.value)) @dataclass diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 523cc3e10..52161c7f5 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -31,7 +31,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldVisitor, FileField, FileSetField, @@ -68,7 +68,7 @@ def visit_bool(self, field: BoolField) -> bool: def visit_string(self, field: StringField) -> str: return field.value - def visit_datetime(self, field: DatetimeField) -> datetime: + def visit_datetime(self, field: DateTimeField) -> datetime: return field.value def visit_file(self, field: FileField) -> None: diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index ececaa08e..9f8974b87 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -41,7 +41,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldDefinition, FieldType, FileEntry, @@ -850,7 +850,7 @@ def get_string_attribute(self, container_id: str, container_type: ContainerType, @with_api_exceptions_handler def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeField: + ) -> DateTimeField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -858,7 +858,7 @@ def get_datetime_attribute( } try: result = self.leaderboard_client.api.getDatetimeAttribute(**params).response().result - return DatetimeField.from_model(result) + return DateTimeField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 006a2ab75..f6d21dbeb 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -28,7 +28,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldDefinition, FieldType, FileEntry, @@ -196,7 +196,7 @@ def get_string_attribute(self, container_id: str, container_type: ContainerType, @abc.abstractmethod def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeField: + ) -> DateTimeField: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 6e40c0894..0a8a1b515 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -37,7 +37,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldDefinition, FieldType, FileEntry, @@ -399,9 +399,9 @@ def get_string_attribute(self, container_id: str, container_type: ContainerType, def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeField: + ) -> DateTimeField: val = self._get_attribute(container_id, container_type, path, Datetime) - return DatetimeField(path=path_to_str(path), value=val.value) + return DateTimeField(path=path_to_str(path), value=val.value) def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index aee0c7c83..48f2b8264 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -23,7 +23,7 @@ from neptune.api.models import ( ArtifactField, BoolField, - DatetimeField, + DateTimeField, FieldDefinition, FileEntry, FileField, @@ -69,7 +69,7 @@ def get_string_attribute(self, container_id: str, container_type: ContainerType, def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeField: + ) -> DateTimeField: raise NeptuneOfflineModeFetchException def get_artifact_attribute( diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py new file mode 100644 index 000000000..5c13f6f1d --- /dev/null +++ b/tests/unit/neptune/new/api/test_models.py @@ -0,0 +1,143 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from mock import Mock + +from neptune.api.models import ( + BoolField, + FloatField, + IntField, + StringField, +) + + +def test__float_field__from_dict(): + # given + data = {"attributeType": "float", "attributeName": "some/float", "value": 18.5} + + # when + result = FloatField.from_dict(data) + + # then + assert result.path == "some/float" + assert result.value == 18.5 + + +def test__float_field__from_model(): + # given + model = Mock(attributeType="float", attributeName="some/float", value=18.5) + + # when + result = FloatField.from_model(model) + + # then + assert result.path == "some/float" + assert result.value == 18.5 + + +def test__int_field__from_dict(): + # given + data = {"attributeType": "int", "attributeName": "some/int", "value": 18} + + # when + result = IntField.from_dict(data) + + # then + assert result.path == "some/int" + assert result.value == 18 + + +def test__int_field__from_model(): + # given + model = Mock(attributeType="int", attributeName="some/int", value=18) + + # when + result = IntField.from_model(model) + + # then + assert result.path == "some/int" + assert result.value == 18 + + +def test__string_field__from_dict(): + # given + data = {"attributeType": "string", "attributeName": "some/string", "value": "hello"} + + # when + result = StringField.from_dict(data) + + # then + assert result.path == "some/string" + assert result.value == "hello" + + +def test__string_field__from_model(): + # given + model = Mock(attributeType="string", attributeName="some/string", value="hello") + + # when + result = StringField.from_model(model) + + # then + assert result.path == "some/string" + assert result.value == "hello" + + +def test__string_field__from_dict__empty(): + # given + data = {"attributeType": "string", "attributeName": "some/string", "value": ""} + + # when + result = StringField.from_dict(data) + + # then + assert result.path == "some/string" + assert result.value == "" + + +def test__string_field__from_model__empty(): + # given + model = Mock(attributeType="string", attributeName="some/string", value="") + + # when + result = StringField.from_model(model) + + # then + assert result.path == "some/string" + assert result.value == "" + + +def test__bool_field__from_dict(): + # given + data = {"attributeType": "bool", "attributeName": "some/bool", "value": True} + + # when + result = BoolField.from_dict(data) + + # then + assert result.path == "some/bool" + assert result.value is True + + +def test__bool_field__from_model(): + # given + model = Mock(attributeType="bool", attributeName="some/bool", value=True) + + # when + result = BoolField.from_model(model) + + # then + assert result.path == "some/bool" + assert result.value is True diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index e9fbba2ea..0eee972b5 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -24,7 +24,7 @@ from neptune import ANONYMOUS_API_TOKEN from neptune.api.models import ( - DatetimeField, + DateTimeField, FieldDefinition, FieldType, FileField, @@ -83,7 +83,7 @@ def build_fields_leaderboard(now: datetime): ObjectStateField(path="run/state", value="Inactive"), FloatField(path="float", value=12.5), StringField(path="string", value="some text"), - DatetimeField(path="datetime", value=now), + DateTimeField(path="datetime", value=now), FloatSeriesField(path="float/series", last=8.7), StringSeriesField(path="string/series", last="last text"), StringSetField(path="string/set", values={"a", "b"}), diff --git a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py index e2b7797aa..b8ff429d8 100644 --- a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py +++ b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py @@ -22,7 +22,7 @@ from time import time from neptune.api.models import ( - DatetimeField, + DateTimeField, FloatField, FloatSeriesField, StringField, @@ -133,7 +133,7 @@ def test_get_datetime_attribute(self): ret = self.backend.get_datetime_attribute(container_id, container_type, path) # then - self.assertEqual(DatetimeField(path="x", value=now), ret) + self.assertEqual(DateTimeField(path="x", value=now), ret) def test_get_float_series_attribute(self): # given From 0f8425b273de91d6a9b75a523c293db41da9ff94 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 15:54:40 +0200 Subject: [PATCH 15/22] More tests --- src/neptune/api/models.py | 39 +- tests/unit/neptune/new/api/test_models.py | 1329 +++++++++++++++++++++ 2 files changed, 1346 insertions(+), 22 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 4365f7db9..8514cf281 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -105,6 +105,10 @@ def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> cls.type = field_type cls._registry[field_type.value] = cls + @classmethod + def by_type(cls, field_type: FieldType) -> Type[Field]: + return cls._registry[field_type.value] + @abc.abstractmethod def accept(self, visitor: FieldVisitor[Ret]) -> Ret: ... @@ -116,7 +120,7 @@ def from_dict(data: Dict[str, Any]) -> Field: @staticmethod def from_model(model: Any) -> Field: field_type = str(model.type) - return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}_properties")) + return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}Properties")) class FieldVisitor(Generic[Ret], abc.ABC): @@ -243,8 +247,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> DateTimeField: - # TODO: what if none - # TODO: Exceptions return DateTimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) @staticmethod @@ -263,7 +265,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FileField: - # TODO: Map to str if not null name and ext return FileField(path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"])) @staticmethod @@ -296,10 +297,8 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> FloatSeriesField: - # TODO: last is optional so map to float if present - # TODO: Last may not be present at all - # TODO: Ensure that it's same as previously (last vs lastStep) - return FloatSeriesField(path=data["attributeName"], last=data.get("last", None)) + last = float(data["last"]) if "last" in data else None + return FloatSeriesField(path=data["attributeName"], last=last) @staticmethod def from_model(model: Any) -> FloatSeriesField: @@ -315,10 +314,8 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> StringSeriesField: - # TODO: last is optional so map to str if present - # TODO: Last may not be present at all - # TODO: Ensure that it's same as previously (last vs lastStep) - return StringSeriesField(path=data["attributeName"], last=data.get("last", "")) + last = str(data["last"]) if "last" in data else None + return StringSeriesField(path=data["attributeName"], last=last) @staticmethod def from_model(model: Any) -> StringSeriesField: @@ -334,8 +331,8 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ImageSeriesField: - # TODO: last_step is optional so map to float if present - return ImageSeriesField(path=data["attributeName"], last_step=data["lastStep"]) + last_step = float(data["lastStep"]) if "lastStep" in data else None + return ImageSeriesField(path=data["attributeName"], last_step=last_step) @staticmethod def from_model(model: Any) -> ImageSeriesField: @@ -351,7 +348,7 @@ def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> StringSetField: - return StringSetField(path=data["attributeName"], values=set(map(str, data.get("values", [])))) + return StringSetField(path=data["attributeName"], values=set(map(str, data["values"]))) @staticmethod def from_model(model: Any) -> StringSetField: @@ -360,12 +357,12 @@ def from_model(model: Any) -> StringSetField: @dataclass class GitCommit: - commit_id: str + commit_id: Optional[str] @staticmethod def from_dict(data: Dict[str, Any]) -> GitCommit: - # TODO: commit and commit_id is optional so map to str if present - return GitCommit(commit_id=str(data["commitId"])) + commit_id = str(data["commitId"]) if "commitId" in data else None + return GitCommit(commit_id=commit_id) @staticmethod def from_model(model: Any) -> GitCommit: @@ -399,7 +396,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ObjectStateField: - # TODO: value is optional so map to str if present value = RunState.from_api(str(data["value"])).value return ObjectStateField(path=data["attributeName"], value=value) @@ -418,8 +414,8 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> NotebookRefField: - # TODO: notebook_name is optional so map to str if present - return NotebookRefField(path=data["attributeName"], notebook_name=data.get("notebookName", None)) + notebook_name = str(data["notebookName"]) if "notebookName" in data else None + return NotebookRefField(path=data["attributeName"], notebook_name=notebook_name) @staticmethod def from_model(model: Any) -> NotebookRefField: @@ -435,7 +431,6 @@ def accept(self, visitor: FieldVisitor[Ret]) -> Ret: @staticmethod def from_dict(data: Dict[str, Any]) -> ArtifactField: - # TODO: hash is optional so map to str if present return ArtifactField(path=data["attributeName"], hash=str(data["hash"])) @staticmethod diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py index 5c13f6f1d..ed6245f2b 100644 --- a/tests/unit/neptune/new/api/test_models.py +++ b/tests/unit/neptune/new/api/test_models.py @@ -13,13 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import datetime + +import pytest from mock import Mock from neptune.api.models import ( + ArtifactField, BoolField, + DateTimeField, + Field, + FieldDefinition, + FieldType, + FileField, + FileSetField, FloatField, + FloatSeriesField, + GitRefField, + ImageSeriesField, IntField, + LeaderboardEntriesSearchResult, + LeaderboardEntry, + NotebookRefField, + ObjectStateField, StringField, + StringSeriesField, + StringSetField, ) @@ -141,3 +160,1313 @@ def test__bool_field__from_model(): # then assert result.path == "some/bool" assert result.value is True + + +def test__datetime_field__from_dict(): + # given + data = {"attributeType": "datetime", "attributeName": "some/datetime", "value": "2024-01-01T00:12:34.567890Z"} + + # when + result = DateTimeField.from_dict(data) + + # then + assert result.path == "some/datetime" + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__datetime_field__from_model(): + # given + model = Mock(attributeType="datetime", attributeName="some/datetime", value="2024-01-01T00:12:34.567890Z") + + # when + result = DateTimeField.from_model(model) + + # then + assert result.path == "some/datetime" + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__float_series_field__from_dict(): + # given + data = { + "attributeType": "floatSeries", + "attributeName": "some/floatSeries", + "last": 19.5, + } + + # when + result = FloatSeriesField.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert result.last == 19.5 + + +def test__float_series_field__from_dict__no_last(): + # given + data = { + "attributeType": "floatSeries", + "attributeName": "some/floatSeries", + } + + # when + result = FloatSeriesField.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert result.last is None + + +def test__float_series_field__from_model(): + # given + model = Mock( + attributeType="floatSeries", + attributeName="some/floatSeries", + last=19.5, + ) + + # when + result = FloatSeriesField.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert result.last == 19.5 + + +def test__float_series_field__from_model__no_last(): + # given + model = Mock( + attributeType="floatSeries", + attributeName="some/floatSeries", + last=None, + ) + + # when + result = FloatSeriesField.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert result.last is None + + +def test__string_series_field__from_dict(): + # given + data = { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + "last": "hello", + } + + # when + result = StringSeriesField.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert result.last == "hello" + + +def test__string_series_field__from_dict__no_last(): + # given + data = { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + } + + # when + result = StringSeriesField.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert result.last is None + + +def test__string_series_field__from_model(): + # given + model = Mock( + attributeType="stringSeries", + attributeName="some/stringSeries", + last="hello", + ) + + # when + result = StringSeriesField.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert result.last == "hello" + + +def test__string_series_field__from_model__no_last(): + # given + model = Mock( + attributeType="stringSeries", + attributeName="some/stringSeries", + last=None, + ) + + # when + result = StringSeriesField.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert result.last is None + + +def test__image_series_field__from_dict(): + # given + data = { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + "lastStep": 15.0, + } + + # when + result = ImageSeriesField.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert result.last_step == 15.0 + + +def test__image_series_field__from_dict__no_last_step(): + # given + data = { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + } + + # when + result = ImageSeriesField.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert result.last_step is None + + +def test__image_series_field__from_model(): + # given + model = Mock( + attributeType="imageSeries", + attributeName="some/imageSeries", + lastStep=15.0, + ) + + # when + result = ImageSeriesField.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert result.last_step == 15.0 + + +def test__image_series_field__from_model__no_last_step(): + # given + model = Mock( + attributeType="imageSeries", + attributeName="some/imageSeries", + lastStep=None, + ) + + # when + result = ImageSeriesField.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert result.last_step is None + + +def test__string_set_field__from_dict(): + # given + data = { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": ["hello", "world"], + } + + # when + result = StringSetField.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert result.values == {"hello", "world"} + + +def test__string_set_field__from_dict__empty(): + # given + data = { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": [], + } + + # when + result = StringSetField.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert result.values == set() + + +def test__string_set_field__from_model(): + # given + model = Mock( + attributeType="stringSet", + attributeName="some/stringSet", + values=["hello", "world"], + ) + + # when + result = StringSetField.from_model(model) + + # then + assert result.path == "some/stringSet" + assert result.values == {"hello", "world"} + + +def test__string_set_field__from_model__empty(): + # given + model = Mock( + attributeType="stringSet", + attributeName="some/stringSet", + values=[], + ) + + # when + result = StringSetField.from_model(model) + + # then + assert result.path == "some/stringSet" + assert result.values == set() + + +def test__file_field__from_dict(): + # given + data = { + "attributeType": "file", + "attributeName": "some/file", + "name": "file.txt", + "size": 1024, + "ext": "txt", + } + + # when + result = FileField.from_dict(data) + + # then + assert result.path == "some/file" + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__file_field__from_model(): + # given + model = Mock( + attributeType="file", + attributeName="some/file", + size=1024, + ext="txt", + ) + model.name = "file.txt" + + # when + result = FileField.from_model(model) + + # then + assert result.path == "some/file" + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +@pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) +def test__object_state_field__from_dict(state, expected): + # given + data = {"attributeType": "experimentState", "attributeName": "sys/state", "value": state} + + # when + result = ObjectStateField.from_dict(data) + + # then + assert result.path == "sys/state" + assert result.value == expected + + +@pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) +def test__object_state_field__from_model(state, expected): + # given + model = Mock(attributeType="experimentState", attributeName="sys/state", value=state) + + # when + result = ObjectStateField.from_model(model) + + # then + assert result.path == "sys/state" + assert result.value == expected + + +def test__file_set_field__from_dict(): + # given + data = { + "attributeType": "fileSet", + "attributeName": "some/fileSet", + "size": 3072, + } + + # when + result = FileSetField.from_dict(data) + + # then + assert result.path == "some/fileSet" + assert result.size == 3072 + + +def test__file_set_field__from_model(): + # given + model = Mock( + attributeType="fileSet", + attributeName="some/fileSet", + size=3072, + ) + + # when + result = FileSetField.from_model(model) + + # then + assert result.path == "some/fileSet" + assert result.size == 3072 + + +def test__notebook_ref_field__from_dict(): + # given + data = { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + "notebookName": "Data Processing.ipynb", + } + + # when + result = NotebookRefField.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name == "Data Processing.ipynb" + + +def test__notebook_ref_field__from_dict__no_notebook_name(): + # given + data = { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + } + + # when + result = NotebookRefField.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name is None + + +def test__notebook_ref_field__from_model(): + # given + model = Mock( + attributeType="notebookRef", + attributeName="some/notebookRef", + notebookName="Data Processing.ipynb", + ) + + # when + result = NotebookRefField.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name == "Data Processing.ipynb" + + +def test__notebook_ref_field__from_model__no_notebook_name(): + # given + model = Mock( + attributeType="notebookRef", + attributeName="some/notebookRef", + notebookName=None, + ) + + # when + result = NotebookRefField.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name is None + + +def test__git_ref_field__from_dict(): + # given + data = { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + "commit": { + "commitId": "b2d7f8a", + }, + } + + # when + result = GitRefField.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert result.commit.commit_id == "b2d7f8a" + + +def test__git_ref_field__from_dict__no_commit(): + # given + data = { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + } + + # when + result = GitRefField.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert result.commit is None + + +def test__git_ref_field__from_model(): + # given + model = Mock( + attributeType="gitRef", + attributeName="some/gitRef", + commit=Mock( + commitId="b2d7f8a", + ), + ) + + # when + result = GitRefField.from_model(model) + + # then + assert result.path == "some/gitRef" + assert result.commit.commit_id == "b2d7f8a" + + +def test__git_ref_field__from_model__no_commit(): + # given + model = Mock( + attributeType="gitRef", + attributeName="some/gitRef", + commit=None, + ) + + # when + result = GitRefField.from_model(model) + + # then + assert result.path == "some/gitRef" + assert result.commit is None + + +def test__artifact_field__from_dict(): + # given + data = { + "attributeType": "artifact", + "attributeName": "some/artifact", + "hash": "f192cddb2b98c0b4c72bba22b68d2245", + } + + # when + result = ArtifactField.from_dict(data) + + # then + assert result.path == "some/artifact" + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__artifact_field__from_model(): + # given + model = Mock( + attributeType="artifact", + attributeName="some/artifact", + hash="f192cddb2b98c0b4c72bba22b68d2245", + ) + + # when + result = ArtifactField.from_model(model) + + # then + assert result.path == "some/artifact" + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field__from_dict__float(): + # given + data = { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/float" + assert isinstance(result, FloatField) + assert result.value == 18.5 + + +def test__field__from_model__float(): + # given + model = Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/float" + assert isinstance(result, FloatField) + assert result.value == 18.5 + + +def test__field__from_dict__int(): + # given + data = { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/int" + assert isinstance(result, IntField) + assert result.value == 18 + + +def test__field__from_model__int(): + # given + model = Mock( + path="some/int", type="int", intProperties=Mock(attributeType="int", attributeName="some/int", value=18) + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/int" + assert isinstance(result, IntField) + assert result.value == 18 + + +def test__field__from_dict__string(): + # given + data = { + "path": "some/string", + "type": "string", + "stringProperties": {"attributeType": "string", "attributeName": "some/string", "value": "hello"}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/string" + assert isinstance(result, StringField) + assert result.value == "hello" + + +def test__field__from_model__string(): + # given + model = Mock( + path="some/string", + type="string", + stringProperties=Mock(attributeType="string", attributeName="some/string", value="hello"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/string" + assert isinstance(result, StringField) + assert result.value == "hello" + + +def test__field__from_dict__bool(): + # given + data = { + "path": "some/bool", + "type": "bool", + "boolProperties": {"attributeType": "bool", "attributeName": "some/bool", "value": True}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/bool" + assert isinstance(result, BoolField) + assert result.value is True + + +def test__field__from_model__bool(): + # given + model = Mock( + path="some/bool", type="bool", boolProperties=Mock(attributeType="bool", attributeName="some/bool", value=True) + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/bool" + assert isinstance(result, BoolField) + assert result.value is True + + +def test__field__from_dict__datetime(): + # given + data = { + "path": "some/datetime", + "type": "datetime", + "datetimeProperties": { + "attributeType": "datetime", + "attributeName": "some/datetime", + "value": "2024-01-01T00:12:34.567890Z", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/datetime" + assert isinstance(result, DateTimeField) + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__field__from_model__datetime(): + # given + model = Mock( + path="some/datetime", + type="datetime", + datetimeProperties=Mock( + attributeType="datetime", attributeName="some/datetime", value="2024-01-01T00:12:34.567890Z" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/datetime" + assert isinstance(result, DateTimeField) + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__field__from_dict__float_series(): + # given + data = { + "path": "some/floatSeries", + "type": "floatSeries", + "floatSeriesProperties": {"attributeType": "floatSeries", "attributeName": "some/floatSeries", "last": 19.5}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert isinstance(result, FloatSeriesField) + assert result.last == 19.5 + + +def test__field__from_model__float_series(): + # given + model = Mock( + path="some/floatSeries", + type="floatSeries", + floatSeriesProperties=Mock(attributeType="floatSeries", attributeName="some/floatSeries", last=19.5), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert isinstance(result, FloatSeriesField) + assert result.last == 19.5 + + +def test__field__from_dict__string_series(): + # given + data = { + "path": "some/stringSeries", + "type": "stringSeries", + "stringSeriesProperties": { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + "last": "hello", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert isinstance(result, StringSeriesField) + assert result.last == "hello" + + +def test__field__from_model__string_series(): + # given + model = Mock( + path="some/stringSeries", + type="stringSeries", + stringSeriesProperties=Mock(attributeType="stringSeries", attributeName="some/stringSeries", last="hello"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert isinstance(result, StringSeriesField) + assert result.last == "hello" + + +def test__field__from_dict__image_series(): + # given + data = { + "path": "some/imageSeries", + "type": "imageSeries", + "imageSeriesProperties": { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + "lastStep": 15.0, + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert isinstance(result, ImageSeriesField) + assert result.last_step == 15.0 + + +def test__field__from_model__image_series(): + # given + model = Mock( + path="some/imageSeries", + type="imageSeries", + imageSeriesProperties=Mock(attributeType="imageSeries", attributeName="some/imageSeries", lastStep=15.0), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert isinstance(result, ImageSeriesField) + assert result.last_step == 15.0 + + +def test__field__from_dict__string_set(): + # given + data = { + "path": "some/stringSet", + "type": "stringSet", + "stringSetProperties": { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": ["hello", "world"], + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert isinstance(result, StringSetField) + assert result.values == {"hello", "world"} + + +def test__field__from_model__string_set(): + # given + model = Mock( + path="some/stringSet", + type="stringSet", + stringSetProperties=Mock(attributeType="stringSet", attributeName="some/stringSet", values=["hello", "world"]), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/stringSet" + assert isinstance(result, StringSetField) + assert result.values == {"hello", "world"} + + +def test__field__from_dict__file(): + # given + data = { + "path": "some/file", + "type": "file", + "fileProperties": { + "attributeType": "file", + "attributeName": "some/file", + "name": "file.txt", + "size": 1024, + "ext": "txt", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/file" + assert isinstance(result, FileField) + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__field__from_model__file(): + # given + model = Mock( + path="some/file", + type="file", + fileProperties=Mock(attributeType="file", attributeName="some/file", size=1024, ext="txt"), + ) + model.fileProperties.name = "file.txt" + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/file" + assert isinstance(result, FileField) + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__field__from_dict__object_state(): + # given + data = { + "path": "sys/state", + "type": "experimentState", + "experimentStateProperties": { + "attributeType": "experimentState", + "attributeName": "sys/state", + "value": "running", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "sys/state" + assert isinstance(result, ObjectStateField) + assert result.value == "Active" + + +def test__field__from_model__object_state(): + # given + model = Mock( + path="sys/state", + type="experimentState", + experimentStateProperties=Mock(attributeType="experimentState", attributeName="sys/state", value="running"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "sys/state" + assert isinstance(result, ObjectStateField) + assert result.value == "Active" + + +def test__field__from_dict__file_set(): + # given + data = { + "path": "some/fileSet", + "type": "fileSet", + "fileSetProperties": {"attributeType": "fileSet", "attributeName": "some/fileSet", "size": 3072}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/fileSet" + assert isinstance(result, FileSetField) + assert result.size == 3072 + + +def test__field__from_model__file_set(): + # given + model = Mock( + path="some/fileSet", + type="fileSet", + fileSetProperties=Mock(attributeType="fileSet", attributeName="some/fileSet", size=3072), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/fileSet" + assert isinstance(result, FileSetField) + assert result.size == 3072 + + +def test__field__from_dict__notebook_ref(): + # given + data = { + "path": "some/notebookRef", + "type": "notebookRef", + "notebookRefProperties": { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + "notebookName": "Data Processing.ipynb", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert isinstance(result, NotebookRefField) + assert result.notebook_name == "Data Processing.ipynb" + + +def test__field__from_model__notebook_ref(): + # given + model = Mock( + path="some/notebookRef", + type="notebookRef", + notebookRefProperties=Mock( + attributeType="notebookRef", attributeName="some/notebookRef", notebookName="Data Processing.ipynb" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert isinstance(result, NotebookRefField) + assert result.notebook_name == "Data Processing.ipynb" + + +def test__field__from_dict__git_ref(): + # given + data = { + "path": "some/gitRef", + "type": "gitRef", + "gitRefProperties": { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + "commit": {"commitId": "b2d7f8a"}, + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert isinstance(result, GitRefField) + assert result.commit.commit_id == "b2d7f8a" + + +def test__field__from_model__git_ref(): + # given + model = Mock( + path="some/gitRef", + type="gitRef", + gitRefProperties=Mock(attributeType="gitRef", attributeName="some/gitRef", commit=Mock(commitId="b2d7f8a")), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/gitRef" + assert isinstance(result, GitRefField) + assert result.commit.commit_id == "b2d7f8a" + + +def test__field__from_dict__artifact(): + # given + data = { + "path": "some/artifact", + "type": "artifact", + "artifactProperties": { + "attributeType": "artifact", + "attributeName": "some/artifact", + "hash": "f192cddb2b98c0b4c72bba22b68d2245", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/artifact" + assert isinstance(result, ArtifactField) + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field__from_model__artifact(): + # given + model = Mock( + path="some/artifact", + type="artifact", + artifactProperties=Mock( + attributeType="artifact", attributeName="some/artifact", hash="f192cddb2b98c0b4c72bba22b68d2245" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/artifact" + assert isinstance(result, ArtifactField) + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field_definition__from_dict(): + # given + data = { + "name": "some/float", + "type": "float", + } + + # when + result = FieldDefinition.from_dict(data) + + # then + assert result.path == "some/float" + assert result.type == FieldType.FLOAT + + +def test__field_definition__from_model(): + # given + model = Mock( + type="float", + ) + model.name = "some/float" + + # when + result = FieldDefinition.from_model(model) + + # then + assert result.path == "some/float" + assert result.type == FieldType.FLOAT + + +def test__leaderboard_entry__from_dict(): + # given + data = { + "experimentId": "some-id", + "attributes": [ + { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + }, + { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + }, + { + "path": "some/string", + "type": "string", + "stringProperties": {"attributeType": "string", "attributeName": "some/string", "value": "hello"}, + }, + ], + } + + # when + result = LeaderboardEntry.from_dict(data) + + # then + assert result.object_id == "some-id" + assert len(result.fields) == 3 + + float_field = result.fields[0] + assert isinstance(float_field, FloatField) + assert float_field.path == "some/float" + assert float_field.value == 18.5 + + int_field = result.fields[1] + assert isinstance(int_field, IntField) + assert int_field.path == "some/int" + + string_field = result.fields[2] + assert isinstance(string_field, StringField) + assert string_field.path == "some/string" + + +def test__leaderboard_entry__from_model(): + # given + model = Mock( + experimentId="some-id", + attributes=[ + Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ), + Mock( + path="some/int", type="int", intProperties=Mock(attributeType="int", attributeName="some/int", value=18) + ), + Mock( + path="some/string", + type="string", + stringProperties=Mock(attributeType="string", attributeName="some/string", value="hello"), + ), + ], + ) + + # when + result = LeaderboardEntry.from_model(model) + + # then + assert result.object_id == "some-id" + assert len(result.fields) == 3 + + float_field = result.fields[0] + assert isinstance(float_field, FloatField) + assert float_field.path == "some/float" + assert float_field.value == 18.5 + + int_field = result.fields[1] + assert isinstance(int_field, IntField) + assert int_field.path == "some/int" + + string_field = result.fields[2] + assert isinstance(string_field, StringField) + assert string_field.path == "some/string" + + +def test__leaderboard_entries_search_result__from_dict(): + # given + data = { + "matchingItemCount": 2, + "entries": [ + { + "experimentId": "some-id-1", + "attributes": [ + { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + }, + ], + }, + { + "experimentId": "some-id-2", + "attributes": [ + { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + }, + ], + }, + ], + } + + # when + result = LeaderboardEntriesSearchResult.from_dict(data) + + # then + assert result.matching_item_count == 2 + assert len(result.entries) == 2 + + entry_1 = result.entries[0] + assert entry_1.object_id == "some-id-1" + assert len(entry_1.fields) == 1 + assert isinstance(entry_1.fields[0], FloatField) + + entry_2 = result.entries[1] + assert entry_2.object_id == "some-id-2" + assert len(entry_2.fields) == 1 + assert isinstance(entry_2.fields[0], IntField) + + +def test__leaderboard_entries_search_result__from_model(): + # given + model = Mock( + matchingItemCount=2, + entries=[ + Mock( + experimentId="some-id-1", + attributes=[ + Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ), + ], + ), + Mock( + experimentId="some-id-2", + attributes=[ + Mock( + path="some/int", + type="int", + intProperties=Mock(attributeType="int", attributeName="some/int", value=18), + ), + ], + ), + ], + ) + + # when + result = LeaderboardEntriesSearchResult.from_model(model) + + # then + assert result.matching_item_count == 2 + assert len(result.entries) == 2 + + entry_1 = result.entries[0] + assert entry_1.object_id == "some-id-1" + assert len(entry_1.fields) == 1 + assert isinstance(entry_1.fields[0], FloatField) + + entry_2 = result.entries[1] + assert entry_2.object_id == "some-id-2" + assert len(entry_2.fields) == 1 + assert isinstance(entry_2.fields[0], IntField) + + +@pytest.mark.parametrize("field_type", list(FieldType)) +def test__all_field_types__have_class(field_type): + # when + field_class = Field.by_type(field_type) + + # then + assert field_class is not None + assert field_class.type == field_type From 6910c2b5235bd46d6206e47eecdfcd571089662c Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 15:56:06 +0200 Subject: [PATCH 16/22] CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04e3359fd..f25d780b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ ### Changes - Stop sending `X-Neptune-LegacyClient` header ([#1715](https://github.com/neptune-ai/neptune-client/pull/1715)) - Use `tqdm.auto` ([#1717](https://github.com/neptune-ai/neptune-client/pull/1717)) +- Fields DTO conversion reworked ([#1722](https://github.com/neptune-ai/neptune-client/pull/1722)) ### Features - ? From de8ea1caefca0bb75f6b80a403abb8d3bd8c10a3 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 16:01:17 +0200 Subject: [PATCH 17/22] FileEntry tests moved --- tests/unit/neptune/new/api/test_models.py | 21 ++++++++++++++++ .../neptune/new/internal/test_file_entry.py | 24 ------------------- 2 files changed, 21 insertions(+), 24 deletions(-) delete mode 100644 tests/unit/neptune/new/internal/test_file_entry.py diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py index ed6245f2b..8e09b1974 100644 --- a/tests/unit/neptune/new/api/test_models.py +++ b/tests/unit/neptune/new/api/test_models.py @@ -25,6 +25,7 @@ Field, FieldDefinition, FieldType, + FileEntry, FileField, FileSetField, FloatField, @@ -1470,3 +1471,23 @@ def test__all_field_types__have_class(field_type): # then assert field_class is not None assert field_class.type == field_type + + +def test__file_entry__from_model(): + # given + now = datetime.datetime.now() + + # and + model = Mock( + size=100, + mtime=now, + fileType="file", + ) + model.name = "mock_name" + + entry = FileEntry.from_dto(model) + + assert entry.name == "mock_name" + assert entry.size == 100 + assert entry.mtime == now + assert entry.file_type == "file" diff --git a/tests/unit/neptune/new/internal/test_file_entry.py b/tests/unit/neptune/new/internal/test_file_entry.py deleted file mode 100644 index 9e81b800c..000000000 --- a/tests/unit/neptune/new/internal/test_file_entry.py +++ /dev/null @@ -1,24 +0,0 @@ -import datetime -from dataclasses import dataclass - -from neptune.api.models import FileEntry - - -def test_file_entry_from_dto(): - now = datetime.datetime.now() - - @dataclass - class MockDto: - name: str - size: int - mtime: datetime.datetime - fileType: str - - dto = MockDto("mock_name", 100, now, "file") - - entry = FileEntry.from_dto(dto) - - assert entry.name == "mock_name" - assert entry.size == 100 - assert entry.mtime == now - assert entry.file_type == "file" From 63766321d8170fa2523acb2dba3143a916e1f414 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 17:04:43 +0200 Subject: [PATCH 18/22] Self review part 1 --- tests/unit/neptune/new/client/test_run_tables.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index 146855068..d1fe07f5f 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -22,7 +22,7 @@ from neptune import init_project from neptune.api.models import ( - Field, + DateTimeField, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -67,16 +67,9 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): LeaderboardEntry( object_id="123", fields=[ - Field.from_dict( - { - "type": "datetime", - "path": "sys/creation_time", - "datetimeProperties": { - "attributeName": "sys/creation_time", - "attributeType": "datetime", - "value": "2024-02-05T20:37:40.915000Z", - }, - } + DateTimeField( + path="sys/creation_time", + value=datetime(2024, 2, 5, 20, 37, 40, 915000), ) ], ) From 42c96820798ab448aad3a6e76a266798a39c78a8 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 17:17:56 +0200 Subject: [PATCH 19/22] Self review part 2 --- .../internal/backends/hosted_artifact_operations.py | 2 +- tests/unit/neptune/new/client/test_model.py | 3 ++- tests/unit/neptune/new/client/test_model_version.py | 7 ++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/neptune/internal/backends/hosted_artifact_operations.py b/src/neptune/internal/backends/hosted_artifact_operations.py index 672d102fb..ec722b5e2 100644 --- a/src/neptune/internal/backends/hosted_artifact_operations.py +++ b/src/neptune/internal/backends/hosted_artifact_operations.py @@ -261,7 +261,7 @@ def get_artifact_attribute( } try: result = swagger_client.api.getArtifactAttribute(**params).response().result - return ArtifactField(path=path_to_str(path), hash=result.hash) + return ArtifactField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index 0041eb1a6..30acb6c40 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -39,6 +39,7 @@ ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, @@ -81,7 +82,7 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntField(42), + new=lambda _, _uuid, _type, _path: IntField(path=path_to_str(_path), value=42), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index 0ce87d90e..14ca4404b 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -42,6 +42,7 @@ from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, @@ -93,11 +94,11 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntField(42), + new=lambda _, _uuid, _type, _path: IntField(path=path_to_str(_path), value=42), ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringField("MDL"), + new=lambda _, _uuid, _type, _path: StringField(path=path_to_str(_path), value="MDL"), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -121,7 +122,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringField("MDL"), + new=lambda _, _uuid, _type, _path: StringField(path=path_to_str(_path), value="MDL"), ) def test_resume(self): with init_model_version(flush_period=0.5, with_id="whatever") as exp: From 5f93c922f98eb2da46e944033d8e5624293740f3 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Fri, 5 Apr 2024 17:45:35 +0200 Subject: [PATCH 20/22] Self review part 3 --- src/neptune/api/searching_entries.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index d2eb222e0..e1f6cc572 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -206,10 +206,10 @@ def iter_over_pages( while True: if last_page: - searching_after_filed = find_attribute(entry=last_page[-1], path=sort_by) - if not searching_after_filed: + searching_after_field = find_attribute(entry=last_page[-1], path=sort_by) + if not searching_after_field: raise ValueError(f"Cannot find attribute {sort_by} in last page") - searching_after = field_to_value_visitor.visit(searching_after_filed) + searching_after = field_to_value_visitor.visit(searching_after_field) for offset in range(0, max_offset, step_size): local_limit = min(step_size, max_offset - offset) From 157f8e4179fbb345afb34b5827d4f8c7b15523c2 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 8 Apr 2024 16:48:16 +0200 Subject: [PATCH 21/22] Code review --- src/neptune/api/models.py | 8 ++++---- src/neptune/internal/backends/hosted_neptune_backend.py | 9 --------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 8514cf281..8c0d1fd3c 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -292,7 +292,7 @@ def from_model(model: Any) -> FileSetField: class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES): last: Optional[float] - def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_float_series(self) @staticmethod @@ -309,7 +309,7 @@ def from_model(model: Any) -> FloatSeriesField: class StringSeriesField(Field, field_type=FieldType.STRING_SERIES): last: Optional[str] - def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_string_series(self) @staticmethod @@ -326,7 +326,7 @@ def from_model(model: Any) -> StringSeriesField: class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES): last_step: Optional[float] - def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_image_series(self) @staticmethod @@ -343,7 +343,7 @@ def from_model(model: Any) -> ImageSeriesField: class StringSetField(Field, field_type=FieldType.STRING_SET): values: Set[str] - def accept(self, visitor: "FieldVisitor[Ret]") -> Ret: + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: return visitor.visit_string_set(self) @staticmethod diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 9f8974b87..4c416c798 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -165,15 +165,6 @@ FieldType.OBJECT_STATE.value, } -ATOMIC_ATTRIBUTE_TYPES = { - FieldType.INT.value, - FieldType.FLOAT.value, - FieldType.STRING.value, - FieldType.BOOL.value, - FieldType.DATETIME.value, - FieldType.OBJECT_STATE.value, -} - class HostedNeptuneBackend(NeptuneBackend): def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] = None): From e80c971951cf26a4d46ce0933f06d65ca2643d15 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 8 Apr 2024 17:17:59 +0200 Subject: [PATCH 22/22] Code review 2 --- src/neptune/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py index 8c0d1fd3c..a1604ed9d 100644 --- a/src/neptune/api/models.py +++ b/src/neptune/api/models.py @@ -97,7 +97,7 @@ class FieldType(Enum): @dataclass class Field(abc.ABC): path: str - type: FieldType = dataclass_field(init=False) + type: ClassVar[FieldType] = dataclass_field(init=False) _registry: ClassVar[Dict[str, Type[Field]]] = {} def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> None: