From 257fe419521a8eecd1f0da3db11e198436372eb3 Mon Sep 17 00:00:00 2001 From: Ludwig Schneider Date: Tue, 12 Mar 2024 18:07:45 -0500 Subject: [PATCH] it finally works --- src/cript/nodes/core.py | 10 ++++++++-- src/cript/nodes/node_iterator.py | 31 +++++++++++++++---------------- src/cript/nodes/uuid_base.py | 2 +- tests/fixtures/primary_nodes.py | 12 ++++++------ tests/test_node_util.py | 1 + 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/cript/nodes/core.py b/src/cript/nodes/core.py index c56c9866..da9e039b 100644 --- a/src/cript/nodes/core.py +++ b/src/cript/nodes/core.py @@ -4,7 +4,7 @@ import re import uuid from abc import ABC -from dataclasses import asdict, dataclass, replace +from dataclasses import dataclass, replace from typing import Dict, List, Optional, Set from cript.nodes.exceptions import ( @@ -103,7 +103,7 @@ def __str__(self) -> str: str A string representation of the node. """ - return str(asdict(self._json_attrs)) + return str(self._json_attrs) @property def uid(self): @@ -201,6 +201,12 @@ def _from_json(cls, json_dict: dict): attrs = replace(attrs, uid="_:" + attrs.uid) except AttributeError: pass + + try: + attrs = replace(attrs, uuid=str(attrs.uuid)) + except AttributeError: + pass + # But here we force even usually unwritable fields to be set. node._update_json_attrs_if_valid(attrs) diff --git a/src/cript/nodes/node_iterator.py b/src/cript/nodes/node_iterator.py index 3e95b47b..1640e1f9 100644 --- a/src/cript/nodes/node_iterator.py +++ b/src/cript/nodes/node_iterator.py @@ -5,13 +5,18 @@ class NodeIterator: def __init__(self, root, max_recursion_depth=-1): self._iter_position: int = 0 - self._uuid_visited: Set[str] = set(str(root.uuid)) - self._stack: List[Any] = [root] - self._recursion_depth = [0] + self._uuid_visited: Set[str] = set() + self._stack: List[Any] = [] + self._recursion_depth = [] self._max_recursion_depth = max_recursion_depth self._depth_first(root, 0) - def _handle_child_node(self, child_node, recursion_depth: int) -> bool: + def _add_node(self, child_node, recursion_depth: int): + self._stack.append(child_node) + self._recursion_depth.append(recursion_depth) + self._uuid_visited.add(child_node.uuid) + + def _check_recursion(self, child_node) -> bool: """Helper function that adds a child to the stack. This function can be called for both listed children and regular children attributes @@ -22,14 +27,12 @@ def _handle_child_node(self, child_node, recursion_depth: int) -> bool: return False if uuid not in self._uuid_visited: - self._stack.append(child_node) - self._uuid_visited.add(str(child_node.uuid)) - self._recursion_depth.append(recursion_depth) return True return False def _depth_first(self, node, recursion_depth: int) -> None: """Helper function that does the traversal in depth first order and stores result in stack""" + self._add_node(node, recursion_depth) if self._max_recursion_depth >= 0 and recursion_depth >= self._max_recursion_depth: return @@ -37,15 +40,11 @@ def _depth_first(self, node, recursion_depth: int) -> None: field_names = [field.name for field in fields(node._json_attrs)] for attr_name in sorted(field_names): attr = getattr(node._json_attrs, attr_name) - node_added = self._handle_child_node(attr, recursion_depth) - if node_added: - self._depth_first(node, recursion_depth + 1) - else: - if isinstance(attr, list): - for list_attr in attr: - node_added = self._handle_child_node(list_attr, recursion_depth) - if node_added: - self._depth_first(list_attr, recursion_depth + 1) + if not isinstance(attr, list): + attr = [attr] + for list_attr in attr: + if self._check_recursion(list_attr): + self._depth_first(list_attr, recursion_depth + 1) def __next__(self): if self._iter_position >= len(self._stack): diff --git a/src/cript/nodes/uuid_base.py b/src/cript/nodes/uuid_base.py index a5d3729e..70a3ca01 100644 --- a/src/cript/nodes/uuid_base.py +++ b/src/cript/nodes/uuid_base.py @@ -31,7 +31,7 @@ class JsonAttributes(BaseNode.JsonAttributes): _json_attrs: JsonAttributes = JsonAttributes() def __new__(cls, *args, **kwargs): - uuid: Optional[str] = kwargs.get("uuid") + uuid: Optional[str] = str(kwargs.get("uuid")) if uuid and uuid in UUIDBaseNode._uuid_cache: existing_node_to_overwrite = UUIDBaseNode._uuid_cache[uuid] if type(existing_node_to_overwrite) is not cls: diff --git a/tests/fixtures/primary_nodes.py b/tests/fixtures/primary_nodes.py index 65fbf022..db86046b 100644 --- a/tests/fixtures/primary_nodes.py +++ b/tests/fixtures/primary_nodes.py @@ -26,15 +26,15 @@ def complex_project_dict(complex_collection_node, simple_material_node, complex_ project_dict = {"node": ["Project"]} project_dict["locked"] = True project_dict["model_version"] = "1.0.0" - project_dict["updated_by"] = json.loads(copy.deepcopy(complex_user_node).get_json(condense_to_uuid={}).json) - project_dict["created_by"] = json.loads(complex_user_node.get_json(condense_to_uuid={}).json) + project_dict["updated_by"] = json.loads(copy.deepcopy(complex_user_node).get_expanded_json()) + project_dict["created_by"] = json.loads(complex_user_node.get_expanded_json()) project_dict["public"] = True project_dict["name"] = "my project name" project_dict["notes"] = "my project notes" - project_dict["member"] = [json.loads(complex_user_node.get_json(condense_to_uuid={}).json)] - project_dict["admin"] = [json.loads(complex_user_node.get_json(condense_to_uuid={}).json)] - project_dict["collection"] = [json.loads(complex_collection_node.get_json(condense_to_uuid={}).json)] - project_dict["material"] = [json.loads(copy.deepcopy(simple_material_node).get_json(condense_to_uuid={}).json)] + project_dict["member"] = [json.loads(complex_user_node.get_expanded_json())] + project_dict["admin"] = [json.loads(complex_user_node.get_expanded_json())] + project_dict["collection"] = [json.loads(complex_collection_node.get_expanded_json())] + project_dict["material"] = [json.loads(copy.deepcopy(simple_material_node).get_expanded_json())] return project_dict diff --git a/tests/test_node_util.py b/tests/test_node_util.py index 704226d1..52f50a0b 100644 --- a/tests/test_node_util.py +++ b/tests/test_node_util.py @@ -299,6 +299,7 @@ def test_invalid_project_graphs(simple_project_node, simple_material_node, simpl # Now add an orphan data data = copy.deepcopy(simple_data_node) property.data = [data] + with pytest.raises(CRIPTOrphanedDataError): project.validate() # Fix with the helper function