Skip to content

Commit

Permalink
Add iterators for UUID nodes (#439)
Browse files Browse the repository at this point in the history
* add untested iterator functionality

* add alphabetical order

* strictly enforce that UUID is a str

* right order of decorators

* fix deepcopy problem

* Fix uid cycle deserialization (#442)
  • Loading branch information
InnocentBug authored Mar 14, 2024
1 parent 3cf3c4e commit 74104b1
Show file tree
Hide file tree
Showing 24 changed files with 826 additions and 318 deletions.
50 changes: 15 additions & 35 deletions src/cript/nodes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import re
import uuid
from abc import ABC
from dataclasses import asdict, dataclass, replace
from dataclasses import dataclass, replace
from typing import Dict, List, Optional, Set

from cript.nodes.exceptions import (
CRIPTAttributeModificationError,
CRIPTExtraJsonAttributes,
CRIPTJsonSerializationError,
)
from cript.nodes.node_iterator import NodeIterator

tolerated_extra_json = []

Expand Down Expand Up @@ -102,7 +103,7 @@ def __str__(self) -> str:
str
A string representation of the node.
"""
return str(asdict(self._json_attrs))
return str(self._json_attrs)

@property
def uid(self):
Expand Down Expand Up @@ -200,6 +201,7 @@ def _from_json(cls, json_dict: dict):
attrs = replace(attrs, uid="_:" + attrs.uid)
except AttributeError:
pass

# But here we force even usually unwritable fields to be set.
node._update_json_attrs_if_valid(attrs)

Expand Down Expand Up @@ -466,7 +468,7 @@ class ReturnTuple:
if is_patch:
del tmp_dict["uuid"] # patches do not allow UUID is the parent most node

return ReturnTuple(json.dumps(tmp_dict), tmp_dict, NodeEncoder.handled_ids)
return ReturnTuple(json.dumps(tmp_dict, **kwargs), tmp_dict, NodeEncoder.handled_ids)
except Exception as exc:
# TODO this handling that doesn't tell the user what happened and how they can fix it
# this just tells the user that something is wrong
Expand Down Expand Up @@ -600,40 +602,18 @@ def is_attr_present(node: BaseNode, key, value):
if handled_nodes is None:
handled_nodes = []

# Protect against cycles in graph, by handling every instance of a node only once
if self in handled_nodes:
return []
handled_nodes += [self]

found_children = []

# In this search we include the calling node itself.
# We check for this node if all specified attributes are present by counting them (AND condition).
found_attr = 0
for key, value in search_attr.items():
if is_attr_present(self, key, value):
found_attr += 1
# If exactly all attributes are found, it matches the search criterion
if found_attr == len(search_attr):
found_children += [self]

# Recursion according to the recursion depth for all node children.
if search_depth != 0:
# Loop over all attributes, runtime contribution (none, or constant (max number of attributes of a node)
for field in self._json_attrs.__dataclass_fields__:
value = getattr(self._json_attrs, field)
# To save code paths, I convert non-lists into lists with one element.
if not isinstance(value, list):
value = [value]
# Run time contribution: number of elements in the attribute list.
for v in value:
try: # Try every attribute for recursion (duck-typing)
found_children += v.find_children(search_attr, search_depth - 1, handled_nodes=handled_nodes)
except AttributeError:
pass
# Total runtime, of non-recursive call: O(m*h) + O(k) where k is the number of children for this node,
# h being the depth of the search dictionary, m being the number of nodes in the attribute list.
# Total runtime, with recursion: O(n*(k+m*h). A full graph traversal O(n) with a cost per node, that scales with the number of children per node and the search depth of the search dictionary.
node_iterator = NodeIterator(self, search_depth)
for node in node_iterator:
found_attr = 0
for key, value in search_attr.items():
if is_attr_present(node, key, value):
found_attr += 1
# If exactly all attributes are found, it matches the search criterion
if found_attr == len(search_attr):
found_children += [node]

return found_children

def remove_child(self, child) -> bool:
Expand Down
68 changes: 68 additions & 0 deletions src/cript/nodes/node_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import fields
from typing import Any, List, Set


class NodeIterator:
def __init__(self, root, max_recursion_depth=-1):
self._iter_position: int = 0
self._uuid_visited: Set[str] = set()
self._stack: List[Any] = []
self._recursion_depth = []
self._max_recursion_depth = max_recursion_depth
self._depth_first(root, 0)

def _add_node(self, child_node, recursion_depth: int):
if child_node.uuid not in self._uuid_visited:
self._stack.append(child_node)
self._recursion_depth.append(recursion_depth)
self._uuid_visited.add(child_node.uuid)
return True
return False

def _check_recursion(self, child_node) -> bool:
"""Helper function that adds a child to the stack.
This function can be called for both listed children and regular children attributes
"""
try:
uuid = child_node.uuid
except AttributeError:
return False

if uuid not in self._uuid_visited:
return True
return False

def _depth_first(self, node, recursion_depth: int) -> None:
"""Helper function that does the traversal in depth first order and stores result in stack"""
node_added = self._add_node(node, recursion_depth)
if not node_added:
return

if self._max_recursion_depth >= 0 and recursion_depth >= self._max_recursion_depth:
return

field_names = [field.name for field in fields(node._json_attrs)]
for attr_name in sorted(field_names):
attr = getattr(node._json_attrs, attr_name)
if not isinstance(attr, list):
attr = [attr]
for list_attr in attr:
if self._check_recursion(list_attr):
self._depth_first(list_attr, recursion_depth + 1)

def __next__(self):
if self._iter_position >= len(self._stack):
raise StopIteration
self._iter_position += 1
return self._stack[self._iter_position - 1]

def __iter__(self):
self._iter_position = 0
return self

def __len__(self):
return len(self._stack)

def __getitem__(self, idx: int):
return self._stack[idx]
21 changes: 12 additions & 9 deletions src/cript/nodes/primary_nodes/collection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.supporting_nodes import User
from cript.nodes.util.json import UIDProxy


class Collection(PrimaryBaseNode):
Expand Down Expand Up @@ -56,17 +57,19 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):
"""

# TODO add proper typing in future, using Any for now to avoid circular import error
member: List[User] = field(default_factory=list)
admin: List[User] = field(default_factory=list)
experiment: List[Any] = field(default_factory=list)
inventory: List[Any] = field(default_factory=list)
member: List[Union[User, UIDProxy]] = field(default_factory=list)
admin: List[Union[User, UIDProxy]] = field(default_factory=list)
experiment: List[Union[Any, UIDProxy]] = field(default_factory=list)
inventory: List[Union[Any, UIDProxy]] = field(default_factory=list)
doi: str = ""
citation: List[Any] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

@beartype
def __init__(self, name: str, experiment: Optional[List[Any]] = None, inventory: Optional[List[Any]] = None, doi: str = "", citation: Optional[List[Any]] = None, notes: str = "", **kwargs) -> None:
def __init__(
self, name: str, experiment: Optional[List[Union[Any, UIDProxy]]] = None, inventory: Optional[List[Union[Any, UIDProxy]]] = None, doi: str = "", citation: Optional[List[Union[Any, UIDProxy]]] = None, notes: str = "", **kwargs
) -> None:
"""
create a Collection with a name
add list of experiment, inventory, citation, doi, and notes if available.
Expand Down Expand Up @@ -117,12 +120,12 @@ def __init__(self, name: str, experiment: Optional[List[Any]] = None, inventory:

@property
@beartype
def member(self) -> List[User]:
def member(self) -> List[Union[User, UIDProxy]]:
return self._json_attrs.member.copy()

@property
@beartype
def admin(self) -> List[User]:
def admin(self) -> List[Union[User, UIDProxy]]:
return self._json_attrs.admin

@property
Expand Down
31 changes: 16 additions & 15 deletions src/cript/nodes/primary_nodes/computation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class Computation(PrimaryBaseNode):
Expand Down Expand Up @@ -64,12 +65,12 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
input_data: List[Any] = field(default_factory=list)
output_data: List[Any] = field(default_factory=list)
software_configuration: List[Any] = field(default_factory=list)
condition: List[Any] = field(default_factory=list)
prerequisite_computation: Optional["Computation"] = None
citation: List[Any] = field(default_factory=list)
input_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
output_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
software_configuration: List[Union[Any, UIDProxy]] = field(default_factory=list)
condition: List[Union[Any, UIDProxy]] = field(default_factory=list)
prerequisite_computation: Optional[Union["Computation", UIDProxy]] = None
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -78,12 +79,12 @@ def __init__(
self,
name: str,
type: str,
input_data: Optional[List[Any]] = None,
output_data: Optional[List[Any]] = None,
software_configuration: Optional[List[Any]] = None,
condition: Optional[List[Any]] = None,
prerequisite_computation: Optional["Computation"] = None,
citation: Optional[List[Any]] = None,
input_data: Optional[List[Union[Any, UIDProxy]]] = None,
output_data: Optional[List[Union[Any, UIDProxy]]] = None,
software_configuration: Optional[List[Union[Any, UIDProxy]]] = None,
condition: Optional[List[Union[Any, UIDProxy]]] = None,
prerequisite_computation: Optional[Union["Computation", UIDProxy]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
) -> None:
Expand Down Expand Up @@ -364,7 +365,7 @@ def condition(self, new_condition_list: List[Any]) -> None:

@property
@beartype
def prerequisite_computation(self) -> Optional["Computation"]:
def prerequisite_computation(self) -> Optional[Union["Computation", UIDProxy]]:
"""
prerequisite computation
Expand All @@ -386,7 +387,7 @@ def prerequisite_computation(self) -> Optional["Computation"]:

@prerequisite_computation.setter
@beartype
def prerequisite_computation(self, new_prerequisite_computation: Optional["Computation"]) -> None:
def prerequisite_computation(self, new_prerequisite_computation: Optional[Union["Computation", UIDProxy]]) -> None:
"""
set new prerequisite_computation
Expand Down
31 changes: 16 additions & 15 deletions src/cript/nodes/primary_nodes/computation_process.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class ComputationProcess(PrimaryBaseNode):
Expand Down Expand Up @@ -113,13 +114,13 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
input_data: List[Any] = field(default_factory=list)
output_data: List[Any] = field(default_factory=list)
ingredient: List[Any] = field(default_factory=list)
software_configuration: List[Any] = field(default_factory=list)
condition: List[Any] = field(default_factory=list)
property: List[Any] = field(default_factory=list)
citation: List[Any] = field(default_factory=list)
input_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
output_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
ingredient: List[Union[Any, UIDProxy]] = field(default_factory=list)
software_configuration: List[Union[Any, UIDProxy]] = field(default_factory=list)
condition: List[Union[Any, UIDProxy]] = field(default_factory=list)
property: List[Union[Any, UIDProxy]] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -128,13 +129,13 @@ def __init__(
self,
name: str,
type: str,
input_data: List[Any],
ingredient: List[Any],
output_data: Optional[List[Any]] = None,
software_configuration: Optional[List[Any]] = None,
condition: Optional[List[Any]] = None,
property: Optional[List[Any]] = None,
citation: Optional[List[Any]] = None,
input_data: List[Union[Any, UIDProxy]],
ingredient: List[Union[Any, UIDProxy]],
output_data: Optional[List[Union[Any, UIDProxy]]] = None,
software_configuration: Optional[List[Union[Any, UIDProxy]]] = None,
condition: Optional[List[Union[Any, UIDProxy]]] = None,
property: Optional[List[Union[Any, UIDProxy]]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
):
Expand Down
29 changes: 15 additions & 14 deletions src/cript/nodes/primary_nodes/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class Data(PrimaryBaseNode):
Expand Down Expand Up @@ -74,13 +75,13 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
file: List[Any] = field(default_factory=list)
sample_preparation: Any = field(default_factory=list)
computation: List[Any] = field(default_factory=list)
computation_process: Any = field(default_factory=list)
material: List[Any] = field(default_factory=list)
process: List[Any] = field(default_factory=list)
citation: List[Any] = field(default_factory=list)
file: List[Union[Any, UIDProxy]] = field(default_factory=list)
sample_preparation: Union[Any, UIDProxy] = field(default_factory=list)
computation: List[Union[Any, UIDProxy]] = field(default_factory=list)
computation_process: Union[Any, UIDProxy] = field(default_factory=list)
material: List[Union[Any, UIDProxy]] = field(default_factory=list)
process: List[Union[Any, UIDProxy]] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -89,13 +90,13 @@ def __init__(
self,
name: str,
type: str,
file: List[Any],
sample_preparation: Any = None,
computation: Optional[List[Any]] = None,
computation_process: Optional[Any] = None,
material: Optional[List[Any]] = None,
process: Optional[List[Any]] = None,
citation: Optional[List[Any]] = None,
file: List[Union[Any, UIDProxy]],
sample_preparation: Union[Any, UIDProxy] = None,
computation: Optional[List[Union[Any, UIDProxy]]] = None,
computation_process: Optional[Union[Any, UIDProxy]] = None,
material: Optional[List[Union[Any, UIDProxy]]] = None,
process: Optional[List[Union[Any, UIDProxy]]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
) -> None:
Expand Down
Loading

0 comments on commit 74104b1

Please sign in to comment.