diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a5469fe..ac9f937af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ ### Changes - Stop sending `X-Neptune-LegacyClient` header ([#1715](https://github.com/neptune-ai/neptune-client/pull/1715)) - Use `tqdm.auto` ([#1717](https://github.com/neptune-ai/neptune-client/pull/1717)) +- Fields DTO conversion reworked ([#1722](https://github.com/neptune-ai/neptune-client/pull/1722)) ### Features - ? diff --git a/src/neptune/api/dtos.py b/src/neptune/api/dtos.py deleted file mode 100644 index 2895b43df..000000000 --- a/src/neptune/api/dtos.py +++ /dev/null @@ -1,33 +0,0 @@ -# -# Copyright (c) 2024, Neptune Labs Sp. z o.o. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -__all__ = ["FileEntry"] - -import datetime -from dataclasses import dataclass -from typing import Any - - -@dataclass -class FileEntry: - name: str - size: int - mtime: datetime.datetime - file_type: str - - @classmethod - def from_dto(cls, file_dto: Any) -> "FileEntry": - return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType) diff --git a/src/neptune/api/field_visitor.py b/src/neptune/api/field_visitor.py new file mode 100644 index 000000000..acc4c8d51 --- /dev/null +++ b/src/neptune/api/field_visitor.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__all__ = ("FieldToValueVisitor",) + +from datetime import datetime +from typing import ( + Any, + Optional, + Set, +) + +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldVisitor, + FileField, + FileSetField, + FloatField, + FloatSeriesField, + GitRefField, + ImageSeriesField, + IntField, + NotebookRefField, + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, +) +from neptune.exceptions import MetadataInconsistency + + +class FieldToValueVisitor(FieldVisitor[Any]): + + def visit_float(self, field: FloatField) -> float: + return field.value + + def visit_int(self, field: IntField) -> int: + return field.value + + def visit_bool(self, field: BoolField) -> bool: + return field.value + + def visit_string(self, field: StringField) -> str: + return field.value + + def visit_datetime(self, field: DateTimeField) -> datetime: + return field.value + + def visit_file(self, field: FileField) -> None: + raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") + + def visit_file_set(self, field: FileSetField) -> None: + raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") + + def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: + return field.last + + def visit_string_series(self, field: StringSeriesField) -> Optional[str]: + return field.last + + def visit_image_series(self, field: ImageSeriesField) -> None: + raise MetadataInconsistency("Cannot get value for image series.") + + def visit_string_set(self, field: StringSetField) -> Set[str]: + return field.values + + def visit_git_ref(self, field: GitRefField) -> Optional[str]: + return field.commit.commit_id if field.commit is not None else None + + def visit_object_state(self, field: ObjectStateField) -> str: + return field.value + + def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: + return field.notebook_name + + def visit_artifact(self, field: ArtifactField) -> str: + return field.hash diff --git a/src/neptune/api/models.py b/src/neptune/api/models.py new file mode 100644 index 000000000..a1604ed9d --- /dev/null +++ b/src/neptune/api/models.py @@ -0,0 +1,490 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +__all__ = ( + "FileEntry", + "Field", + "FieldType", + "GitCommit", + "LeaderboardEntry", + "LeaderboardEntriesSearchResult", + "FieldVisitor", + "FloatField", + "IntField", + "BoolField", + "StringField", + "DateTimeField", + "FileField", + "FileSetField", + "FloatSeriesField", + "StringSeriesField", + "ImageSeriesField", + "StringSetField", + "GitRefField", + "ObjectStateField", + "NotebookRefField", + "ArtifactField", + "FieldDefinition", +) + +import abc +from dataclasses import dataclass +from dataclasses import field as dataclass_field +from datetime import datetime +from enum import Enum +from typing import ( + Any, + ClassVar, + Dict, + Generic, + List, + Optional, + Set, + Type, + TypeVar, +) + +from neptune.internal.utils.iso_dates import parse_iso_date +from neptune.internal.utils.run_state import RunState + +Ret = TypeVar("Ret") + + +@dataclass +class FileEntry: + name: str + size: int + mtime: datetime + file_type: str + + @classmethod + def from_dto(cls, file_dto: Any) -> "FileEntry": + return cls(name=file_dto.name, size=file_dto.size, mtime=file_dto.mtime, file_type=file_dto.fileType) + + +class FieldType(Enum): + FLOAT = "float" + INT = "int" + BOOL = "bool" + STRING = "string" + DATETIME = "datetime" + FILE = "file" + FILE_SET = "fileSet" + FLOAT_SERIES = "floatSeries" + STRING_SERIES = "stringSeries" + IMAGE_SERIES = "imageSeries" + STRING_SET = "stringSet" + GIT_REF = "gitRef" + OBJECT_STATE = "experimentState" + NOTEBOOK_REF = "notebookRef" + ARTIFACT = "artifact" + + +@dataclass +class Field(abc.ABC): + path: str + type: ClassVar[FieldType] = dataclass_field(init=False) + _registry: ClassVar[Dict[str, Type[Field]]] = {} + + def __init_subclass__(cls, *args: Any, field_type: FieldType, **kwargs: Any) -> None: + super().__init_subclass__(*args, **kwargs) + cls.type = field_type + cls._registry[field_type.value] = cls + + @classmethod + def by_type(cls, field_type: FieldType) -> Type[Field]: + return cls._registry[field_type.value] + + @abc.abstractmethod + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: ... + + @staticmethod + def from_dict(data: Dict[str, Any]) -> Field: + field_type = data["type"] + return Field._registry[field_type].from_dict(data[f"{field_type}Properties"]) + + @staticmethod + def from_model(model: Any) -> Field: + field_type = str(model.type) + return Field._registry[field_type].from_model(model.__getattribute__(f"{field_type}Properties")) + + +class FieldVisitor(Generic[Ret], abc.ABC): + + def visit(self, field: Field) -> Ret: + return field.accept(self) + + @abc.abstractmethod + def visit_float(self, field: FloatField) -> Ret: ... + + @abc.abstractmethod + def visit_int(self, field: IntField) -> Ret: ... + + @abc.abstractmethod + def visit_bool(self, field: BoolField) -> Ret: ... + + @abc.abstractmethod + def visit_string(self, field: StringField) -> Ret: ... + + @abc.abstractmethod + def visit_datetime(self, field: DateTimeField) -> Ret: ... + + @abc.abstractmethod + def visit_file(self, field: FileField) -> Ret: ... + + @abc.abstractmethod + def visit_file_set(self, field: FileSetField) -> Ret: ... + + @abc.abstractmethod + def visit_float_series(self, field: FloatSeriesField) -> Ret: ... + + @abc.abstractmethod + def visit_string_series(self, field: StringSeriesField) -> Ret: ... + + @abc.abstractmethod + def visit_image_series(self, field: ImageSeriesField) -> Ret: ... + + @abc.abstractmethod + def visit_string_set(self, field: StringSetField) -> Ret: ... + + @abc.abstractmethod + def visit_git_ref(self, field: GitRefField) -> Ret: ... + + @abc.abstractmethod + def visit_object_state(self, field: ObjectStateField) -> Ret: ... + + @abc.abstractmethod + def visit_notebook_ref(self, field: NotebookRefField) -> Ret: ... + + @abc.abstractmethod + def visit_artifact(self, field: ArtifactField) -> Ret: ... + + +@dataclass +class FloatField(Field, field_type=FieldType.FLOAT): + value: float + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_float(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FloatField: + return FloatField(path=data["attributeName"], value=float(data["value"])) + + @staticmethod + def from_model(model: Any) -> FloatField: + return FloatField(path=model.attributeName, value=model.value) + + +@dataclass +class IntField(Field, field_type=FieldType.INT): + value: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_int(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> IntField: + return IntField(path=data["attributeName"], value=int(data["value"])) + + @staticmethod + def from_model(model: Any) -> IntField: + return IntField(path=model.attributeName, value=model.value) + + +@dataclass +class BoolField(Field, field_type=FieldType.BOOL): + value: bool + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_bool(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> BoolField: + return BoolField(path=data["attributeName"], value=bool(data["value"])) + + @staticmethod + def from_model(model: Any) -> BoolField: + return BoolField(path=model.attributeName, value=model.value) + + +@dataclass +class StringField(Field, field_type=FieldType.STRING): + value: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_string(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringField: + return StringField(path=data["attributeName"], value=str(data["value"])) + + @staticmethod + def from_model(model: Any) -> StringField: + return StringField(path=model.attributeName, value=model.value) + + +@dataclass +class DateTimeField(Field, field_type=FieldType.DATETIME): + value: datetime + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_datetime(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> DateTimeField: + return DateTimeField(path=data["attributeName"], value=parse_iso_date(data["value"])) + + @staticmethod + def from_model(model: Any) -> DateTimeField: + return DateTimeField(path=model.attributeName, value=parse_iso_date(model.value)) + + +@dataclass +class FileField(Field, field_type=FieldType.FILE): + name: str + ext: str + size: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_file(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FileField: + return FileField(path=data["attributeName"], name=data["name"], ext=data["ext"], size=int(data["size"])) + + @staticmethod + def from_model(model: Any) -> FileField: + return FileField(path=model.attributeName, name=model.name, ext=model.ext, size=model.size) + + +@dataclass +class FileSetField(Field, field_type=FieldType.FILE_SET): + size: int + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_file_set(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FileSetField: + return FileSetField(path=data["attributeName"], size=int(data["size"])) + + @staticmethod + def from_model(model: Any) -> FileSetField: + return FileSetField(path=model.attributeName, size=model.size) + + +@dataclass +class FloatSeriesField(Field, field_type=FieldType.FLOAT_SERIES): + last: Optional[float] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_float_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FloatSeriesField: + last = float(data["last"]) if "last" in data else None + return FloatSeriesField(path=data["attributeName"], last=last) + + @staticmethod + def from_model(model: Any) -> FloatSeriesField: + return FloatSeriesField(path=model.attributeName, last=model.last) + + +@dataclass +class StringSeriesField(Field, field_type=FieldType.STRING_SERIES): + last: Optional[str] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_string_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringSeriesField: + last = str(data["last"]) if "last" in data else None + return StringSeriesField(path=data["attributeName"], last=last) + + @staticmethod + def from_model(model: Any) -> StringSeriesField: + return StringSeriesField(path=model.attributeName, last=model.last) + + +@dataclass +class ImageSeriesField(Field, field_type=FieldType.IMAGE_SERIES): + last_step: Optional[float] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_image_series(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> ImageSeriesField: + last_step = float(data["lastStep"]) if "lastStep" in data else None + return ImageSeriesField(path=data["attributeName"], last_step=last_step) + + @staticmethod + def from_model(model: Any) -> ImageSeriesField: + return ImageSeriesField(path=model.attributeName, last_step=model.lastStep) + + +@dataclass +class StringSetField(Field, field_type=FieldType.STRING_SET): + values: Set[str] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_string_set(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> StringSetField: + return StringSetField(path=data["attributeName"], values=set(map(str, data["values"]))) + + @staticmethod + def from_model(model: Any) -> StringSetField: + return StringSetField(path=model.attributeName, values=set(model.values)) + + +@dataclass +class GitCommit: + commit_id: Optional[str] + + @staticmethod + def from_dict(data: Dict[str, Any]) -> GitCommit: + commit_id = str(data["commitId"]) if "commitId" in data else None + return GitCommit(commit_id=commit_id) + + @staticmethod + def from_model(model: Any) -> GitCommit: + return GitCommit(commit_id=model.commitId) + + +@dataclass +class GitRefField(Field, field_type=FieldType.GIT_REF): + commit: Optional[GitCommit] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_git_ref(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> GitRefField: + commit = GitCommit.from_dict(data["commit"]) if "commit" in data else None + return GitRefField(path=data["attributeName"], commit=commit) + + @staticmethod + def from_model(model: Any) -> GitRefField: + commit = GitCommit.from_model(model.commit) if model.commit is not None else None + return GitRefField(path=model.attributeName, commit=commit) + + +@dataclass +class ObjectStateField(Field, field_type=FieldType.OBJECT_STATE): + value: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_object_state(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> ObjectStateField: + value = RunState.from_api(str(data["value"])).value + return ObjectStateField(path=data["attributeName"], value=value) + + @staticmethod + def from_model(model: Any) -> ObjectStateField: + value = RunState.from_api(str(model.value)).value + return ObjectStateField(path=model.attributeName, value=value) + + +@dataclass +class NotebookRefField(Field, field_type=FieldType.NOTEBOOK_REF): + notebook_name: Optional[str] + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_notebook_ref(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> NotebookRefField: + notebook_name = str(data["notebookName"]) if "notebookName" in data else None + return NotebookRefField(path=data["attributeName"], notebook_name=notebook_name) + + @staticmethod + def from_model(model: Any) -> NotebookRefField: + return NotebookRefField(path=model.attributeName, notebook_name=model.notebookName) + + +@dataclass +class ArtifactField(Field, field_type=FieldType.ARTIFACT): + hash: str + + def accept(self, visitor: FieldVisitor[Ret]) -> Ret: + return visitor.visit_artifact(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> ArtifactField: + return ArtifactField(path=data["attributeName"], hash=str(data["hash"])) + + @staticmethod + def from_model(model: Any) -> ArtifactField: + return ArtifactField(path=model.attributeName, hash=model.hash) + + +@dataclass +class LeaderboardEntry: + object_id: str + fields: List[Field] + + @staticmethod + def from_dict(data: Dict[str, Any]) -> LeaderboardEntry: + return LeaderboardEntry( + object_id=data["experimentId"], fields=[Field.from_dict(field) for field in data["attributes"]] + ) + + @staticmethod + def from_model(model: Any) -> LeaderboardEntry: + return LeaderboardEntry( + object_id=model.experimentId, fields=[Field.from_model(field) for field in model.attributes] + ) + + +@dataclass +class LeaderboardEntriesSearchResult: + entries: List[LeaderboardEntry] + matching_item_count: int + + @staticmethod + def from_dict(result: Dict[str, Any]) -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + entries=[LeaderboardEntry.from_dict(entry) for entry in result.get("entries", [])], + matching_item_count=result["matchingItemCount"], + ) + + @staticmethod + def from_model(result: Any) -> LeaderboardEntriesSearchResult: + return LeaderboardEntriesSearchResult( + entries=[LeaderboardEntry.from_model(entry) for entry in result.entries], + matching_item_count=result.matchingItemCount, + ) + + +@dataclass +class FieldDefinition: + path: str + type: FieldType + + @staticmethod + def from_dict(data: Dict[str, Any]) -> FieldDefinition: + return FieldDefinition(path=data["name"], type=FieldType(data["type"])) + + @staticmethod + def from_model(model: Any) -> FieldDefinition: + return FieldDefinition(path=model.name, type=FieldType(model.type)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 7c9f0007d..e1f6cc572 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -21,7 +21,6 @@ Dict, Generator, Iterable, - List, Optional, ) @@ -33,12 +32,14 @@ TypeAlias, ) -from neptune.exceptions import NeptuneInvalidQueryException -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, +from neptune.api.field_visitor import FieldToValueVisitor +from neptune.api.models import ( + Field, + FieldType, + LeaderboardEntriesSearchResult, LeaderboardEntry, ) +from neptune.exceptions import NeptuneInvalidQueryException from neptune.internal.backends.hosted_client import DEFAULT_REQUEST_KWARGS from neptune.internal.backends.nql import ( NQLAggregator, @@ -58,7 +59,7 @@ from neptune.internal.id_formats import UniqueId -SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in AttributeType} +SUPPORTED_ATTRIBUTE_TYPES = {item.value for item in FieldType} SORT_BY_COLUMN_TYPE: TypeAlias = Literal["string", "datetime", "integer", "boolean", "float"] @@ -98,7 +99,7 @@ def get_single_page( searching_after: Optional[str], ) -> Any: normalized_query = query or NQLEmptyQuery() - sort_by_column_type = sort_by_column_type if sort_by_column_type else AttributeType.STRING.value + sort_by_column_type = sort_by_column_type if sort_by_column_type else FieldType.STRING.value if sort_by and searching_after: sort_by_as_nql = NQLQueryAttribute( name=sort_by, @@ -119,7 +120,7 @@ def get_single_page( "aggregationMode": "none", "sortBy": { "name": sort_by, - "type": sort_by_column_type if sort_by_column_type else AttributeType.STRING.value, + "type": sort_by_column_type if sort_by_column_type else FieldType.STRING.value, }, } } @@ -157,23 +158,8 @@ def get_single_page( raise e -def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: - return LeaderboardEntry( - id=entry["experimentId"], - attributes=[ - AttributeWithProperties( - path=attr["name"], - type=AttributeType(attr["type"]), - properties=attr.__getitem__(f"{attr['type']}Properties"), - ) - for attr in entry["attributes"] - if attr["type"] in SUPPORTED_ATTRIBUTE_TYPES - ], - ) - - -def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[AttributeWithProperties]: - return next((attr for attr in entry.attributes if attr.path == path), None) +def find_attribute(*, entry: LeaderboardEntry, path: str) -> Optional[Field]: + return next((attr for attr in entry.fields if attr.path == path), None) def iter_over_pages( @@ -190,7 +176,7 @@ def iter_over_pages( searching_after = None last_page = None - total = get_single_page( + data = get_single_page( limit=0, offset=0, sort_by=sort_by, @@ -198,7 +184,8 @@ def iter_over_pages( sort_by_column_type=sort_by_column_type, searching_after=None, **kwargs, - ).get("matchingItemCount", 0) + ) + total = LeaderboardEntriesSearchResult.from_dict(data).matching_item_count limit = limit if limit is not None else NoLimit() @@ -208,6 +195,8 @@ def iter_over_pages( extracted_records = 0 + field_to_value_visitor = FieldToValueVisitor() + with construct_progress_bar(progress_bar, "Fetching table...") as bar: # beginning of the first page bar.update( @@ -217,18 +206,17 @@ def iter_over_pages( while True: if last_page: - page_attribute = find_attribute(entry=last_page[-1], path=sort_by) - - if not page_attribute: + searching_after_field = find_attribute(entry=last_page[-1], path=sort_by) + if not searching_after_field: raise ValueError(f"Cannot find attribute {sort_by} in last page") - - searching_after = page_attribute.properties["value"] + searching_after = field_to_value_visitor.visit(searching_after_field) for offset in range(0, max_offset, step_size): local_limit = min(step_size, max_offset - offset) if extracted_records + local_limit > limit: local_limit = limit - extracted_records - result = get_single_page( + + data = get_single_page( limit=local_limit, offset=offset, sort_by=sort_by, @@ -237,14 +225,15 @@ def iter_over_pages( ascending=ascending, **kwargs, ) + result = LeaderboardEntriesSearchResult.from_dict(data) # fetch the item count everytime a new page is started (except for the very fist page) if offset == 0 and last_page is not None: - total += result.get("matchingItemCount", 0) + total += result.matching_item_count total = min(total, limit) - page = _entries_from_page(result) + page = result.entries extracted_records += len(page) bar.update(by=len(page), total=total) @@ -257,7 +246,3 @@ def iter_over_pages( return last_page = page - - -def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]: - return list(map(to_leaderboard_entry, single_page.get("entries", []))) diff --git a/src/neptune/attributes/file_set.py b/src/neptune/attributes/file_set.py index 922522a23..5f32a4705 100644 --- a/src/neptune/attributes/file_set.py +++ b/src/neptune/attributes/file_set.py @@ -23,7 +23,7 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.attributes.attribute import Attribute from neptune.internal.operation import ( DeleteFiles, diff --git a/src/neptune/attributes/utils.py b/src/neptune/attributes/utils.py index ea50f1f1f..18d83410b 100644 --- a/src/neptune/attributes/utils.py +++ b/src/neptune/attributes/utils.py @@ -20,6 +20,7 @@ List, ) +from neptune.api.models import FieldType from neptune.attributes import ( Artifact, Boolean, @@ -37,7 +38,6 @@ StringSeries, StringSet, ) -from neptune.internal.backends.api_model import AttributeType from neptune.internal.exceptions import InternalClientError if TYPE_CHECKING: @@ -45,26 +45,26 @@ from neptune.objects import NeptuneObject _attribute_type_to_attr_class_map = { - AttributeType.FLOAT: Float, - AttributeType.INT: Integer, - AttributeType.BOOL: Boolean, - AttributeType.STRING: String, - AttributeType.DATETIME: Datetime, - AttributeType.FILE: File, - AttributeType.FILE_SET: FileSet, - AttributeType.FLOAT_SERIES: FloatSeries, - AttributeType.STRING_SERIES: StringSeries, - AttributeType.IMAGE_SERIES: FileSeries, - AttributeType.STRING_SET: StringSet, - AttributeType.GIT_REF: GitRef, - AttributeType.RUN_STATE: RunState, - AttributeType.NOTEBOOK_REF: NotebookRef, - AttributeType.ARTIFACT: Artifact, + FieldType.FLOAT: Float, + FieldType.INT: Integer, + FieldType.BOOL: Boolean, + FieldType.STRING: String, + FieldType.DATETIME: Datetime, + FieldType.FILE: File, + FieldType.FILE_SET: FileSet, + FieldType.FLOAT_SERIES: FloatSeries, + FieldType.STRING_SERIES: StringSeries, + FieldType.IMAGE_SERIES: FileSeries, + FieldType.STRING_SET: StringSet, + FieldType.GIT_REF: GitRef, + FieldType.OBJECT_STATE: RunState, + FieldType.NOTEBOOK_REF: NotebookRef, + FieldType.ARTIFACT: Artifact, } def create_attribute_from_type( - attribute_type: AttributeType, + attribute_type: FieldType, container: "NeptuneObject", path: List[str], ) -> "Attribute": diff --git a/src/neptune/handler.py b/src/neptune/handler.py index 4607eba6c..6e5f8d28c 100644 --- a/src/neptune/handler.py +++ b/src/neptune/handler.py @@ -28,7 +28,7 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import FileEntry from neptune.attributes import File from neptune.attributes.atoms.artifact import Artifact from neptune.attributes.constants import SYSTEM_STAGE_ATTRIBUTE_PATH diff --git a/src/neptune/integrations/pandas/__init__.py b/src/neptune/integrations/pandas/__init__.py index 268bc9ee4..52161c7f5 100644 --- a/src/neptune/integrations/pandas/__init__.py +++ b/src/neptune/integrations/pandas/__init__.py @@ -20,81 +20,112 @@ from datetime import datetime from typing import ( TYPE_CHECKING, - Any, Dict, + Optional, Tuple, Union, ) import pandas as pd -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldVisitor, + FileField, + FileSetField, + FloatField, + FloatSeriesField, + GitRefField, + ImageSeriesField, + IntField, LeaderboardEntry, + NotebookRefField, + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, ) -from neptune.internal.utils.logger import get_logger -from neptune.internal.utils.run_state import RunState if TYPE_CHECKING: from neptune.table import Table -logger = get_logger() +PANDAS_AVAILABLE_TYPES = Union[str, float, int, bool, datetime, None] -def to_pandas(table: Table) -> pd.DataFrame: - def make_attribute_value(attribute: AttributeWithProperties) -> Any: - _type = attribute.type - _properties = attribute.properties - if _type == AttributeType.RUN_STATE: - return RunState.from_api(_properties.get("value")).value - if _type in ( - AttributeType.FLOAT, - AttributeType.INT, - AttributeType.BOOL, - AttributeType.STRING, - AttributeType.DATETIME, - ): - return _properties.get("value") - if _type == AttributeType.FLOAT_SERIES: - return _properties.get("last") - if _type == AttributeType.STRING_SERIES: - return _properties.get("last") - if _type == AttributeType.IMAGE_SERIES: - return None - if _type == AttributeType.FILE or _type == AttributeType.FILE_SET: - return None - if _type == AttributeType.STRING_SET: - return ",".join(_properties.get("values")) - if _type == AttributeType.GIT_REF: - return _properties.get("commit", {}).get("commitId") - if _type == AttributeType.NOTEBOOK_REF: - return _properties.get("notebookName") - if _type == AttributeType.ARTIFACT: - return _properties.get("hash") - logger.error( - "Attribute type %s not supported in this version, yielding None. Recommended client upgrade.", - _type, - ) +class FieldToPandasValueVisitor(FieldVisitor[PANDAS_AVAILABLE_TYPES]): + + def visit_float(self, field: FloatField) -> float: + return field.value + + def visit_int(self, field: IntField) -> int: + return field.value + + def visit_bool(self, field: BoolField) -> bool: + return field.value + + def visit_string(self, field: StringField) -> str: + return field.value + + def visit_datetime(self, field: DateTimeField) -> datetime: + return field.value + + def visit_file(self, field: FileField) -> None: + return None + + def visit_string_set(self, field: StringSetField) -> Optional[str]: + return ",".join(field.values) + + def visit_float_series(self, field: FloatSeriesField) -> Optional[float]: + return field.last + + def visit_string_series(self, field: StringSeriesField) -> Optional[str]: + return field.last + + def visit_image_series(self, field: ImageSeriesField) -> None: + return None + + def visit_file_set(self, field: FileSetField) -> None: return None - def make_row(entry: LeaderboardEntry) -> Dict[str, Any]: - row: Dict[str, Union[str, float, datetime]] = dict() - for attr in entry.attributes: - value = make_attribute_value(attr) - if value is not None: - row[attr.path] = value - return row - - def sort_key(attr: str) -> Tuple[int, str]: - domain = attr.split("/")[0] - if domain == "sys": - return 0, attr - if domain == "monitoring": - return 2, attr - return 1, attr - - rows = dict((n, make_row(entry)) for (n, entry) in enumerate(table._entries)) + def visit_git_ref(self, field: GitRefField) -> Optional[str]: + return field.commit.commit_id if field.commit is not None else None + + def visit_object_state(self, field: ObjectStateField) -> str: + return field.value + + def visit_notebook_ref(self, field: NotebookRefField) -> Optional[str]: + return field.notebook_name + + def visit_artifact(self, field: ArtifactField) -> str: + return field.hash + + +def make_row(entry: LeaderboardEntry, to_value_visitor: FieldVisitor) -> Dict[str, PANDAS_AVAILABLE_TYPES]: + row: Dict[str, PANDAS_AVAILABLE_TYPES] = dict() + + for field in entry.fields: + value = to_value_visitor.visit(field) + if value is not None: + row[field.path] = value + + return row + + +def sort_key(field: str) -> Tuple[int, str]: + domain = field.split("/")[0] + if domain == "sys": + return 0, field + if domain == "monitoring": + return 2, field + return 1, field + + +def to_pandas(table: Table) -> pd.DataFrame: + + to_value_visitor = FieldToPandasValueVisitor() + rows = dict((n, make_row(entry, to_value_visitor)) for (n, entry) in enumerate(table._entries)) df = pd.DataFrame.from_dict(data=rows, orient="index") df = df.reindex(sorted(df.columns, key=sort_key), axis="columns") diff --git a/src/neptune/internal/backends/api_model.py b/src/neptune/internal/backends/api_model.py index 1116de3a5..660d1c7e5 100644 --- a/src/neptune/internal/backends/api_model.py +++ b/src/neptune/internal/backends/api_model.py @@ -20,38 +20,20 @@ "OptionalFeatures", "VersionInfo", "ClientConfig", - "AttributeType", - "Attribute", - "AttributeWithProperties", - "LeaderboardEntry", "StringPointValue", "ImageSeriesValues", "StringSeriesValues", "FloatPointValue", "FloatSeriesValues", - "FloatAttribute", - "IntAttribute", - "BoolAttribute", - "FileAttribute", - "StringAttribute", - "DatetimeAttribute", - "ArtifactAttribute", "ArtifactModel", - "FloatSeriesAttribute", - "StringSeriesAttribute", - "StringSetAttribute", "MultipartConfig", ] from dataclasses import dataclass -from datetime import datetime -from enum import Enum from typing import ( - Any, FrozenSet, List, Optional, - Set, ) from packaging import version @@ -199,43 +181,6 @@ def from_api_response(config) -> "ClientConfig": ) -class AttributeType(Enum): - FLOAT = "float" - INT = "int" - BOOL = "bool" - STRING = "string" - DATETIME = "datetime" - FILE = "file" - FILE_SET = "fileSet" - FLOAT_SERIES = "floatSeries" - STRING_SERIES = "stringSeries" - IMAGE_SERIES = "imageSeries" - STRING_SET = "stringSet" - GIT_REF = "gitRef" - RUN_STATE = "experimentState" - NOTEBOOK_REF = "notebookRef" - ARTIFACT = "artifact" - - -@dataclass -class Attribute: - path: str - type: AttributeType - - -@dataclass -class AttributeWithProperties: - path: str - type: AttributeType - properties: Any - - -@dataclass -class LeaderboardEntry: - id: str - attributes: List[AttributeWithProperties] - - @dataclass class StringPointValue: timestampMillis: int @@ -267,60 +212,8 @@ class FloatSeriesValues: values: List[FloatPointValue] -@dataclass -class FloatAttribute: - value: float - - -@dataclass -class IntAttribute: - value: int - - -@dataclass -class BoolAttribute: - value: bool - - -@dataclass -class FileAttribute: - name: str - ext: str - size: int - - -@dataclass -class StringAttribute: - value: str - - -@dataclass -class DatetimeAttribute: - value: datetime - - -@dataclass -class ArtifactAttribute: - hash: str - - @dataclass class ArtifactModel: received_metadata: bool hash: str size: int - - -@dataclass -class FloatSeriesAttribute: - last: Optional[float] - - -@dataclass -class StringSeriesAttribute: - last: Optional[str] - - -@dataclass -class StringSetAttribute: - values: Set[str] diff --git a/src/neptune/internal/backends/hosted_artifact_operations.py b/src/neptune/internal/backends/hosted_artifact_operations.py index 497e85fbc..ec722b5e2 100644 --- a/src/neptune/internal/backends/hosted_artifact_operations.py +++ b/src/neptune/internal/backends/hosted_artifact_operations.py @@ -30,6 +30,7 @@ from bravado.exception import HTTPNotFound +from neptune.api.models import ArtifactField from neptune.exceptions import ( ArtifactNotFoundException, ArtifactUploadingError, @@ -42,10 +43,7 @@ ArtifactDriversMap, ArtifactFileData, ) -from neptune.internal.backends.api_model import ( - ArtifactAttribute, - ArtifactModel, -) +from neptune.internal.backends.api_model import ArtifactModel from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper from neptune.internal.backends.utils import with_api_exceptions_handler from neptune.internal.operation import ( @@ -254,7 +252,7 @@ def get_artifact_attribute( parent_identifier: str, path: List[str], default_request_params: Dict, -) -> ArtifactAttribute: +) -> ArtifactField: requests_params = add_artifact_version_to_request_params(default_request_params) params = { "experimentId": parent_identifier, @@ -263,7 +261,7 @@ def get_artifact_attribute( } try: result = swagger_client.api.getArtifactAttribute(**params).response().result - return ArtifactAttribute(hash=result.hash) + return ArtifactField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 1a51b930f..4c416c798 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -38,7 +38,22 @@ HTTPUnprocessableEntity, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.api.searching_entries import iter_over_pages from neptune.core.components.operation_storage import OperationStorage from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE @@ -58,26 +73,13 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - Attribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FileAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, OptionalFeatures, Project, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) from neptune.internal.backends.hosted_artifact_operations import ( @@ -155,21 +157,12 @@ _logger = get_logger() ATOMIC_ATTRIBUTE_TYPES = { - AttributeType.INT.value, - AttributeType.FLOAT.value, - AttributeType.STRING.value, - AttributeType.BOOL.value, - AttributeType.DATETIME.value, - AttributeType.RUN_STATE.value, -} - -ATOMIC_ATTRIBUTE_TYPES = { - AttributeType.INT.value, - AttributeType.FLOAT.value, - AttributeType.STRING.value, - AttributeType.BOOL.value, - AttributeType.DATETIME.value, - AttributeType.RUN_STATE.value, + FieldType.INT.value, + FieldType.FLOAT.value, + FieldType.STRING.value, + FieldType.BOOL.value, + FieldType.DATETIME.value, + FieldType.OBJECT_STATE.value, } @@ -682,10 +675,7 @@ def _execute_operations( raise NeptuneLimitExceedException(reason=e.response.json().get("title", "Unknown reason")) from e @with_api_exceptions_handler - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: - def to_attribute(attr) -> Attribute: - return Attribute(attr.name, AttributeType(attr.type)) - + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: params = { "experimentId": container_id, **DEFAULT_REQUEST_KWARGS, @@ -693,7 +683,7 @@ def to_attribute(attr) -> Attribute: try: experiment = self.leaderboard_client.api.getExperimentAttributes(**params).response().result - attribute_type_names = [at.value for at in AttributeType] + attribute_type_names = [at.value for at in FieldType] accepted_attributes = [attr for attr in experiment.attributes if attr.type in attribute_type_names] # Notify about ignored attrs @@ -706,7 +696,9 @@ def to_attribute(attr) -> Attribute: ignored_attributes, ) - return [to_attribute(attr) for attr in accepted_attributes if attr.type in attribute_type_names] + return [ + FieldDefinition.from_model(field) for field in accepted_attributes if field.type in attribute_type_names + ] except HTTPNotFound as e: raise ContainerUUIDNotFound( container_id=container_id, @@ -782,7 +774,7 @@ def download_file_set( raise @with_api_exceptions_handler - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -790,12 +782,12 @@ def get_float_attribute(self, container_id: str, container_type: ContainerType, } try: result = self.leaderboard_client.api.getFloatAttribute(**params).response().result - return FloatAttribute(result.value) + return FloatField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -803,12 +795,12 @@ def get_int_attribute(self, container_id: str, container_type: ContainerType, pa } try: result = self.leaderboard_client.api.getIntAttribute(**params).response().result - return IntAttribute(result.value) + return IntField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -816,12 +808,12 @@ def get_bool_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getBoolAttribute(**params).response().result - return BoolAttribute(result.value) + return BoolField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -829,14 +821,12 @@ def get_file_attribute(self, container_id: str, container_type: ContainerType, p } try: result = self.leaderboard_client.api.getFileAttribute(**params).response().result - return FileAttribute(name=result.name, ext=result.ext, size=result.size) + return FileField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -844,14 +834,14 @@ def get_string_attribute( } try: result = self.leaderboard_client.api.getStringAttribute(**params).response().result - return StringAttribute(result.value) + return StringField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DateTimeField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -859,13 +849,13 @@ def get_datetime_attribute( } try: result = self.leaderboard_client.api.getDatetimeAttribute(**params).response().result - return DatetimeAttribute(result.value) + return DateTimeField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: return get_artifact_attribute( swagger_client=self.leaderboard_client, parent_identifier=container_id, @@ -899,7 +889,7 @@ def list_fileset_files(self, attribute: List[str], container_id: str, path: str) @with_api_exceptions_handler def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -907,14 +897,14 @@ def get_float_series_attribute( } try: result = self.leaderboard_client.api.getFloatSeriesAttribute(**params).response().result - return FloatSeriesAttribute(result.last) + return FloatSeriesField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -922,14 +912,14 @@ def get_string_series_attribute( } try: result = self.leaderboard_client.api.getStringSeriesAttribute(**params).response().result - return StringSeriesAttribute(result.last) + return StringSeriesField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @with_api_exceptions_handler def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: params = { "experimentId": container_id, "attribute": path_to_str(path), @@ -937,7 +927,7 @@ def get_string_set_attribute( } try: result = self.leaderboard_client.api.getStringSetAttribute(**params).response().result - return StringSetAttribute(set(result.values)) + return StringSetField.from_model(result) except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) @@ -1016,7 +1006,7 @@ def get_float_series_values( @with_api_exceptions_handler def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: params = { "experimentId": container_id, } @@ -1081,9 +1071,9 @@ def search_leaderboard_entries( attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {} if sort_by == "sys/creation_time": - sort_by_column_type = AttributeType.DATETIME.value - if sort_by == "sys/id": - sort_by_column_type = AttributeType.STRING.value + sort_by_column_type = FieldType.DATETIME.value + elif sort_by == "sys/id": + sort_by_column_type = FieldType.STRING.value else: sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter) sort_by_column_type = _get_column_type_from_entries(sort_by_column_type_candidates, sort_by) @@ -1147,11 +1137,11 @@ def _get_column_type_from_entries(entries: List[Any], column: str) -> str: ) types.add(entry.type) - if types == {AttributeType.INT.value, AttributeType.FLOAT.value}: - return AttributeType.FLOAT.value + if types == {FieldType.INT.value, FieldType.FLOAT.value}: + return FieldType.FLOAT.value warn_once( f"Column {column} contains more than one simple data type. Sorting result might be inaccurate.", exception=NeptuneWarning, ) - return AttributeType.STRING.value + return FieldType.STRING.value diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index b282042bb..f6d21dbeb 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -25,28 +25,30 @@ Union, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - Attribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FileAttribute, - FloatAttribute, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, Project, - StringAttribute, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) from neptune.internal.backends.nql import NQLQuery @@ -146,7 +148,7 @@ def execute_operations( pass @abc.abstractmethod - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: pass @abc.abstractmethod @@ -172,37 +174,35 @@ def download_file_set( pass @abc.abstractmethod - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: pass @abc.abstractmethod - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: pass @abc.abstractmethod - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: pass @abc.abstractmethod - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: pass @abc.abstractmethod - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: pass @abc.abstractmethod def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DateTimeField: pass @abc.abstractmethod def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: pass @abc.abstractmethod @@ -212,19 +212,19 @@ def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[Artif @abc.abstractmethod def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: pass @abc.abstractmethod def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: pass @abc.abstractmethod def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: pass @abc.abstractmethod @@ -298,7 +298,7 @@ def get_model_version_url( @abc.abstractmethod def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 4ff41e88e..0a8a1b515 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -34,7 +34,22 @@ ) from zipfile import ZipFile -from neptune.api.dtos import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + LeaderboardEntry, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.exceptions import ( ContainerUUIDNotFound, @@ -46,25 +61,12 @@ from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( ApiExperiment, - ArtifactAttribute, - Attribute, - AttributeType, - BoolAttribute, - DatetimeAttribute, - FileAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - LeaderboardEntry, Project, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, Workspace, ) from neptune.internal.backends.hosted_file_operations import get_unique_upload_entries @@ -311,7 +313,7 @@ def _execute_operation( else: run.pop(op.path) - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: run = self._get_container(container_id, container_type) return list(self._generate_attributes(None, run.get_structure())) @@ -321,7 +323,7 @@ def _generate_attributes(self, base_path: Optional[str], values: dict): if isinstance(value_or_dict, dict): yield from self._generate_attributes(new_path, value_or_dict) else: - yield Attribute( + yield FieldDefinition( new_path, value_or_dict.accept(self._attribute_type_converter_value_visitor), ) @@ -370,64 +372,63 @@ def download_file_set( for upload_entry in upload_entries: zipObj.write(upload_entry.source, upload_entry.target_path) - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: val = self._get_attribute(container_id, container_type, path, Float) - return FloatAttribute(val.value) + return FloatField(path=path_to_str(path), value=val.value) - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: val = self._get_attribute(container_id, container_type, path, Integer) - return IntAttribute(val.value) + return IntField(path=path_to_str(path), value=val.value) - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: val = self._get_attribute(container_id, container_type, path, Boolean) - return BoolAttribute(val.value) + return BoolField(path=path_to_str(path), value=val.value) - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: val = self._get_attribute(container_id, container_type, path, File) - return FileAttribute( + return FileField( + path=path_to_str(path), name=os.path.basename(val.path) if val.file_type is FileType.LOCAL_FILE else "", ext=val.extension or "", size=0, ) - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: val = self._get_attribute(container_id, container_type, path, String) - return StringAttribute(val.value) + return StringField(path=path_to_str(path), value=val.value) def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DateTimeField: val = self._get_attribute(container_id, container_type, path, Datetime) - return DatetimeAttribute(val.value) + return DateTimeField(path=path_to_str(path), value=val.value) def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: val = self._get_attribute(container_id, container_type, path, Artifact) - return ArtifactAttribute(val.hash) + return ArtifactField(path=path_to_str(path), hash=val.hash) def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[ArtifactFileData]: return self._artifacts[(project_id, artifact_hash)] def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: val = self._get_attribute(container_id, container_type, path, FloatSeries) - return FloatSeriesAttribute(val.values[-1] if val.values else None) + return FloatSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None) def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: val = self._get_attribute(container_id, container_type, path, StringSeries) - return StringSeriesAttribute(val.values[-1] if val.values else None) + return StringSeriesField(path=path_to_str(path), last=val.values[-1] if val.values else None) def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: val = self._get_attribute(container_id, container_type, path, StringSet) - return StringSetAttribute(set(val.values)) + return StringSetField(path=path_to_str(path), values=set(val.values)) def _get_attribute( self, @@ -528,7 +529,7 @@ def _get_attribute_values(self, value_dict, path_prefix: List[str]): def fetch_atom_attribute_values( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> List[Tuple[str, AttributeType, Any]]: + ) -> List[Tuple[str, FieldType, Any]]: run = self._get_container(container_id, container_type) values = self._get_attribute_values(run.get(path), path) namespace_prefix = path_to_str(path) @@ -554,50 +555,50 @@ def search_leaderboard_entries( ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" - class AttributeTypeConverterValueVisitor(ValueVisitor[AttributeType]): - def visit_float(self, _: Float) -> AttributeType: - return AttributeType.FLOAT + class AttributeTypeConverterValueVisitor(ValueVisitor[FieldType]): + def visit_float(self, _: Float) -> FieldType: + return FieldType.FLOAT - def visit_integer(self, _: Integer) -> AttributeType: - return AttributeType.INT + def visit_integer(self, _: Integer) -> FieldType: + return FieldType.INT - def visit_boolean(self, _: Boolean) -> AttributeType: - return AttributeType.BOOL + def visit_boolean(self, _: Boolean) -> FieldType: + return FieldType.BOOL - def visit_string(self, _: String) -> AttributeType: - return AttributeType.STRING + def visit_string(self, _: String) -> FieldType: + return FieldType.STRING - def visit_datetime(self, _: Datetime) -> AttributeType: - return AttributeType.DATETIME + def visit_datetime(self, _: Datetime) -> FieldType: + return FieldType.DATETIME - def visit_file(self, _: File) -> AttributeType: - return AttributeType.FILE + def visit_file(self, _: File) -> FieldType: + return FieldType.FILE - def visit_file_set(self, _: FileSet) -> AttributeType: - return AttributeType.FILE_SET + def visit_file_set(self, _: FileSet) -> FieldType: + return FieldType.FILE_SET - def visit_float_series(self, _: FloatSeries) -> AttributeType: - return AttributeType.FLOAT_SERIES + def visit_float_series(self, _: FloatSeries) -> FieldType: + return FieldType.FLOAT_SERIES - def visit_string_series(self, _: StringSeries) -> AttributeType: - return AttributeType.STRING_SERIES + def visit_string_series(self, _: StringSeries) -> FieldType: + return FieldType.STRING_SERIES - def visit_image_series(self, _: FileSeries) -> AttributeType: - return AttributeType.IMAGE_SERIES + def visit_image_series(self, _: FileSeries) -> FieldType: + return FieldType.IMAGE_SERIES - def visit_string_set(self, _: StringSet) -> AttributeType: - return AttributeType.STRING_SET + def visit_string_set(self, _: StringSet) -> FieldType: + return FieldType.STRING_SET - def visit_git_ref(self, _: GitRef) -> AttributeType: - return AttributeType.GIT_REF + def visit_git_ref(self, _: GitRef) -> FieldType: + return FieldType.GIT_REF - def visit_artifact(self, _: Artifact) -> AttributeType: - return AttributeType.ARTIFACT + def visit_artifact(self, _: Artifact) -> FieldType: + return FieldType.ARTIFACT - def visit_namespace(self, _: Namespace) -> AttributeType: + def visit_namespace(self, _: Namespace) -> FieldType: raise NotImplementedError - def copy_value(self, source_type: Type[Attribute], source_path: List[str]) -> AttributeType: + def copy_value(self, source_type: Type[FieldDefinition], source_path: List[str]) -> FieldType: raise NotImplementedError class NewValueOpVisitor(OperationVisitor[Optional[Value]]): diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index f5d2589bf..48f2b8264 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -20,24 +20,26 @@ Optional, ) -from neptune.api.dtos import FileEntry +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + FieldDefinition, + FileEntry, + FileField, + FloatField, + FloatSeriesField, + IntField, + StringField, + StringSeriesField, + StringSetField, +) from neptune.exceptions import NeptuneOfflineModeFetchException from neptune.internal.artifacts.types import ArtifactFileData from neptune.internal.backends.api_model import ( - ArtifactAttribute, - Attribute, - BoolAttribute, - DatetimeAttribute, - FileAttribute, - FloatAttribute, - FloatSeriesAttribute, FloatSeriesValues, ImageSeriesValues, - IntAttribute, - StringAttribute, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType @@ -47,34 +49,32 @@ class OfflineNeptuneBackend(NeptuneBackendMock): WORKSPACE_NAME = "offline" - def get_attributes(self, container_id: str, container_type: ContainerType) -> List[Attribute]: + def get_attributes(self, container_id: str, container_type: ContainerType) -> List[FieldDefinition]: raise NeptuneOfflineModeFetchException - def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatAttribute: + def get_float_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FloatField: raise NeptuneOfflineModeFetchException - def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntAttribute: + def get_int_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> IntField: raise NeptuneOfflineModeFetchException - def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolAttribute: + def get_bool_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> BoolField: raise NeptuneOfflineModeFetchException - def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileAttribute: + def get_file_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> FileField: raise NeptuneOfflineModeFetchException - def get_string_attribute( - self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringAttribute: + def get_string_attribute(self, container_id: str, container_type: ContainerType, path: List[str]) -> StringField: raise NeptuneOfflineModeFetchException def get_datetime_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> DatetimeAttribute: + ) -> DateTimeField: raise NeptuneOfflineModeFetchException def get_artifact_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> ArtifactAttribute: + ) -> ArtifactField: raise NeptuneOfflineModeFetchException def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[ArtifactFileData]: @@ -82,17 +82,17 @@ def list_artifact_files(self, project_id: str, artifact_hash: str) -> List[Artif def get_float_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> FloatSeriesAttribute: + ) -> FloatSeriesField: raise NeptuneOfflineModeFetchException def get_string_series_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSeriesAttribute: + ) -> StringSeriesField: raise NeptuneOfflineModeFetchException def get_string_set_attribute( self, container_id: str, container_type: ContainerType, path: List[str] - ) -> StringSetAttribute: + ) -> StringSetField: raise NeptuneOfflineModeFetchException def get_string_series_values( diff --git a/src/neptune/internal/utils/generic_attribute_mapper.py b/src/neptune/internal/utils/generic_attribute_mapper.py index 3f5a29752..da40ba9b2 100644 --- a/src/neptune/internal/utils/generic_attribute_mapper.py +++ b/src/neptune/internal/utils/generic_attribute_mapper.py @@ -15,7 +15,7 @@ # __all__ = ["NoValue", "atomic_attribute_types_map", "map_attribute_result_to_value"] -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType class NoValue: @@ -27,30 +27,30 @@ class NoValue: VALUES = "values" atomic_attribute_types_map = { - AttributeType.FLOAT.value: "floatProperties", - AttributeType.INT.value: "intProperties", - AttributeType.BOOL.value: "boolProperties", - AttributeType.STRING.value: "stringProperties", - AttributeType.DATETIME.value: "datetimeProperties", - AttributeType.RUN_STATE.value: "experimentStateProperties", - AttributeType.NOTEBOOK_REF.value: "notebookRefProperties", + FieldType.FLOAT.value: "floatProperties", + FieldType.INT.value: "intProperties", + FieldType.BOOL.value: "boolProperties", + FieldType.STRING.value: "stringProperties", + FieldType.DATETIME.value: "datetimeProperties", + FieldType.OBJECT_STATE.value: "experimentStateProperties", + FieldType.NOTEBOOK_REF.value: "notebookRefProperties", } value_series_attribute_types_map = { - AttributeType.FLOAT_SERIES.value: "floatSeriesProperties", - AttributeType.STRING_SERIES.value: "stringSeriesProperties", + FieldType.FLOAT_SERIES.value: "floatSeriesProperties", + FieldType.STRING_SERIES.value: "stringSeriesProperties", } value_set_attribute_types_map = { - AttributeType.STRING_SET.value: "stringSetProperties", + FieldType.STRING_SET.value: "stringSetProperties", } # TODO: nicer mapping? _unmapped_attribute_types_map = { - AttributeType.FILE_SET.value: "fileSetProperties", # TODO: return size? - AttributeType.FILE.value: "fileProperties", # TODO: name? size? - AttributeType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step? - AttributeType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch? + FieldType.FILE_SET.value: "fileSetProperties", # TODO: return size? + FieldType.FILE.value: "fileProperties", # TODO: name? size? + FieldType.IMAGE_SERIES.value: "imageSeriesProperties", # TODO: return last step? + FieldType.GIT_REF.value: "gitRefProperties", # TODO: commit? branch? } diff --git a/src/neptune/objects/neptune_object.py b/src/neptune/objects/neptune_object.py index ce4def13c..d9459ef06 100644 --- a/src/neptune/objects/neptune_object.py +++ b/src/neptune/objects/neptune_object.py @@ -39,6 +39,7 @@ Union, ) +from neptune.api.models import FieldType from neptune.attributes import create_attribute_from_type from neptune.attributes.attribute import Attribute from neptune.attributes.namespace import Namespace as NamespaceAttr @@ -51,7 +52,6 @@ from neptune.handler import Handler from neptune.internal.backends.api_model import ( ApiExperiment, - AttributeType, Project, ) from neptune.internal.backends.factory import get_backend @@ -97,7 +97,6 @@ AbstractNeptuneObject, NeptuneObjectCallback, ) -from neptune.objects.utils import parse_dates from neptune.table import Table from neptune.types.mode import Mode from neptune.types.type_casting import cast_value @@ -630,7 +629,7 @@ def sync(self, *, wait: bool = True) -> None: for attribute in attributes: self._define_attribute(parse_path(attribute.path), attribute.type) - def _define_attribute(self, _path: List[str], _type: AttributeType): + def _define_attribute(self, _path: List[str], _type: FieldType): attr = create_attribute_from_type(_type, self, _path) self._structure.set(_path, attr) @@ -685,8 +684,6 @@ def _fetch_entries( progress_bar=progress_bar, ) - leaderboard_entries = parse_dates(leaderboard_entries) - return Table( backend=self._backend, container_type=child_type, diff --git a/src/neptune/objects/utils.py b/src/neptune/objects/utils.py index 446e0637c..73763e746 100644 --- a/src/neptune/objects/utils.py +++ b/src/neptune/objects/utils.py @@ -15,23 +15,16 @@ # __all__ = [ - "parse_dates", "prepare_nql_query", ] from typing import ( - Generator, Iterable, List, Optional, Union, ) -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, - LeaderboardEntry, -) from neptune.internal.backends.nql import ( NQLAggregator, NQLAttributeOperator, @@ -41,12 +34,7 @@ NQLQueryAttribute, RawNQLQuery, ) -from neptune.internal.utils.iso_dates import parse_iso_date from neptune.internal.utils.run_state import RunState -from neptune.internal.warnings import ( - NeptuneWarning, - warn_once, -) def prepare_nql_query( @@ -136,39 +124,6 @@ def prepare_nql_query( return query -def parse_dates(leaderboard_entries: Iterable[LeaderboardEntry]) -> Generator[LeaderboardEntry, None, None]: - yield from [_parse_entry(entry) for entry in leaderboard_entries] - - -def _parse_entry(entry: LeaderboardEntry) -> LeaderboardEntry: - try: - return LeaderboardEntry( - entry.id, - attributes=[ - ( - AttributeWithProperties( - attribute.path, - attribute.type, - { - **attribute.properties, - "value": parse_iso_date(attribute.properties["value"]), - }, - ) - if attribute.type == AttributeType.DATETIME - else attribute - ) - for attribute in entry.attributes - ], - ) - except ValueError: - # the parsing format is incorrect - warn_once( - "Date parsing failed. The date format is incorrect. Returning as string instead of datetime.", - exception=NeptuneWarning, - ) - return entry - - def build_raw_query(query: str, trashed: Optional[bool]) -> NQLQuery: raw_nql = RawNQLQuery(query) diff --git a/src/neptune/table.py b/src/neptune/table.py index 9b997cc09..ac4a75c2e 100644 --- a/src/neptune/table.py +++ b/src/neptune/table.py @@ -23,13 +23,14 @@ Optional, ) -from neptune.exceptions import MetadataInconsistency -from neptune.integrations.pandas import to_pandas -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, +from neptune.api.field_visitor import FieldToValueVisitor +from neptune.api.models import ( + Field, + FieldType, LeaderboardEntry, ) +from neptune.exceptions import MetadataInconsistency +from neptune.integrations.pandas import to_pandas from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.container_type import ContainerType from neptune.internal.utils.logger import get_logger @@ -37,7 +38,6 @@ join_paths, parse_path, ) -from neptune.internal.utils.run_state import RunState from neptune.typing import ProgressBarType if TYPE_CHECKING: @@ -53,57 +53,28 @@ def __init__( backend: NeptuneBackend, container_type: ContainerType, _id: str, - attributes: List[AttributeWithProperties], + attributes: List[Field], ): self._backend = backend self._container_type = container_type self._id = _id - self._attributes = attributes + self._fields = attributes + self._field_to_value_visitor = FieldToValueVisitor() def __getitem__(self, path: str) -> "LeaderboardHandler": return LeaderboardHandler(table_entry=self, path=path) - def get_attribute_type(self, path: str) -> AttributeType: - for attr in self._attributes: - if attr.path == path: - return attr.type - raise ValueError("Could not find {} attribute".format(path)) + def get_attribute_type(self, path: str) -> FieldType: + for field in self._fields: + if field.path == path: + return field.type + + raise ValueError(f"Could not find {path} field") def get_attribute_value(self, path: str) -> Any: - for attr in self._attributes: - if attr.path == path: - _type = attr.type - if _type == AttributeType.RUN_STATE: - return RunState.from_api(attr.properties.get("value")).value - if _type in ( - AttributeType.FLOAT, - AttributeType.INT, - AttributeType.BOOL, - AttributeType.STRING, - AttributeType.DATETIME, - ): - return attr.properties.get("value") - if _type == AttributeType.FLOAT_SERIES or _type == AttributeType.STRING_SERIES: - return attr.properties.get("last") - if _type == AttributeType.IMAGE_SERIES: - raise MetadataInconsistency("Cannot get value for image series.") - if _type == AttributeType.FILE: - raise MetadataInconsistency("Cannot get value for file attribute. Use download() instead.") - if _type == AttributeType.FILE_SET: - raise MetadataInconsistency("Cannot get value for file set attribute. Use download() instead.") - if _type == AttributeType.STRING_SET: - return set(attr.properties.get("values")) - if _type == AttributeType.GIT_REF: - return attr.properties.get("commit", {}).get("commitId") - if _type == AttributeType.NOTEBOOK_REF: - return attr.properties.get("notebookName") - if _type == AttributeType.ARTIFACT: - return attr.properties.get("hash") - logger.error( - "Attribute type %s not supported in this version, yielding None. Recommended client upgrade.", - _type, - ) - return None + for field in self._fields: + if field.path == path: + return self._field_to_value_visitor.visit(field) raise ValueError("Could not find {} attribute".format(path)) def download_file_attribute( @@ -112,10 +83,10 @@ def download_file_attribute( destination: Optional[str], progress_bar: Optional[ProgressBarType] = None, ) -> None: - for attr in self._attributes: + for attr in self._fields: if attr.path == path: _type = attr.type - if _type == AttributeType.FILE: + if _type == FieldType.FILE: self._backend.download_file( container_id=self._id, container_type=self._container_type, @@ -133,10 +104,10 @@ def download_file_set_attribute( destination: Optional[str], progress_bar: Optional[ProgressBarType] = None, ) -> None: - for attr in self._attributes: + for attr in self._fields: if attr.path == path: _type = attr.type - if _type == AttributeType.FILE_SET: + if _type == FieldType.FILE_SET: self._backend.download_file_set( container_id=self._id, container_type=self._container_type, @@ -162,9 +133,9 @@ def get(self) -> Any: def download(self, destination: Optional[str]) -> None: attr_type = self._table_entry.get_attribute_type(self._path) - if attr_type == AttributeType.FILE: + if attr_type == FieldType.FILE: return self._table_entry.download_file_attribute(self._path, destination) - elif attr_type == AttributeType.FILE_SET: + elif attr_type == FieldType.FILE_SET: return self._table_entry.download_file_set_attribute(path=self._path, destination=destination) raise MetadataInconsistency("Cannot download file from attribute of type {}".format(attr_type)) @@ -193,8 +164,8 @@ def __next__(self) -> TableEntry: return TableEntry( backend=self._backend, container_type=self._container_type, - _id=entry.id, - attributes=entry.attributes, + _id=entry.object_id, + attributes=entry.fields, ) def to_pandas(self) -> "pandas.DataFrame": diff --git a/tests/unit/neptune/new/api/test_models.py b/tests/unit/neptune/new/api/test_models.py new file mode 100644 index 000000000..8e09b1974 --- /dev/null +++ b/tests/unit/neptune/new/api/test_models.py @@ -0,0 +1,1493 @@ +# +# Copyright (c) 2024, Neptune Labs Sp. z o.o. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime + +import pytest +from mock import Mock + +from neptune.api.models import ( + ArtifactField, + BoolField, + DateTimeField, + Field, + FieldDefinition, + FieldType, + FileEntry, + FileField, + FileSetField, + FloatField, + FloatSeriesField, + GitRefField, + ImageSeriesField, + IntField, + LeaderboardEntriesSearchResult, + LeaderboardEntry, + NotebookRefField, + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, +) + + +def test__float_field__from_dict(): + # given + data = {"attributeType": "float", "attributeName": "some/float", "value": 18.5} + + # when + result = FloatField.from_dict(data) + + # then + assert result.path == "some/float" + assert result.value == 18.5 + + +def test__float_field__from_model(): + # given + model = Mock(attributeType="float", attributeName="some/float", value=18.5) + + # when + result = FloatField.from_model(model) + + # then + assert result.path == "some/float" + assert result.value == 18.5 + + +def test__int_field__from_dict(): + # given + data = {"attributeType": "int", "attributeName": "some/int", "value": 18} + + # when + result = IntField.from_dict(data) + + # then + assert result.path == "some/int" + assert result.value == 18 + + +def test__int_field__from_model(): + # given + model = Mock(attributeType="int", attributeName="some/int", value=18) + + # when + result = IntField.from_model(model) + + # then + assert result.path == "some/int" + assert result.value == 18 + + +def test__string_field__from_dict(): + # given + data = {"attributeType": "string", "attributeName": "some/string", "value": "hello"} + + # when + result = StringField.from_dict(data) + + # then + assert result.path == "some/string" + assert result.value == "hello" + + +def test__string_field__from_model(): + # given + model = Mock(attributeType="string", attributeName="some/string", value="hello") + + # when + result = StringField.from_model(model) + + # then + assert result.path == "some/string" + assert result.value == "hello" + + +def test__string_field__from_dict__empty(): + # given + data = {"attributeType": "string", "attributeName": "some/string", "value": ""} + + # when + result = StringField.from_dict(data) + + # then + assert result.path == "some/string" + assert result.value == "" + + +def test__string_field__from_model__empty(): + # given + model = Mock(attributeType="string", attributeName="some/string", value="") + + # when + result = StringField.from_model(model) + + # then + assert result.path == "some/string" + assert result.value == "" + + +def test__bool_field__from_dict(): + # given + data = {"attributeType": "bool", "attributeName": "some/bool", "value": True} + + # when + result = BoolField.from_dict(data) + + # then + assert result.path == "some/bool" + assert result.value is True + + +def test__bool_field__from_model(): + # given + model = Mock(attributeType="bool", attributeName="some/bool", value=True) + + # when + result = BoolField.from_model(model) + + # then + assert result.path == "some/bool" + assert result.value is True + + +def test__datetime_field__from_dict(): + # given + data = {"attributeType": "datetime", "attributeName": "some/datetime", "value": "2024-01-01T00:12:34.567890Z"} + + # when + result = DateTimeField.from_dict(data) + + # then + assert result.path == "some/datetime" + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__datetime_field__from_model(): + # given + model = Mock(attributeType="datetime", attributeName="some/datetime", value="2024-01-01T00:12:34.567890Z") + + # when + result = DateTimeField.from_model(model) + + # then + assert result.path == "some/datetime" + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__float_series_field__from_dict(): + # given + data = { + "attributeType": "floatSeries", + "attributeName": "some/floatSeries", + "last": 19.5, + } + + # when + result = FloatSeriesField.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert result.last == 19.5 + + +def test__float_series_field__from_dict__no_last(): + # given + data = { + "attributeType": "floatSeries", + "attributeName": "some/floatSeries", + } + + # when + result = FloatSeriesField.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert result.last is None + + +def test__float_series_field__from_model(): + # given + model = Mock( + attributeType="floatSeries", + attributeName="some/floatSeries", + last=19.5, + ) + + # when + result = FloatSeriesField.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert result.last == 19.5 + + +def test__float_series_field__from_model__no_last(): + # given + model = Mock( + attributeType="floatSeries", + attributeName="some/floatSeries", + last=None, + ) + + # when + result = FloatSeriesField.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert result.last is None + + +def test__string_series_field__from_dict(): + # given + data = { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + "last": "hello", + } + + # when + result = StringSeriesField.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert result.last == "hello" + + +def test__string_series_field__from_dict__no_last(): + # given + data = { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + } + + # when + result = StringSeriesField.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert result.last is None + + +def test__string_series_field__from_model(): + # given + model = Mock( + attributeType="stringSeries", + attributeName="some/stringSeries", + last="hello", + ) + + # when + result = StringSeriesField.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert result.last == "hello" + + +def test__string_series_field__from_model__no_last(): + # given + model = Mock( + attributeType="stringSeries", + attributeName="some/stringSeries", + last=None, + ) + + # when + result = StringSeriesField.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert result.last is None + + +def test__image_series_field__from_dict(): + # given + data = { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + "lastStep": 15.0, + } + + # when + result = ImageSeriesField.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert result.last_step == 15.0 + + +def test__image_series_field__from_dict__no_last_step(): + # given + data = { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + } + + # when + result = ImageSeriesField.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert result.last_step is None + + +def test__image_series_field__from_model(): + # given + model = Mock( + attributeType="imageSeries", + attributeName="some/imageSeries", + lastStep=15.0, + ) + + # when + result = ImageSeriesField.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert result.last_step == 15.0 + + +def test__image_series_field__from_model__no_last_step(): + # given + model = Mock( + attributeType="imageSeries", + attributeName="some/imageSeries", + lastStep=None, + ) + + # when + result = ImageSeriesField.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert result.last_step is None + + +def test__string_set_field__from_dict(): + # given + data = { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": ["hello", "world"], + } + + # when + result = StringSetField.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert result.values == {"hello", "world"} + + +def test__string_set_field__from_dict__empty(): + # given + data = { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": [], + } + + # when + result = StringSetField.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert result.values == set() + + +def test__string_set_field__from_model(): + # given + model = Mock( + attributeType="stringSet", + attributeName="some/stringSet", + values=["hello", "world"], + ) + + # when + result = StringSetField.from_model(model) + + # then + assert result.path == "some/stringSet" + assert result.values == {"hello", "world"} + + +def test__string_set_field__from_model__empty(): + # given + model = Mock( + attributeType="stringSet", + attributeName="some/stringSet", + values=[], + ) + + # when + result = StringSetField.from_model(model) + + # then + assert result.path == "some/stringSet" + assert result.values == set() + + +def test__file_field__from_dict(): + # given + data = { + "attributeType": "file", + "attributeName": "some/file", + "name": "file.txt", + "size": 1024, + "ext": "txt", + } + + # when + result = FileField.from_dict(data) + + # then + assert result.path == "some/file" + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__file_field__from_model(): + # given + model = Mock( + attributeType="file", + attributeName="some/file", + size=1024, + ext="txt", + ) + model.name = "file.txt" + + # when + result = FileField.from_model(model) + + # then + assert result.path == "some/file" + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +@pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) +def test__object_state_field__from_dict(state, expected): + # given + data = {"attributeType": "experimentState", "attributeName": "sys/state", "value": state} + + # when + result = ObjectStateField.from_dict(data) + + # then + assert result.path == "sys/state" + assert result.value == expected + + +@pytest.mark.parametrize("state,expected", [("running", "Active"), ("idle", "Inactive")]) +def test__object_state_field__from_model(state, expected): + # given + model = Mock(attributeType="experimentState", attributeName="sys/state", value=state) + + # when + result = ObjectStateField.from_model(model) + + # then + assert result.path == "sys/state" + assert result.value == expected + + +def test__file_set_field__from_dict(): + # given + data = { + "attributeType": "fileSet", + "attributeName": "some/fileSet", + "size": 3072, + } + + # when + result = FileSetField.from_dict(data) + + # then + assert result.path == "some/fileSet" + assert result.size == 3072 + + +def test__file_set_field__from_model(): + # given + model = Mock( + attributeType="fileSet", + attributeName="some/fileSet", + size=3072, + ) + + # when + result = FileSetField.from_model(model) + + # then + assert result.path == "some/fileSet" + assert result.size == 3072 + + +def test__notebook_ref_field__from_dict(): + # given + data = { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + "notebookName": "Data Processing.ipynb", + } + + # when + result = NotebookRefField.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name == "Data Processing.ipynb" + + +def test__notebook_ref_field__from_dict__no_notebook_name(): + # given + data = { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + } + + # when + result = NotebookRefField.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name is None + + +def test__notebook_ref_field__from_model(): + # given + model = Mock( + attributeType="notebookRef", + attributeName="some/notebookRef", + notebookName="Data Processing.ipynb", + ) + + # when + result = NotebookRefField.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name == "Data Processing.ipynb" + + +def test__notebook_ref_field__from_model__no_notebook_name(): + # given + model = Mock( + attributeType="notebookRef", + attributeName="some/notebookRef", + notebookName=None, + ) + + # when + result = NotebookRefField.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert result.notebook_name is None + + +def test__git_ref_field__from_dict(): + # given + data = { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + "commit": { + "commitId": "b2d7f8a", + }, + } + + # when + result = GitRefField.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert result.commit.commit_id == "b2d7f8a" + + +def test__git_ref_field__from_dict__no_commit(): + # given + data = { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + } + + # when + result = GitRefField.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert result.commit is None + + +def test__git_ref_field__from_model(): + # given + model = Mock( + attributeType="gitRef", + attributeName="some/gitRef", + commit=Mock( + commitId="b2d7f8a", + ), + ) + + # when + result = GitRefField.from_model(model) + + # then + assert result.path == "some/gitRef" + assert result.commit.commit_id == "b2d7f8a" + + +def test__git_ref_field__from_model__no_commit(): + # given + model = Mock( + attributeType="gitRef", + attributeName="some/gitRef", + commit=None, + ) + + # when + result = GitRefField.from_model(model) + + # then + assert result.path == "some/gitRef" + assert result.commit is None + + +def test__artifact_field__from_dict(): + # given + data = { + "attributeType": "artifact", + "attributeName": "some/artifact", + "hash": "f192cddb2b98c0b4c72bba22b68d2245", + } + + # when + result = ArtifactField.from_dict(data) + + # then + assert result.path == "some/artifact" + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__artifact_field__from_model(): + # given + model = Mock( + attributeType="artifact", + attributeName="some/artifact", + hash="f192cddb2b98c0b4c72bba22b68d2245", + ) + + # when + result = ArtifactField.from_model(model) + + # then + assert result.path == "some/artifact" + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field__from_dict__float(): + # given + data = { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/float" + assert isinstance(result, FloatField) + assert result.value == 18.5 + + +def test__field__from_model__float(): + # given + model = Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/float" + assert isinstance(result, FloatField) + assert result.value == 18.5 + + +def test__field__from_dict__int(): + # given + data = { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/int" + assert isinstance(result, IntField) + assert result.value == 18 + + +def test__field__from_model__int(): + # given + model = Mock( + path="some/int", type="int", intProperties=Mock(attributeType="int", attributeName="some/int", value=18) + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/int" + assert isinstance(result, IntField) + assert result.value == 18 + + +def test__field__from_dict__string(): + # given + data = { + "path": "some/string", + "type": "string", + "stringProperties": {"attributeType": "string", "attributeName": "some/string", "value": "hello"}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/string" + assert isinstance(result, StringField) + assert result.value == "hello" + + +def test__field__from_model__string(): + # given + model = Mock( + path="some/string", + type="string", + stringProperties=Mock(attributeType="string", attributeName="some/string", value="hello"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/string" + assert isinstance(result, StringField) + assert result.value == "hello" + + +def test__field__from_dict__bool(): + # given + data = { + "path": "some/bool", + "type": "bool", + "boolProperties": {"attributeType": "bool", "attributeName": "some/bool", "value": True}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/bool" + assert isinstance(result, BoolField) + assert result.value is True + + +def test__field__from_model__bool(): + # given + model = Mock( + path="some/bool", type="bool", boolProperties=Mock(attributeType="bool", attributeName="some/bool", value=True) + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/bool" + assert isinstance(result, BoolField) + assert result.value is True + + +def test__field__from_dict__datetime(): + # given + data = { + "path": "some/datetime", + "type": "datetime", + "datetimeProperties": { + "attributeType": "datetime", + "attributeName": "some/datetime", + "value": "2024-01-01T00:12:34.567890Z", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/datetime" + assert isinstance(result, DateTimeField) + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__field__from_model__datetime(): + # given + model = Mock( + path="some/datetime", + type="datetime", + datetimeProperties=Mock( + attributeType="datetime", attributeName="some/datetime", value="2024-01-01T00:12:34.567890Z" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/datetime" + assert isinstance(result, DateTimeField) + assert result.value == datetime.datetime(2024, 1, 1, 0, 12, 34, 567890) + + +def test__field__from_dict__float_series(): + # given + data = { + "path": "some/floatSeries", + "type": "floatSeries", + "floatSeriesProperties": {"attributeType": "floatSeries", "attributeName": "some/floatSeries", "last": 19.5}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/floatSeries" + assert isinstance(result, FloatSeriesField) + assert result.last == 19.5 + + +def test__field__from_model__float_series(): + # given + model = Mock( + path="some/floatSeries", + type="floatSeries", + floatSeriesProperties=Mock(attributeType="floatSeries", attributeName="some/floatSeries", last=19.5), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/floatSeries" + assert isinstance(result, FloatSeriesField) + assert result.last == 19.5 + + +def test__field__from_dict__string_series(): + # given + data = { + "path": "some/stringSeries", + "type": "stringSeries", + "stringSeriesProperties": { + "attributeType": "stringSeries", + "attributeName": "some/stringSeries", + "last": "hello", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/stringSeries" + assert isinstance(result, StringSeriesField) + assert result.last == "hello" + + +def test__field__from_model__string_series(): + # given + model = Mock( + path="some/stringSeries", + type="stringSeries", + stringSeriesProperties=Mock(attributeType="stringSeries", attributeName="some/stringSeries", last="hello"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/stringSeries" + assert isinstance(result, StringSeriesField) + assert result.last == "hello" + + +def test__field__from_dict__image_series(): + # given + data = { + "path": "some/imageSeries", + "type": "imageSeries", + "imageSeriesProperties": { + "attributeType": "imageSeries", + "attributeName": "some/imageSeries", + "lastStep": 15.0, + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/imageSeries" + assert isinstance(result, ImageSeriesField) + assert result.last_step == 15.0 + + +def test__field__from_model__image_series(): + # given + model = Mock( + path="some/imageSeries", + type="imageSeries", + imageSeriesProperties=Mock(attributeType="imageSeries", attributeName="some/imageSeries", lastStep=15.0), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/imageSeries" + assert isinstance(result, ImageSeriesField) + assert result.last_step == 15.0 + + +def test__field__from_dict__string_set(): + # given + data = { + "path": "some/stringSet", + "type": "stringSet", + "stringSetProperties": { + "attributeType": "stringSet", + "attributeName": "some/stringSet", + "values": ["hello", "world"], + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/stringSet" + assert isinstance(result, StringSetField) + assert result.values == {"hello", "world"} + + +def test__field__from_model__string_set(): + # given + model = Mock( + path="some/stringSet", + type="stringSet", + stringSetProperties=Mock(attributeType="stringSet", attributeName="some/stringSet", values=["hello", "world"]), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/stringSet" + assert isinstance(result, StringSetField) + assert result.values == {"hello", "world"} + + +def test__field__from_dict__file(): + # given + data = { + "path": "some/file", + "type": "file", + "fileProperties": { + "attributeType": "file", + "attributeName": "some/file", + "name": "file.txt", + "size": 1024, + "ext": "txt", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/file" + assert isinstance(result, FileField) + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__field__from_model__file(): + # given + model = Mock( + path="some/file", + type="file", + fileProperties=Mock(attributeType="file", attributeName="some/file", size=1024, ext="txt"), + ) + model.fileProperties.name = "file.txt" + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/file" + assert isinstance(result, FileField) + assert result.name == "file.txt" + assert result.size == 1024 + assert result.ext == "txt" + + +def test__field__from_dict__object_state(): + # given + data = { + "path": "sys/state", + "type": "experimentState", + "experimentStateProperties": { + "attributeType": "experimentState", + "attributeName": "sys/state", + "value": "running", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "sys/state" + assert isinstance(result, ObjectStateField) + assert result.value == "Active" + + +def test__field__from_model__object_state(): + # given + model = Mock( + path="sys/state", + type="experimentState", + experimentStateProperties=Mock(attributeType="experimentState", attributeName="sys/state", value="running"), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "sys/state" + assert isinstance(result, ObjectStateField) + assert result.value == "Active" + + +def test__field__from_dict__file_set(): + # given + data = { + "path": "some/fileSet", + "type": "fileSet", + "fileSetProperties": {"attributeType": "fileSet", "attributeName": "some/fileSet", "size": 3072}, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/fileSet" + assert isinstance(result, FileSetField) + assert result.size == 3072 + + +def test__field__from_model__file_set(): + # given + model = Mock( + path="some/fileSet", + type="fileSet", + fileSetProperties=Mock(attributeType="fileSet", attributeName="some/fileSet", size=3072), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/fileSet" + assert isinstance(result, FileSetField) + assert result.size == 3072 + + +def test__field__from_dict__notebook_ref(): + # given + data = { + "path": "some/notebookRef", + "type": "notebookRef", + "notebookRefProperties": { + "attributeType": "notebookRef", + "attributeName": "some/notebookRef", + "notebookName": "Data Processing.ipynb", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/notebookRef" + assert isinstance(result, NotebookRefField) + assert result.notebook_name == "Data Processing.ipynb" + + +def test__field__from_model__notebook_ref(): + # given + model = Mock( + path="some/notebookRef", + type="notebookRef", + notebookRefProperties=Mock( + attributeType="notebookRef", attributeName="some/notebookRef", notebookName="Data Processing.ipynb" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/notebookRef" + assert isinstance(result, NotebookRefField) + assert result.notebook_name == "Data Processing.ipynb" + + +def test__field__from_dict__git_ref(): + # given + data = { + "path": "some/gitRef", + "type": "gitRef", + "gitRefProperties": { + "attributeType": "gitRef", + "attributeName": "some/gitRef", + "commit": {"commitId": "b2d7f8a"}, + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/gitRef" + assert isinstance(result, GitRefField) + assert result.commit.commit_id == "b2d7f8a" + + +def test__field__from_model__git_ref(): + # given + model = Mock( + path="some/gitRef", + type="gitRef", + gitRefProperties=Mock(attributeType="gitRef", attributeName="some/gitRef", commit=Mock(commitId="b2d7f8a")), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/gitRef" + assert isinstance(result, GitRefField) + assert result.commit.commit_id == "b2d7f8a" + + +def test__field__from_dict__artifact(): + # given + data = { + "path": "some/artifact", + "type": "artifact", + "artifactProperties": { + "attributeType": "artifact", + "attributeName": "some/artifact", + "hash": "f192cddb2b98c0b4c72bba22b68d2245", + }, + } + + # when + result = Field.from_dict(data) + + # then + assert result.path == "some/artifact" + assert isinstance(result, ArtifactField) + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field__from_model__artifact(): + # given + model = Mock( + path="some/artifact", + type="artifact", + artifactProperties=Mock( + attributeType="artifact", attributeName="some/artifact", hash="f192cddb2b98c0b4c72bba22b68d2245" + ), + ) + + # when + result = Field.from_model(model) + + # then + assert result.path == "some/artifact" + assert isinstance(result, ArtifactField) + assert result.hash == "f192cddb2b98c0b4c72bba22b68d2245" + + +def test__field_definition__from_dict(): + # given + data = { + "name": "some/float", + "type": "float", + } + + # when + result = FieldDefinition.from_dict(data) + + # then + assert result.path == "some/float" + assert result.type == FieldType.FLOAT + + +def test__field_definition__from_model(): + # given + model = Mock( + type="float", + ) + model.name = "some/float" + + # when + result = FieldDefinition.from_model(model) + + # then + assert result.path == "some/float" + assert result.type == FieldType.FLOAT + + +def test__leaderboard_entry__from_dict(): + # given + data = { + "experimentId": "some-id", + "attributes": [ + { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + }, + { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + }, + { + "path": "some/string", + "type": "string", + "stringProperties": {"attributeType": "string", "attributeName": "some/string", "value": "hello"}, + }, + ], + } + + # when + result = LeaderboardEntry.from_dict(data) + + # then + assert result.object_id == "some-id" + assert len(result.fields) == 3 + + float_field = result.fields[0] + assert isinstance(float_field, FloatField) + assert float_field.path == "some/float" + assert float_field.value == 18.5 + + int_field = result.fields[1] + assert isinstance(int_field, IntField) + assert int_field.path == "some/int" + + string_field = result.fields[2] + assert isinstance(string_field, StringField) + assert string_field.path == "some/string" + + +def test__leaderboard_entry__from_model(): + # given + model = Mock( + experimentId="some-id", + attributes=[ + Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ), + Mock( + path="some/int", type="int", intProperties=Mock(attributeType="int", attributeName="some/int", value=18) + ), + Mock( + path="some/string", + type="string", + stringProperties=Mock(attributeType="string", attributeName="some/string", value="hello"), + ), + ], + ) + + # when + result = LeaderboardEntry.from_model(model) + + # then + assert result.object_id == "some-id" + assert len(result.fields) == 3 + + float_field = result.fields[0] + assert isinstance(float_field, FloatField) + assert float_field.path == "some/float" + assert float_field.value == 18.5 + + int_field = result.fields[1] + assert isinstance(int_field, IntField) + assert int_field.path == "some/int" + + string_field = result.fields[2] + assert isinstance(string_field, StringField) + assert string_field.path == "some/string" + + +def test__leaderboard_entries_search_result__from_dict(): + # given + data = { + "matchingItemCount": 2, + "entries": [ + { + "experimentId": "some-id-1", + "attributes": [ + { + "path": "some/float", + "type": "float", + "floatProperties": {"attributeType": "float", "attributeName": "some/float", "value": 18.5}, + }, + ], + }, + { + "experimentId": "some-id-2", + "attributes": [ + { + "path": "some/int", + "type": "int", + "intProperties": {"attributeType": "int", "attributeName": "some/int", "value": 18}, + }, + ], + }, + ], + } + + # when + result = LeaderboardEntriesSearchResult.from_dict(data) + + # then + assert result.matching_item_count == 2 + assert len(result.entries) == 2 + + entry_1 = result.entries[0] + assert entry_1.object_id == "some-id-1" + assert len(entry_1.fields) == 1 + assert isinstance(entry_1.fields[0], FloatField) + + entry_2 = result.entries[1] + assert entry_2.object_id == "some-id-2" + assert len(entry_2.fields) == 1 + assert isinstance(entry_2.fields[0], IntField) + + +def test__leaderboard_entries_search_result__from_model(): + # given + model = Mock( + matchingItemCount=2, + entries=[ + Mock( + experimentId="some-id-1", + attributes=[ + Mock( + path="some/float", + type="float", + floatProperties=Mock(attributeType="float", attributeName="some/float", value=18.5), + ), + ], + ), + Mock( + experimentId="some-id-2", + attributes=[ + Mock( + path="some/int", + type="int", + intProperties=Mock(attributeType="int", attributeName="some/int", value=18), + ), + ], + ), + ], + ) + + # when + result = LeaderboardEntriesSearchResult.from_model(model) + + # then + assert result.matching_item_count == 2 + assert len(result.entries) == 2 + + entry_1 = result.entries[0] + assert entry_1.object_id == "some-id-1" + assert len(entry_1.fields) == 1 + assert isinstance(entry_1.fields[0], FloatField) + + entry_2 = result.entries[1] + assert entry_2.object_id == "some-id-2" + assert len(entry_2.fields) == 1 + assert isinstance(entry_2.fields[0], IntField) + + +@pytest.mark.parametrize("field_type", list(FieldType)) +def test__all_field_types__have_class(field_type): + # when + field_class = Field.by_type(field_type) + + # then + assert field_class is not None + assert field_class.type == field_type + + +def test__file_entry__from_model(): + # given + now = datetime.datetime.now() + + # and + model = Mock( + size=100, + mtime=now, + fileType="file", + ) + model.name = "mock_name" + + entry = FileEntry.from_dto(model) + + assert entry.name == "mock_name" + assert entry.size == 100 + assert entry.mtime == now + assert entry.file_type == "file" diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 6260b19d1..c17e3f13d 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -14,7 +14,8 @@ # limitations under the License. # from typing import ( - List, + Any, + Dict, Sequence, ) @@ -26,17 +27,17 @@ patch, ) +from neptune.api.models import ( + FloatField, + LeaderboardEntriesSearchResult, + LeaderboardEntry, + StringField, +) from neptune.api.searching_entries import ( get_single_page, iter_over_pages, - to_leaderboard_entry, ) from neptune.exceptions import NeptuneInvalidQueryException -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, - LeaderboardEntry, -) def test__to_leaderboard_entry(): @@ -48,6 +49,8 @@ def test__to_leaderboard_entry(): "name": "plugh", "type": "float", "floatProperties": { + "attributeName": "plugh", + "attributeType": "float", "value": 1.0, }, }, @@ -55,6 +58,8 @@ def test__to_leaderboard_entry(): "name": "sys/id", "type": "string", "stringProperties": { + "attributeName": "sys/id", + "attributeType": "string", "value": "TEST-123", }, }, @@ -62,37 +67,25 @@ def test__to_leaderboard_entry(): } # when - result = to_leaderboard_entry(entry=entry) + result = LeaderboardEntry.from_dict(entry) # then - assert result.id == "foo" - assert result.attributes == [ - AttributeWithProperties( - path="plugh", - type=AttributeType.FLOAT, - properties={ - "value": 1.0, - }, - ), - AttributeWithProperties( - path="sys/id", - type=AttributeType.STRING, - properties={ - "value": "TEST-123", - }, - ), + assert result.object_id == "foo" + assert result.fields == [ + FloatField(path="plugh", value=1.0), + StringField(path="sys/id", value="TEST-123"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) -def test__iter_over_pages__single_pagination(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__single_pagination(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 9}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -101,33 +94,38 @@ def test__iter_over_pages__single_pagination(get_single_page, entries_from_page) step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=9, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=9, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) -def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__multiple_search_after(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 9}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -136,7 +134,7 @@ def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_p step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, max_offset=6, @@ -144,22 +142,29 @@ def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_p ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), - call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="f"), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="f"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) -def test__iter_over_pages__empty(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__empty(get_single_page_mock): # given - entries_from_page.side_effect = [[]] + get_single_page_mock.side_effect = [ + {"matchingItemCount": 0}, + generate_leaderboard_entries(values=[]), + ] # when result = list( @@ -167,7 +172,7 @@ def test__iter_over_pages__empty(get_single_page, entries_from_page): step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) @@ -175,21 +180,21 @@ def test__iter_over_pages__empty(get_single_page, entries_from_page): # then assert result == [] - assert get_single_page.mock_calls == [ + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) -def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__max_server_offset(get_single_page_mock): # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 5}, generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -198,7 +203,7 @@ def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page) step_size=3, limit=None, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, max_offset=5, @@ -206,28 +211,33 @@ def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page) ) # then - assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) - assert get_single_page.mock_calls == [ + assert ( + result + == LeaderboardEntriesSearchResult.from_dict( + generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) + ).entries + ) + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="e"), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after="e"), ] -@patch("neptune.api.searching_entries._entries_from_page") -@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 5}) -def test__iter_over_pages__limit(get_single_page, entries_from_page): +@patch("neptune.api.searching_entries.get_single_page") +def test__iter_over_pages__limit(get_single_page_mock): # since the limiting itself takes place in an external service, we can't test the results # we can only test if the limit is properly passed to the external service call # given - entries_from_page.side_effect = [ + get_single_page_mock.side_effect = [ + {"matchingItemCount": 5}, generate_leaderboard_entries(values=["a", "b"]), generate_leaderboard_entries(values=["c", "d"]), generate_leaderboard_entries(values=["e"]), - [], + generate_leaderboard_entries(values=[]), ] # when @@ -236,29 +246,42 @@ def test__iter_over_pages__limit(get_single_page, entries_from_page): step_size=2, limit=4, sort_by="sys/id", - sort_by_column_type=None, + sort_by_column_type="string", ascending=False, progress_bar=None, ) ) # then - assert get_single_page.mock_calls == [ + assert get_single_page_mock.mock_calls == [ # total checking - call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=0, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), - call(offset=2, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=0, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=0, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), + call(offset=2, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type="string", searching_after=None), ] -def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> List[LeaderboardEntry]: - return [ - LeaderboardEntry( - id=experiment_id, - attributes=[AttributeWithProperties(path="sys/id", type=AttributeType.STRING, properties={"value": value})], - ) - for value in values - ] +def generate_leaderboard_entries(values: Sequence, experiment_id: str = "foo") -> Dict[str, Any]: + return { + "matchingItemCount": len(values), + "entries": [ + { + "experimentId": f"{experiment_id}-{value}", + "attributes": [ + { + "name": "sys/id", + "type": "string", + "stringProperties": { + "attributeName": "sys/id", + "attributeType": "string", + "value": value, + }, + }, + ], + } + for value in values + ], + } @patch("neptune.api.searching_entries.construct_request") diff --git a/tests/unit/neptune/new/attributes/test_attribute_utils.py b/tests/unit/neptune/new/attributes/test_attribute_utils.py index 4bbfc7735..5839101e2 100644 --- a/tests/unit/neptune/new/attributes/test_attribute_utils.py +++ b/tests/unit/neptune/new/attributes/test_attribute_utils.py @@ -16,9 +16,9 @@ import unittest from unittest.mock import MagicMock +from neptune.api.models import FieldType from neptune.attributes import create_attribute_from_type from neptune.attributes.attribute import Attribute -from neptune.internal.backends.api_model import AttributeType class TestAttributeUtils(unittest.TestCase): @@ -27,7 +27,6 @@ def test_attribute_type_to_atom(self): # ... and this reflection is class based on `Attribute` self.assertTrue( all( - isinstance(create_attribute_from_type(attr_type, MagicMock(), ""), Attribute) - for attr_type in AttributeType + isinstance(create_attribute_from_type(attr_type, MagicMock(), ""), Attribute) for attr_type in FieldType ) ) diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 1a5101b8d..0eee972b5 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -23,17 +23,28 @@ from mock import patch from neptune import ANONYMOUS_API_TOKEN +from neptune.api.models import ( + DateTimeField, + FieldDefinition, + FieldType, + FileField, + FileSetField, + FloatField, + FloatSeriesField, + GitCommit, + GitRefField, + ImageSeriesField, + LeaderboardEntry, + ObjectStateField, + StringField, + StringSeriesField, + StringSetField, +) from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, ) from neptune.exceptions import MetadataInconsistency -from neptune.internal.backends.api_model import ( - Attribute, - AttributeType, - AttributeWithProperties, - LeaderboardEntry, -) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.table import ( Table, @@ -43,7 +54,7 @@ @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute(path="test", type=AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition(path="test", type=FieldType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class AbstractTablesTestMixin: @@ -67,26 +78,20 @@ def setUp(cls) -> None: del os.environ[PROJECT_ENV_NAME] @staticmethod - def build_attributes_leaderboard(now: datetime): - attributes = [] - attributes.append(AttributeWithProperties("run/state", AttributeType.RUN_STATE, {"value": "idle"})) - attributes.append(AttributeWithProperties("float", AttributeType.FLOAT, {"value": 12.5})) - attributes.append(AttributeWithProperties("string", AttributeType.STRING, {"value": "some text"})) - attributes.append(AttributeWithProperties("datetime", AttributeType.DATETIME, {"value": now})) - attributes.append(AttributeWithProperties("float/series", AttributeType.FLOAT_SERIES, {"last": 8.7})) - attributes.append(AttributeWithProperties("string/series", AttributeType.STRING_SERIES, {"last": "last text"})) - attributes.append(AttributeWithProperties("string/set", AttributeType.STRING_SET, {"values": ["a", "b"]})) - attributes.append( - AttributeWithProperties( - "git/ref", - AttributeType.GIT_REF, - {"commit": {"commitId": "abcdef0123456789"}}, - ) - ) - attributes.append(AttributeWithProperties("file", AttributeType.FILE, None)) - attributes.append(AttributeWithProperties("file/set", AttributeType.FILE_SET, None)) - attributes.append(AttributeWithProperties("image/series", AttributeType.IMAGE_SERIES, None)) - return attributes + def build_fields_leaderboard(now: datetime): + return [ + ObjectStateField(path="run/state", value="Inactive"), + FloatField(path="float", value=12.5), + StringField(path="string", value="some text"), + DateTimeField(path="datetime", value=now), + FloatSeriesField(path="float/series", last=8.7), + StringSeriesField(path="string/series", last="last text"), + StringSetField(path="string/set", values={"a", "b"}), + GitRefField(path="git/ref", commit=GitCommit(commit_id="abcdef0123456789")), + FileField(path="file", size=0, name="file.txt", ext="txt"), + FileSetField(path="file/set", size=0), + ImageSeriesField(path="image/series", last_step=None), + ] @patch.object(NeptuneBackendMock, "search_leaderboard_entries") def test_get_table_with_columns_filter(self, search_leaderboard_entries): @@ -102,11 +107,11 @@ def test_get_table_with_columns_filter(self, search_leaderboard_entries): def test_get_table_as_pandas(self, search_leaderboard_entries): # given now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) - filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + empty_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=[]) + filled_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=fields) search_leaderboard_entries.return_value = [empty_entry, filled_entry] # when @@ -119,7 +124,7 @@ def test_get_table_as_pandas(self, search_leaderboard_entries): self.assertEqual(now, df["datetime"][1]) self.assertEqual(8.7, df["float/series"][1]) self.assertEqual("last text", df["string/series"][1]) - self.assertEqual("a,b", df["string/set"][1]) + self.assertEqual({"a", "b"}, set(df["string/set"][1].split(","))) self.assertEqual("abcdef0123456789", df["git/ref"][1]) with self.assertRaises(KeyError): @@ -133,11 +138,11 @@ def test_get_table_as_pandas(self, search_leaderboard_entries): def test_get_table_as_rows(self, search_leaderboard_entries): # given now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) - filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + empty_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=[]) + filled_entry = LeaderboardEntry(object_id=str(uuid.uuid4()), fields=fields) search_leaderboard_entries.return_value = [empty_entry, filled_entry] # and @@ -173,10 +178,10 @@ def test_get_table_as_table_entries( # given exp_id = str(uuid.uuid4()) now = datetime.now() - attributes = self.build_attributes_leaderboard(now) + fields = self.build_fields_leaderboard(now) # and - search_leaderboard_entries.return_value = [LeaderboardEntry(exp_id, attributes)] + search_leaderboard_entries.return_value = [LeaderboardEntry(object_id=exp_id, fields=fields)] # when table_entry = self.get_table_entries(table=self.get_table())[0] diff --git a/tests/unit/neptune/new/client/test_model.py b/tests/unit/neptune/new/client/test_model.py index d34982c9c..30acb6c40 100644 --- a/tests/unit/neptune/new/client/test_model.py +++ b/tests/unit/neptune/new/client/test_model.py @@ -23,6 +23,11 @@ ANONYMOUS_API_TOKEN, init_model, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, +) from neptune.attributes import String from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -32,13 +37,9 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.internal.backends.api_model import ( - Attribute, - AttributeType, - IntAttribute, -) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, @@ -77,11 +78,11 @@ def test_offline_mode(self): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", FieldType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(path=path_to_str(_path), value=42), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -102,7 +103,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) def test_resume(self): with init_model(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_model_version.py b/tests/unit/neptune/new/client/test_model_version.py index 9583dcadd..14ca4404b 100644 --- a/tests/unit/neptune/new/client/test_model_version.py +++ b/tests/unit/neptune/new/client/test_model_version.py @@ -23,6 +23,12 @@ ANONYMOUS_API_TOKEN, init_model_version, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, + StringField, +) from neptune.attributes import String from neptune.envs import ( API_TOKEN_ENV_NAME, @@ -33,15 +39,10 @@ NeptuneUnsupportedFunctionalityException, NeptuneWrongInitParametersException, ) -from neptune.internal.backends.api_model import ( - Attribute, - AttributeType, - IntAttribute, - StringAttribute, -) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, @@ -87,17 +88,17 @@ def test_offline_mode(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - Attribute("some/variable", AttributeType.INT), - Attribute("sys/model_id", AttributeType.STRING), + FieldDefinition("some/variable", FieldType.INT), + FieldDefinition("sys/model_id", FieldType.STRING), ], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(path=path_to_str(_path), value=42), ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringAttribute("MDL"), + new=lambda _, _uuid, _type, _path: StringField(path=path_to_str(_path), value="MDL"), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -115,13 +116,13 @@ def test_read_only_mode(self, warn_once): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", new=lambda _, _uuid, _type: [ - Attribute("test", AttributeType.STRING), - Attribute("sys/model_id", AttributeType.STRING), + FieldDefinition("test", FieldType.STRING), + FieldDefinition("sys/model_id", FieldType.STRING), ], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_string_attribute", - new=lambda _, _uuid, _type, _path: StringAttribute("MDL"), + new=lambda _, _uuid, _type, _path: StringField(path=path_to_str(_path), value="MDL"), ) def test_resume(self): with init_model_version(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_project.py b/tests/unit/neptune/new/client/test_project.py index 41c8d81df..e73284c8c 100644 --- a/tests/unit/neptune/new/client/test_project.py +++ b/tests/unit/neptune/new/client/test_project.py @@ -15,7 +15,6 @@ # import os import unittest -from datetime import datetime import pytest from mock import patch @@ -24,6 +23,11 @@ ANONYMOUS_API_TOKEN, init_project, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, +) from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, @@ -32,29 +36,20 @@ NeptuneMissingProjectNameException, NeptuneUnsupportedFunctionalityException, ) -from neptune.internal.backends.api_model import ( - Attribute, - AttributeType, - AttributeWithProperties, - IntAttribute, - LeaderboardEntry, -) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.exceptions import NeptuneException +from neptune.internal.utils.paths import path_to_str from neptune.internal.warnings import ( NeptuneWarning, warned_once, ) -from neptune.objects.utils import ( - parse_dates, - prepare_nql_query, -) +from neptune.objects.utils import prepare_nql_query from tests.unit.neptune.new.client.abstract_experiment_test_mixin import AbstractExperimentTestMixin @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) @patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock) class TestClientProject(AbstractExperimentTestMixin, unittest.TestCase): @@ -100,7 +95,7 @@ def test_project_name_env_var(self): @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(value=42, path=path_to_str(_path)), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -168,49 +163,3 @@ def test_prepare_nql_query(): trashed=None, ) assert len(query.items) == 0 - - -def test_parse_dates(): - def entries_generator(): - yield LeaderboardEntry( - id="test", - attributes=[ - AttributeWithProperties( - "attr1", - AttributeType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, - ), - AttributeWithProperties( - "attr2", - AttributeType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, - ), - ], - ) - - parsed = list(parse_dates(entries_generator())) - assert parsed[0].attributes[0].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) - assert parsed[0].attributes[1].properties["value"] == datetime(2024, 2, 5, 20, 37, 40, 915000) - - -@patch("neptune.objects.utils.warn_once") -def test_parse_dates_wrong_format(mock_warn_once): - entries = [ - LeaderboardEntry( - id="test", - attributes=[ - AttributeWithProperties( - "attr1", - AttributeType.DATETIME, - {"value": "07-02-2024"}, # different format than expected - ) - ], - ) - ] - - parsed = list(parse_dates(entries)) - assert parsed[0].attributes[0].properties["value"] == "07-02-2024" # should be left unchanged due to ValueError - mock_warn_once.assert_called_once_with( - "Date parsing failed. The date format is incorrect. Returning as string instead of datetime.", - exception=NeptuneWarning, - ) diff --git a/tests/unit/neptune/new/client/test_run.py b/tests/unit/neptune/new/client/test_run.py index e515f2d85..765d724bc 100644 --- a/tests/unit/neptune/new/client/test_run.py +++ b/tests/unit/neptune/new/client/test_run.py @@ -27,18 +27,19 @@ ANONYMOUS_API_TOKEN, init_run, ) +from neptune.api.models import ( + FieldDefinition, + FieldType, + IntField, +) from neptune.attributes.atoms import String from neptune.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, ) from neptune.exceptions import MissingFieldException -from neptune.internal.backends.api_model import ( - Attribute, - AttributeType, - IntAttribute, -) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock +from neptune.internal.utils.paths import path_to_str from neptune.internal.utils.utils import IS_WINDOWS from neptune.internal.warnings import ( NeptuneWarning, @@ -68,11 +69,11 @@ def setUpClass(cls) -> None: ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("some/variable", AttributeType.INT)], + new=lambda _, _uuid, _type: [FieldDefinition("some/variable", FieldType.INT)], ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_int_attribute", - new=lambda _, _uuid, _type, _path: IntAttribute(42), + new=lambda _, _uuid, _type, _path: IntField(value=42, path=path_to_str(_path)), ) @patch("neptune.internal.operation_processors.read_only_operation_processor.warn_once") def test_read_only_mode(self, warn_once): @@ -93,7 +94,7 @@ def test_read_only_mode(self, warn_once): ) @patch( "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.get_attributes", - new=lambda _, _uuid, _type: [Attribute("test", AttributeType.STRING)], + new=lambda _, _uuid, _type: [FieldDefinition("test", FieldType.STRING)], ) def test_resume(self): with init_run(flush_period=0.5, with_id="whatever") as exp: diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index 8b5a75869..d1fe07f5f 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -21,9 +21,8 @@ from mock import patch from neptune import init_project -from neptune.internal.backends.api_model import ( - AttributeType, - AttributeWithProperties, +from neptune.api.models import ( + DateTimeField, LeaderboardEntry, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock @@ -66,12 +65,11 @@ def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): "neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.search_leaderboard_entries", new=lambda *args, **kwargs: [ LeaderboardEntry( - id="123", - attributes=[ - AttributeWithProperties( - "sys/creation_time", - AttributeType.DATETIME, - {"value": "2024-02-05T20:37:40.915000Z"}, + object_id="123", + fields=[ + DateTimeField( + path="sys/creation_time", + value=datetime(2024, 2, 5, 20, 37, 40, 915000), ) ], ) diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_client.py b/tests/unit/neptune/new/internal/backends/test_hosted_client.py index aefe62b1a..ba77f1e72 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_client.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_client.py @@ -32,7 +32,7 @@ patch, ) -from neptune.internal.backends.api_model import AttributeType +from neptune.api.models import FieldType from neptune.internal.backends.hosted_client import ( DEFAULT_REQUEST_KWARGS, _get_token_client, @@ -545,15 +545,15 @@ class DTO: # when test_cases = [ {"entries": [], "exc": ValueError}, - {"entries": [DTO(type="float")], "result": AttributeType.FLOAT.value}, - {"entries": [DTO(type="string")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float")], "result": FieldType.FLOAT.value}, + {"entries": [DTO(type="string")], "result": FieldType.STRING.value}, {"entries": [DTO(type="float"), DTO(type="floatSeries")], "exc": ValueError}, - {"entries": [DTO(type="float"), DTO(type="int")], "result": AttributeType.FLOAT.value}, - {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": AttributeType.STRING.value}, - {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="int")], "result": FieldType.FLOAT.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": FieldType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": FieldType.STRING.value}, { "entries": [DTO(type="float"), DTO(type="int"), DTO(type="string", name="test_column_different")], - "result": AttributeType.FLOAT.value, + "result": FieldType.FLOAT.value, }, ] diff --git a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py index a83f7da84..b8ff429d8 100644 --- a/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py +++ b/tests/unit/neptune/new/internal/backends/test_neptune_backend_mock.py @@ -21,22 +21,24 @@ from pathlib import Path from time import time +from neptune.api.models import ( + DateTimeField, + FloatField, + FloatSeriesField, + StringField, + StringSeriesField, + StringSetField, +) from neptune.core.components.operation_storage import OperationStorage from neptune.exceptions import ( ContainerUUIDNotFound, MetadataInconsistency, ) from neptune.internal.backends.api_model import ( - DatetimeAttribute, - FloatAttribute, FloatPointValue, - FloatSeriesAttribute, FloatSeriesValues, - StringAttribute, StringPointValue, - StringSeriesAttribute, StringSeriesValues, - StringSetAttribute, ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType @@ -78,36 +80,38 @@ def test_get_float_attribute(self): with self.subTest(f"For containerType: {container_type}"): # given digit = random.randint(1, 10**4) + path = ["x"] self.backend.execute_operations( container_id, container_type, - operations=[AssignFloat(["x"], digit)], + operations=[AssignFloat(path, digit)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_float_attribute(container_id, container_type, path=["x"]) + ret = self.backend.get_float_attribute(container_id, container_type, path=path) # then - self.assertEqual(FloatAttribute(digit), ret) + self.assertEqual(FloatField(path="x", value=digit), ret) def test_get_string_attribute(self): for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): # given text = a_string() + path = ["x"] self.backend.execute_operations( container_id, container_type, - operations=[AssignString(["x"], text)], + operations=[AssignString(path, text)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_string_attribute(container_id, container_type, path=["x"]) + ret = self.backend.get_string_attribute(container_id, container_type, path=path) # then - self.assertEqual(StringAttribute(text), ret) + self.assertEqual(StringField(path="x", value=text), ret) def test_get_datetime_attribute(self): for container_id, container_type in self.ids_with_types: @@ -115,21 +119,27 @@ def test_get_datetime_attribute(self): # given now = datetime.datetime.now() now = now.replace(microsecond=1000 * int(now.microsecond / 1000)) + path = ["x"] + + # and self.backend.execute_operations( container_id, container_type, - [AssignDatetime(["x"], now)], + [AssignDatetime(path, now)], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_datetime_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_datetime_attribute(container_id, container_type, path) # then - self.assertEqual(DatetimeAttribute(now), ret) + self.assertEqual(DateTimeField(path="x", value=now), ret) def test_get_float_series_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( @@ -137,7 +147,7 @@ def test_get_float_series_attribute(self): container_type, [ LogFloats( - ["x"], + path, [ LogFloats.ValueType(5, None, time()), LogFloats.ValueType(3, None, time()), @@ -151,7 +161,7 @@ def test_get_float_series_attribute(self): container_type, [ LogFloats( - ["x"], + path, [ LogFloats.ValueType(2, None, time()), LogFloats.ValueType(9, None, time()), @@ -162,13 +172,16 @@ def test_get_float_series_attribute(self): ) # when - ret = self.backend.get_float_series_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_float_series_attribute(container_id, container_type, path) # then - self.assertEqual(FloatSeriesAttribute(9), ret) + self.assertEqual(FloatSeriesField(last=9, path="x"), ret) def test_get_string_series_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( @@ -176,7 +189,7 @@ def test_get_string_series_attribute(self): container_type, [ LogStrings( - ["x"], + path, [ LogStrings.ValueType("adf", None, time()), LogStrings.ValueType("sdg", None, time()), @@ -190,7 +203,7 @@ def test_get_string_series_attribute(self): container_type, [ LogStrings( - ["x"], + path, [ LogStrings.ValueType("dfh", None, time()), LogStrings.ValueType("qwe", None, time()), @@ -201,27 +214,30 @@ def test_get_string_series_attribute(self): ) # when - ret = self.backend.get_string_series_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_string_series_attribute(container_id, container_type, path) # then - self.assertEqual(StringSeriesAttribute("qwe"), ret) + self.assertEqual(StringSeriesField(last="qwe", path="x"), ret) def test_get_string_set_attribute(self): # given + path = ["x"] + + # and for container_id, container_type in self.ids_with_types: with self.subTest(f"For containerType: {container_type}"): self.backend.execute_operations( container_id, container_type, - [AddStrings(["x"], {"abcx", "qwe"})], + [AddStrings(path, {"abcx", "qwe"})], operation_storage=self.dummy_operation_storage, ) # when - ret = self.backend.get_string_set_attribute(container_id, container_type, ["x"]) + ret = self.backend.get_string_set_attribute(container_id, container_type, path) # then - self.assertEqual(StringSetAttribute({"abcx", "qwe"}), ret) + self.assertEqual(StringSetField(values={"abcx", "qwe"}, path="x"), ret) def test_get_string_series_values(self): # given diff --git a/tests/unit/neptune/new/internal/test_file_entry.py b/tests/unit/neptune/new/internal/test_file_entry.py deleted file mode 100644 index ba3d9844c..000000000 --- a/tests/unit/neptune/new/internal/test_file_entry.py +++ /dev/null @@ -1,24 +0,0 @@ -import datetime -from dataclasses import dataclass - -from neptune.api.dtos import FileEntry - - -def test_file_entry_from_dto(): - now = datetime.datetime.now() - - @dataclass - class MockDto: - name: str - size: int - mtime: datetime.datetime - fileType: str - - dto = MockDto("mock_name", 100, now, "file") - - entry = FileEntry.from_dto(dto) - - assert entry.name == "mock_name" - assert entry.size == 100 - assert entry.mtime == now - assert entry.file_type == "file"