Skip to content

Commit

Permalink
Fix bug preventing nesting from working properly
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknyquist committed Apr 6, 2023
1 parent 7b66afa commit 8968eee
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 20 deletions.
53 changes: 51 additions & 2 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,41 @@ class TestConfig(VersionedObject):
self.assertEqual(1, cfg.val2.val1)
self.assertEqual(55.5, cfg.val2.val2)

def test_nested_config_dict_2(self):
"""
Tests that a VersionedObject instance with nested VersionedObjbects still
contains the same values after being serialized to a dict and deserialized
back from the dict
"""
class NestedConfig(VersionedObject):
val1 = 1
val2 = 55.5

class TestConfig(VersionedObject):
val1 = "a"
val2 = NestedConfig()
val3 = NestedConfig()

cfg = TestConfig()
s = Serializer(cfg)

d = s.to_dict()

self.assertEqual("a", d["val1"])
self.assertEqual(1, d["val2"]["val1"])
self.assertEqual(55.5, d["val2"]["val2"])
self.assertEqual(1, d["val3"]["val1"])
self.assertEqual(55.5, d["val3"]["val2"])

result = s.from_dict(d)
self.assertIs(None, result) # verify no migrations peformewd

self.assertEqual("a", cfg.val1)
self.assertEqual(1, cfg.val2.val1)
self.assertEqual(55.5, cfg.val2.val2)
self.assertEqual(1, cfg.val3.val1)
self.assertEqual(55.5, cfg.val3.val2)

def test_nested_config_dict_change(self):
"""
Tests that we can change instance attributes on a VersionedObject instance
Expand Down Expand Up @@ -331,6 +366,7 @@ class Level4(VersionedObject):
class Level3(VersionedObject):
val1 = 66.6
val2 = Level4()
val3 = Level4()

class Level2(VersionedObject):
val1 = "gg"
Expand All @@ -342,13 +378,16 @@ class Level1(VersionedObject):

ser = Serializer()
cfg = Level1()

d = ser.to_dict(cfg)

self.assertEqual(3, d['val1'])
self.assertEqual('gg', d['val2']['val1'])
self.assertEqual(66.6, d['val2']['val2']['val1'])
self.assertEqual(True, d['val2']['val2']['val2']['val1'])
self.assertEqual(False, d['val2']['val2']['val2']['val2'])
self.assertEqual(True, d['val2']['val2']['val3']['val1'])
self.assertEqual(False, d['val2']['val2']['val3']['val2'])

d['val2']['val1'] = "changed"
result = ser.from_dict(d, cfg)
Expand All @@ -359,6 +398,8 @@ class Level1(VersionedObject):
self.assertEqual(66.6, cfg.val2.val2.val1)
self.assertEqual(True, cfg.val2.val2.val2.val1)
self.assertEqual(False, cfg.val2.val2.val2.val2)
self.assertEqual(True, cfg.val2.val2.val3.val1)
self.assertEqual(False, cfg.val2.val2.val3.val2)

def test_load_json_deeper_nesting(self):
"""
Expand All @@ -385,7 +426,7 @@ class Level1(VersionedObject):
ser = Serializer(cfg)
d = ser.to_json(cfg)
result = ser.from_json(d)
self.assertIs(None, result) # verify no migrations peformewd
self.assertIs(None, result) # verify no migrations peformed

self.assertEqual(3, cfg.val1)
self.assertEqual('gg', cfg.val2.val1)
Expand Down Expand Up @@ -1834,17 +1875,25 @@ class TestConfig2(VersionedObject):

class TestConfig(VersionedObject):
var1 = TestConfig2()
var2 = 7
var3 = TestConfig2()

serializer = Serializer()
config = TestConfig()

self.assertEqual(config.var1.var1, 1)
self.assertEqual(config.var2, 7)
self.assertEqual(config.var3.var1, 1)

serializer.from_dict({"var1": {"var1": 2}}, config)
serializer.from_dict({"var1": {"var1": 2}, "var2": 8, "var3": {"var1": 3}}, config)

config2 = TestConfig()
self.assertEqual(config.var1.var1, 2)
self.assertEqual(config.var2, 8)
self.assertEqual(config.var3.var1, 3)
self.assertEqual(config2.var1.var1, 1)
self.assertEqual(config2.var2, 7)
self.assertEqual(config2.var3.var1, 1)

def test_loading_doesnt_change_default_values_listfield(self):
"""
Expand Down
15 changes: 6 additions & 9 deletions versionedobj/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,21 @@ def _walk_dict_attrs(obj, parent_attrs, only=[], ignore=[]):
:param list only: List of 'only' names
:param list ignore: List of 'ignore' names
"""
parents = []
attrs_stack = [(None, parent_attrs)]
attrs_stack = [(None, [], parent_attrs)]

while attrs_stack:
fieldname, attrs = attrs_stack.pop(0)
if fieldname is not None:
parents.append(fieldname)
fieldname, parents, attrs = attrs_stack.pop(0)

for n in attrs:
p = parents if fieldname is None else parents + [fieldname]
value = attrs[n]
field = _ObjField(parents, n, value)
dotname = field.dot_name()
field = _ObjField(p, n, value)
field_value = field.get_obj_field(obj)

if (isinstance(field_value, VersionedObject) and (type(value) == dict)):
attrs_stack.append((n, value))
attrs_stack.append((n, p, value))
else:
if not _field_should_be_skipped(dotname, only, ignore):
if not _field_should_be_skipped(field.dot_name(), only, ignore):
yield field


Expand Down
15 changes: 6 additions & 9 deletions versionedobj/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,20 @@ def _walk_obj_attrs(parent_obj, only=[], ignore=[]):
:param list only: List of 'only' names
:param list ignore: List of 'ignore' names
"""
parents = []
obj_stack = [(None, parent_obj)]
obj_stack = [(None, [], parent_obj)]

while obj_stack:
fieldname, obj = obj_stack.pop(0)
if fieldname is not None:
parents.append(fieldname)
fieldname, parents, obj = obj_stack.pop(0)

for n in _iter_obj_attrs(obj):
p = parents if fieldname is None else parents + [fieldname]
value = obj.__dict__[n]
field = _ObjField(parents, n, value)
dotname = field.dot_name()
field = _ObjField(p, n, value)

if isinstance(value, _ObjField.obj_class):
obj_stack.append((n, value))
obj_stack.append((n, p, value))
else:
if not _field_should_be_skipped(dotname, only, ignore):
if not _field_should_be_skipped(field.dot_name(), only, ignore):
yield field


Expand Down

0 comments on commit 8968eee

Please sign in to comment.