Skip to content

Commit

Permalink
drop dict and list parent classes for lazy_nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Feb 12, 2024
1 parent fb81020 commit 94bc2cd
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 45 deletions.
13 changes: 9 additions & 4 deletions asdf/_asdf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import copy
import datetime
import io
Expand Down Expand Up @@ -321,7 +322,11 @@ def _check_extensions(self, tree, strict=False):
strict : bool, optional
Set to `True` to convert warnings to exceptions.
"""
if "history" not in tree or not isinstance(tree["history"], dict) or "extensions" not in tree["history"]:
if (
"history" not in tree
or not isinstance(tree["history"], collections.abc.Mapping)
or "extensions" not in tree["history"]
):
return

for extension in tree["history"]["extensions"]:
Expand Down Expand Up @@ -437,7 +442,7 @@ def _update_extension_history(self, tree, serialization_context):
if "history" not in tree:
tree["history"] = {"extensions": []}
# Support clients who are still using the old history format
elif isinstance(tree["history"], list):
elif isinstance(tree["history"], collections.abc.Sequence):
histlist = tree["history"]
tree["history"] = {"entries": histlist, "extensions": []}
warnings.warn(
Expand Down Expand Up @@ -1330,7 +1335,7 @@ def add_history_entry(self, description, software=None):
- ``homepage``: A URI to the homepage of the software
- ``version``: The version of the software
"""
if isinstance(software, list):
if isinstance(software, collections.abc.Sequence) and not isinstance(software, str):
software = [Software(x) for x in software]
elif software is not None:
software = Software(software)
Expand Down Expand Up @@ -1387,7 +1392,7 @@ def get_history_entries(self):
if "history" not in self.tree:
return []

if isinstance(self.tree["history"], list):
if isinstance(self.tree["history"], collections.abc.Sequence) and not isinstance(self.tree["history"], str):
return self.tree["history"]

if "entries" in self.tree["history"]:
Expand Down
6 changes: 4 additions & 2 deletions asdf/_core/_converters/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections.abc

import numpy as np

from asdf.extension import Converter
Expand Down Expand Up @@ -133,12 +135,12 @@ def from_yaml_tree(self, node, tag, ctx):
from asdf.tags.core import NDArrayType
from asdf.tags.core.ndarray import asdf_datatype_to_numpy_dtype

if isinstance(node, list):
if isinstance(node, collections.abc.Sequence) and not isinstance(node, str):
instance = NDArrayType(node, None, None, None, None, None, None)
ctx._blocks._set_array_storage(instance, "inline")
return instance

if isinstance(node, dict):
if isinstance(node, collections.abc.Mapping):
shape = node.get("shape", None)
if "source" in node and "data" in node:
msg = "Both source and data may not be provided at the same time"
Expand Down
60 changes: 41 additions & 19 deletions asdf/_lazy_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import collections
import collections.abc
import copy
import warnings
from types import GeneratorType

Expand Down Expand Up @@ -66,32 +68,46 @@ def __init__(self, data=None, af_ref=None):
def tagged(self):
return self.data

def copy(self):
return self.__class__(copy.copy(self.data), self._af_ref)

class AsdfListNode(AsdfNode, collections.UserList, list):
def __asdf_traverse__(self):
return self.data


class AsdfListNode(AsdfNode, collections.abc.MutableSequence):
def __init__(self, data=None, af_ref=None):
if data is None:
data = []
AsdfNode.__init__(self, data, af_ref)
collections.UserList.__init__(self, data)
list.__init__(self, data)

def __setitem__(self, index, value):
self.data.__setitem__(index, value)

def __delitem__(self, index):
self.data.__delitem__(index)

def __len__(self):
return self.data.__len__()

def insert(self, index, value):
self.data.insert(index, value)

def __eq__(self, other):
if self is other:
return True
if not isinstance(other, collections.abc.Sequence):
return False
return list(self) == list(other)

def __ne__(self, other):
return not self.__eq__(other)

def __reduce__(self):
return collections.UserList.__reduce__(self)

def __getitem__(self, key):
# key might be an int or slice
value = super().__getitem__(key)
value = self.data.__getitem__(key)
if isinstance(key, slice):
value._af_ref = self._af_ref
return value
return AsdfListNode(value, self._af_ref)
if isinstance(value, tagged.Tagged):
value = _convert(value, self._af_ref)
self[key] = value
Expand Down Expand Up @@ -131,29 +147,36 @@ def __getitem__(self, key):
return value


# dict is required here so TaggedDict doesn't convert this to a dict
# and so that json.dumps will work for this node TODO add test
class AsdfDictNode(AsdfNode, collections.UserDict, dict):
class AsdfDictNode(AsdfNode, collections.abc.MutableMapping):
def __init__(self, data=None, af_ref=None):
if data is None:
data = {}
AsdfNode.__init__(self, data, af_ref)
collections.UserDict.__init__(self, data)
dict.__init__(self, data)

def __setitem__(self, index, value):
self.data.__setitem__(index, value)

def __delitem__(self, index):
self.data.__delitem__(index)

def __len__(self):
return self.data.__len__()

def __iter__(self):
return self.data.__iter__()

def __eq__(self, other):
if self is other:
return True
if not isinstance(other, collections.abc.Mapping):
return False
return dict(self) == dict(other)

def __ne__(self, other):
return not self.__eq__(other)

def __reduce__(self):
return collections.UserDict.__reduce__(self)

def __getitem__(self, key):
value = super().__getitem__(key)
value = self.data.__getitem__(key)
if isinstance(value, tagged.Tagged):
value = _convert(value, self._af_ref)
self[key] = value
Expand Down Expand Up @@ -198,4 +221,3 @@ def __init__(self, data=None, af_ref=None):
if data is None:
data = collections.OrderedDict()
AsdfDictNode.__init__(self, data, af_ref)
collections.OrderedDict.__init__(self, data)
2 changes: 2 additions & 0 deletions asdf/_node_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def set_schema_for_property(self, parent, identifier):
def set_schema_from_node(self, node, extension_manager):
"""Pull a tagged schema for the node"""

if not hasattr(node, "_tag"):
return
tag_def = extension_manager.get_tag_definition(node._tag)
schema_uri = tag_def.schema_uris[0]
schema = load_schema(schema_uri)
Expand Down
10 changes: 8 additions & 2 deletions asdf/_tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import io
import os
import warnings
Expand Down Expand Up @@ -115,12 +116,17 @@ def recurse(old, new):
else:
getattr(old_type, funcname)(old, new)

elif isinstance(old, dict) and isinstance(new, dict):
elif isinstance(old, collections.abc.Mapping) and isinstance(new, collections.abc.Mapping):
assert {x for x in old if x not in ignore_keys} == {x for x in new if x not in ignore_keys}
for key in old:
if key not in ignore_keys:
recurse(old[key], new[key])
elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)):
elif (
isinstance(old, collections.abc.Sequence)
and isinstance(new, collections.abc.Sequence)
and not isinstance(old, str)
and not isinstance(new, str)
):
assert len(old) == len(new)
for a, b in zip(old, new):
recurse(a, b)
Expand Down
5 changes: 3 additions & 2 deletions asdf/_tests/test_history.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import datetime
import fractions
import warnings
Expand Down Expand Up @@ -53,7 +54,7 @@ def test_history_to_file(tmp_path):
# Test the history entry retrieval API
entries = ff.get_history_entries()
assert len(entries) == 1
assert isinstance(entries, list)
assert isinstance(entries, collections.abc.Sequence) and not isinstance(entries, str)
assert isinstance(entries[0], HistoryEntry)
assert entries[0]["description"] == "This happened"
assert entries[0]["software"]["name"] == "my_tool"
Expand All @@ -78,7 +79,7 @@ def test_old_history():
# Test the history entry retrieval API
entries = af.get_history_entries()
assert len(entries) == 1
assert isinstance(entries, list)
assert isinstance(entries, collections.abc.Sequence) and not isinstance(entries, str)
assert isinstance(entries[0], HistoryEntry)
assert entries[0]["description"] == "Here's a test of old history entries"
assert entries[0]["software"]["name"] == "foo"
Expand Down
6 changes: 3 additions & 3 deletions asdf/_tests/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml

import asdf
from asdf import tagged, treeutil, yamlutil
from asdf import _lazy_nodes, tagged, treeutil, yamlutil
from asdf.exceptions import AsdfConversionWarning, AsdfWarning

from . import _helpers as helpers
Expand All @@ -30,7 +30,7 @@ def check_asdf(asdf):
assert list(tree["ordered_dict"].keys()) == ["first", "second", "third"]

assert not isinstance(tree["unordered_dict"], OrderedDict)
assert isinstance(tree["unordered_dict"], dict)
assert isinstance(tree["unordered_dict"], (dict, _lazy_nodes.AsdfDictNode))

def check_raw_yaml(content):
assert b"OrderedDict" not in content
Expand Down Expand Up @@ -79,7 +79,7 @@ class Foo:

def run_tuple_test(tree, tmp_path):
def check_asdf(asdf):
assert isinstance(asdf.tree["val"], list)
assert isinstance(asdf.tree["val"], (list, _lazy_nodes.AsdfListNode))

def check_raw_yaml(content):
assert b"tuple" not in content
Expand Down
26 changes: 15 additions & 11 deletions asdf/tags/core/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections.abc
import mmap
import sys

Expand Down Expand Up @@ -48,7 +49,8 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None):
return np.dtype(str(byteorder + datatype))

if (
isinstance(datatype, list)
isinstance(datatype, collections.abc.Sequence)
and not isinstance(datatype, str)
and len(datatype) == 2
and isinstance(datatype[0], str)
and isinstance(datatype[1], int)
Expand All @@ -60,7 +62,7 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None):

return np.dtype(datatype)

if isinstance(datatype, dict):
if isinstance(datatype, collections.abc.Mapping):
if "datatype" not in datatype:
msg = f"Field entry has no datatype: '{datatype}'"
raise ValueError(msg)
Expand All @@ -75,7 +77,7 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None):

return (str(name), datatype, tuple(shape))

if isinstance(datatype, list):
if isinstance(datatype, collections.abc.Sequence) and not isinstance(datatype, str):
datatype_list = []
for subdatatype in datatype:
np_dtype = asdf_datatype_to_numpy_dtype(subdatatype, byteorder)
Expand Down Expand Up @@ -160,7 +162,7 @@ def inline_data_asarray(inline, dtype=None):
if dtype is not None and dtype.fields is not None:

def find_innermost_match(line, depth=0):
if not isinstance(line, list) or not len(line):
if not isinstance(line, collections.abc.Sequence) or not len(line):
msg = "data can not be converted to structured array"
raise ValueError(msg)
try:
Expand All @@ -183,7 +185,7 @@ def convert_to_tuples(line, data_depth, depth=0):
return np.asarray(inline, dtype=dtype)

def handle_mask(inline):
if isinstance(inline, list):
if isinstance(inline, collections.abc.Sequence) and not isinstance(inline, str):
if None in inline:
inline_array = np.asarray(inline)
nones = np.equal(inline_array, None)
Expand All @@ -207,7 +209,7 @@ def tolist(x):
if isinstance(x, (np.ndarray, NDArrayType)):
x = x.astype("U").tolist() if x.dtype.char == "S" else x.tolist()

if isinstance(x, (list, tuple)):
if isinstance(x, collections.abc.Sequence) and not isinstance(x, str):
return [tolist(y) for y in x]

return x
Expand All @@ -233,7 +235,7 @@ def __init__(self, source, shape, dtype, offset, strides, order, mask, data_call
self._array = None
self._mask = mask

if isinstance(source, list):
if isinstance(source, collections.abc.Sequence) and not isinstance(source, str):
self._array = inline_data_asarray(source, dtype)
self._array = self._apply_mask(self._array, self._mask)
# single element structured arrays can have shape == ()
Expand Down Expand Up @@ -338,6 +340,8 @@ def get_actual_shape(self, shape, strides, dtype, block_size):
Get the actual shape of an array, by computing it against the
block_size if it contains a ``*``.
"""
if hasattr(shape, "data"):
shape = shape.data
num_stars = shape.count("*")
if num_stars == 0:
return shape
Expand Down Expand Up @@ -478,11 +482,11 @@ def operation(self, *args):


def _get_ndim(instance):
if isinstance(instance, list):
if isinstance(instance, collections.abc.Sequence) and not isinstance(instance, str):
array = inline_data_asarray(instance)
return array.ndim

if isinstance(instance, dict):
if isinstance(instance, collections.abc.Mapping):
if "shape" in instance:
return len(instance["shape"])

Expand Down Expand Up @@ -514,10 +518,10 @@ def validate_max_ndim(validator, max_ndim, instance, schema):


def validate_datatype(validator, datatype, instance, schema):
if isinstance(instance, list):
if isinstance(instance, collections.abc.Sequence):
array = inline_data_asarray(instance)
in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype)
elif isinstance(instance, dict):
elif isinstance(instance, collections.abc.Mapping):
if "datatype" in instance:
in_datatype = instance["datatype"]
elif "data" in instance:
Expand Down
4 changes: 2 additions & 2 deletions asdf/treeutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,11 @@ def _handle_immutable_sequence(node, json_id):
return result

def _handle_children(node, json_id):
if isinstance(node, dict):
if isinstance(node, (dict, _lazy_nodes.AsdfDictNode)):
result = _handle_mapping(node, json_id)
elif isinstance(node, tuple):
result = _handle_immutable_sequence(node, json_id)
elif isinstance(node, list):
elif isinstance(node, (list, _lazy_nodes.AsdfListNode)):
result = _handle_mutable_sequence(node, json_id)
else:
result = node
Expand Down

0 comments on commit 94bc2cd

Please sign in to comment.