Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add iterators for UUID nodes #439

Merged
merged 14 commits into from
Mar 14, 2024
50 changes: 15 additions & 35 deletions src/cript/nodes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import re
import uuid
from abc import ABC
from dataclasses import asdict, dataclass, replace
from dataclasses import dataclass, replace
from typing import Dict, List, Optional, Set

from cript.nodes.exceptions import (
CRIPTAttributeModificationError,
CRIPTExtraJsonAttributes,
CRIPTJsonSerializationError,
)
from cript.nodes.node_iterator import NodeIterator

tolerated_extra_json = []

Expand Down Expand Up @@ -102,7 +103,7 @@ def __str__(self) -> str:
str
A string representation of the node.
"""
return str(asdict(self._json_attrs))
return str(self._json_attrs)

@property
def uid(self):
Expand Down Expand Up @@ -200,6 +201,7 @@ def _from_json(cls, json_dict: dict):
attrs = replace(attrs, uid="_:" + attrs.uid)
except AttributeError:
pass

# But here we force even usually unwritable fields to be set.
node._update_json_attrs_if_valid(attrs)

Expand Down Expand Up @@ -466,7 +468,7 @@ class ReturnTuple:
if is_patch:
del tmp_dict["uuid"] # patches do not allow UUID is the parent most node

return ReturnTuple(json.dumps(tmp_dict), tmp_dict, NodeEncoder.handled_ids)
return ReturnTuple(json.dumps(tmp_dict, **kwargs), tmp_dict, NodeEncoder.handled_ids)
except Exception as exc:
# TODO this handling that doesn't tell the user what happened and how they can fix it
# this just tells the user that something is wrong
Expand Down Expand Up @@ -600,40 +602,18 @@ def is_attr_present(node: BaseNode, key, value):
if handled_nodes is None:
handled_nodes = []

# Protect against cycles in graph, by handling every instance of a node only once
if self in handled_nodes:
return []
handled_nodes += [self]

found_children = []

# In this search we include the calling node itself.
# We check for this node if all specified attributes are present by counting them (AND condition).
found_attr = 0
for key, value in search_attr.items():
if is_attr_present(self, key, value):
found_attr += 1
# If exactly all attributes are found, it matches the search criterion
if found_attr == len(search_attr):
found_children += [self]

# Recursion according to the recursion depth for all node children.
if search_depth != 0:
# Loop over all attributes, runtime contribution (none, or constant (max number of attributes of a node)
for field in self._json_attrs.__dataclass_fields__:
value = getattr(self._json_attrs, field)
# To save code paths, I convert non-lists into lists with one element.
if not isinstance(value, list):
value = [value]
# Run time contribution: number of elements in the attribute list.
for v in value:
try: # Try every attribute for recursion (duck-typing)
found_children += v.find_children(search_attr, search_depth - 1, handled_nodes=handled_nodes)
except AttributeError:
pass
# Total runtime, of non-recursive call: O(m*h) + O(k) where k is the number of children for this node,
# h being the depth of the search dictionary, m being the number of nodes in the attribute list.
# Total runtime, with recursion: O(n*(k+m*h). A full graph traversal O(n) with a cost per node, that scales with the number of children per node and the search depth of the search dictionary.
node_iterator = NodeIterator(self, search_depth)
for node in node_iterator:
found_attr = 0
for key, value in search_attr.items():
if is_attr_present(node, key, value):
found_attr += 1
# If exactly all attributes are found, it matches the search criterion
if found_attr == len(search_attr):
found_children += [node]

return found_children

def remove_child(self, child) -> bool:
Expand Down
68 changes: 68 additions & 0 deletions src/cript/nodes/node_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import fields
from typing import Any, List, Set


class NodeIterator:
def __init__(self, root, max_recursion_depth=-1):
self._iter_position: int = 0
self._uuid_visited: Set[str] = set()
self._stack: List[Any] = []
self._recursion_depth = []
self._max_recursion_depth = max_recursion_depth
self._depth_first(root, 0)

def _add_node(self, child_node, recursion_depth: int):
if child_node.uuid not in self._uuid_visited:
self._stack.append(child_node)
self._recursion_depth.append(recursion_depth)
self._uuid_visited.add(child_node.uuid)
return True
return False

def _check_recursion(self, child_node) -> bool:
"""Helper function that adds a child to the stack.

This function can be called for both listed children and regular children attributes
"""
try:
uuid = child_node.uuid
except AttributeError:
return False

if uuid not in self._uuid_visited:
return True
return False

def _depth_first(self, node, recursion_depth: int) -> None:
"""Helper function that does the traversal in depth first order and stores result in stack"""
node_added = self._add_node(node, recursion_depth)
if not node_added:
return

if self._max_recursion_depth >= 0 and recursion_depth >= self._max_recursion_depth:
return

field_names = [field.name for field in fields(node._json_attrs)]
for attr_name in sorted(field_names):
attr = getattr(node._json_attrs, attr_name)
if not isinstance(attr, list):
attr = [attr]
for list_attr in attr:
if self._check_recursion(list_attr):
self._depth_first(list_attr, recursion_depth + 1)

def __next__(self):
if self._iter_position >= len(self._stack):
raise StopIteration
self._iter_position += 1
return self._stack[self._iter_position - 1]

def __iter__(self):
self._iter_position = 0
return self

def __len__(self):
return len(self._stack)

def __getitem__(self, idx: int):
return self._stack[idx]
21 changes: 12 additions & 9 deletions src/cript/nodes/primary_nodes/collection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.supporting_nodes import User
from cript.nodes.util.json import UIDProxy


class Collection(PrimaryBaseNode):
Expand Down Expand Up @@ -56,17 +57,19 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):
"""

# TODO add proper typing in future, using Any for now to avoid circular import error
member: List[User] = field(default_factory=list)
admin: List[User] = field(default_factory=list)
experiment: List[Any] = field(default_factory=list)
inventory: List[Any] = field(default_factory=list)
member: List[Union[User, UIDProxy]] = field(default_factory=list)
admin: List[Union[User, UIDProxy]] = field(default_factory=list)
experiment: List[Union[Any, UIDProxy]] = field(default_factory=list)
inventory: List[Union[Any, UIDProxy]] = field(default_factory=list)
doi: str = ""
citation: List[Any] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

@beartype
def __init__(self, name: str, experiment: Optional[List[Any]] = None, inventory: Optional[List[Any]] = None, doi: str = "", citation: Optional[List[Any]] = None, notes: str = "", **kwargs) -> None:
def __init__(
self, name: str, experiment: Optional[List[Union[Any, UIDProxy]]] = None, inventory: Optional[List[Union[Any, UIDProxy]]] = None, doi: str = "", citation: Optional[List[Union[Any, UIDProxy]]] = None, notes: str = "", **kwargs
) -> None:
"""
create a Collection with a name
add list of experiment, inventory, citation, doi, and notes if available.
Expand Down Expand Up @@ -117,12 +120,12 @@ def __init__(self, name: str, experiment: Optional[List[Any]] = None, inventory:

@property
@beartype
def member(self) -> List[User]:
def member(self) -> List[Union[User, UIDProxy]]:
return self._json_attrs.member.copy()

@property
@beartype
def admin(self) -> List[User]:
def admin(self) -> List[Union[User, UIDProxy]]:
return self._json_attrs.admin

@property
Expand Down
31 changes: 16 additions & 15 deletions src/cript/nodes/primary_nodes/computation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class Computation(PrimaryBaseNode):
Expand Down Expand Up @@ -64,12 +65,12 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
input_data: List[Any] = field(default_factory=list)
output_data: List[Any] = field(default_factory=list)
software_configuration: List[Any] = field(default_factory=list)
condition: List[Any] = field(default_factory=list)
prerequisite_computation: Optional["Computation"] = None
citation: List[Any] = field(default_factory=list)
input_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
output_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
software_configuration: List[Union[Any, UIDProxy]] = field(default_factory=list)
condition: List[Union[Any, UIDProxy]] = field(default_factory=list)
prerequisite_computation: Optional[Union["Computation", UIDProxy]] = None
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -78,12 +79,12 @@ def __init__(
self,
name: str,
type: str,
input_data: Optional[List[Any]] = None,
output_data: Optional[List[Any]] = None,
software_configuration: Optional[List[Any]] = None,
condition: Optional[List[Any]] = None,
prerequisite_computation: Optional["Computation"] = None,
citation: Optional[List[Any]] = None,
input_data: Optional[List[Union[Any, UIDProxy]]] = None,
output_data: Optional[List[Union[Any, UIDProxy]]] = None,
software_configuration: Optional[List[Union[Any, UIDProxy]]] = None,
condition: Optional[List[Union[Any, UIDProxy]]] = None,
prerequisite_computation: Optional[Union["Computation", UIDProxy]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
) -> None:
Expand Down Expand Up @@ -364,7 +365,7 @@ def condition(self, new_condition_list: List[Any]) -> None:

@property
@beartype
def prerequisite_computation(self) -> Optional["Computation"]:
def prerequisite_computation(self) -> Optional[Union["Computation", UIDProxy]]:
"""
prerequisite computation

Expand All @@ -386,7 +387,7 @@ def prerequisite_computation(self) -> Optional["Computation"]:

@prerequisite_computation.setter
@beartype
def prerequisite_computation(self, new_prerequisite_computation: Optional["Computation"]) -> None:
def prerequisite_computation(self, new_prerequisite_computation: Optional[Union["Computation", UIDProxy]]) -> None:
"""
set new prerequisite_computation

Expand Down
31 changes: 16 additions & 15 deletions src/cript/nodes/primary_nodes/computation_process.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class ComputationProcess(PrimaryBaseNode):
Expand Down Expand Up @@ -113,13 +114,13 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
input_data: List[Any] = field(default_factory=list)
output_data: List[Any] = field(default_factory=list)
ingredient: List[Any] = field(default_factory=list)
software_configuration: List[Any] = field(default_factory=list)
condition: List[Any] = field(default_factory=list)
property: List[Any] = field(default_factory=list)
citation: List[Any] = field(default_factory=list)
input_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
output_data: List[Union[Any, UIDProxy]] = field(default_factory=list)
ingredient: List[Union[Any, UIDProxy]] = field(default_factory=list)
software_configuration: List[Union[Any, UIDProxy]] = field(default_factory=list)
condition: List[Union[Any, UIDProxy]] = field(default_factory=list)
property: List[Union[Any, UIDProxy]] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -128,13 +129,13 @@ def __init__(
self,
name: str,
type: str,
input_data: List[Any],
ingredient: List[Any],
output_data: Optional[List[Any]] = None,
software_configuration: Optional[List[Any]] = None,
condition: Optional[List[Any]] = None,
property: Optional[List[Any]] = None,
citation: Optional[List[Any]] = None,
input_data: List[Union[Any, UIDProxy]],
ingredient: List[Union[Any, UIDProxy]],
output_data: Optional[List[Union[Any, UIDProxy]]] = None,
software_configuration: Optional[List[Union[Any, UIDProxy]]] = None,
condition: Optional[List[Union[Any, UIDProxy]]] = None,
property: Optional[List[Union[Any, UIDProxy]]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
):
Expand Down
29 changes: 15 additions & 14 deletions src/cript/nodes/primary_nodes/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from beartype import beartype

from cript.nodes.primary_nodes.primary_base_node import PrimaryBaseNode
from cript.nodes.util.json import UIDProxy


class Data(PrimaryBaseNode):
Expand Down Expand Up @@ -74,13 +75,13 @@ class JsonAttributes(PrimaryBaseNode.JsonAttributes):

type: str = ""
# TODO add proper typing in future, using Any for now to avoid circular import error
file: List[Any] = field(default_factory=list)
sample_preparation: Any = field(default_factory=list)
computation: List[Any] = field(default_factory=list)
computation_process: Any = field(default_factory=list)
material: List[Any] = field(default_factory=list)
process: List[Any] = field(default_factory=list)
citation: List[Any] = field(default_factory=list)
file: List[Union[Any, UIDProxy]] = field(default_factory=list)
sample_preparation: Union[Any, UIDProxy] = field(default_factory=list)
computation: List[Union[Any, UIDProxy]] = field(default_factory=list)
computation_process: Union[Any, UIDProxy] = field(default_factory=list)
material: List[Union[Any, UIDProxy]] = field(default_factory=list)
process: List[Union[Any, UIDProxy]] = field(default_factory=list)
citation: List[Union[Any, UIDProxy]] = field(default_factory=list)

_json_attrs: JsonAttributes = JsonAttributes()

Expand All @@ -89,13 +90,13 @@ def __init__(
self,
name: str,
type: str,
file: List[Any],
sample_preparation: Any = None,
computation: Optional[List[Any]] = None,
computation_process: Optional[Any] = None,
material: Optional[List[Any]] = None,
process: Optional[List[Any]] = None,
citation: Optional[List[Any]] = None,
file: List[Union[Any, UIDProxy]],
sample_preparation: Union[Any, UIDProxy] = None,
computation: Optional[List[Union[Any, UIDProxy]]] = None,
computation_process: Optional[Union[Any, UIDProxy]] = None,
material: Optional[List[Union[Any, UIDProxy]]] = None,
process: Optional[List[Union[Any, UIDProxy]]] = None,
citation: Optional[List[Union[Any, UIDProxy]]] = None,
notes: str = "",
**kwargs
) -> None:
Expand Down
Loading
Loading