From 94bc2cd59dd413daf57d9d049cf76671b51cf956 Mon Sep 17 00:00:00 2001 From: Brett Date: Tue, 9 Jan 2024 15:20:30 -0500 Subject: [PATCH] drop dict and list parent classes for lazy_nodes --- asdf/_asdf.py | 13 ++++--- asdf/_core/_converters/ndarray.py | 6 ++-- asdf/_lazy_nodes.py | 60 +++++++++++++++++++++---------- asdf/_node_info.py | 2 ++ asdf/_tests/_helpers.py | 10 ++++-- asdf/_tests/test_history.py | 5 +-- asdf/_tests/test_yaml.py | 6 ++-- asdf/tags/core/ndarray.py | 26 ++++++++------ asdf/treeutil.py | 4 +-- 9 files changed, 87 insertions(+), 45 deletions(-) diff --git a/asdf/_asdf.py b/asdf/_asdf.py index c57d755d6..63bc63712 100644 --- a/asdf/_asdf.py +++ b/asdf/_asdf.py @@ -1,3 +1,4 @@ +import collections.abc import copy import datetime import io @@ -321,7 +322,11 @@ def _check_extensions(self, tree, strict=False): strict : bool, optional Set to `True` to convert warnings to exceptions. """ - if "history" not in tree or not isinstance(tree["history"], dict) or "extensions" not in tree["history"]: + if ( + "history" not in tree + or not isinstance(tree["history"], collections.abc.Mapping) + or "extensions" not in tree["history"] + ): return for extension in tree["history"]["extensions"]: @@ -437,7 +442,7 @@ def _update_extension_history(self, tree, serialization_context): if "history" not in tree: tree["history"] = {"extensions": []} # Support clients who are still using the old history format - elif isinstance(tree["history"], list): + elif isinstance(tree["history"], collections.abc.Sequence): histlist = tree["history"] tree["history"] = {"entries": histlist, "extensions": []} warnings.warn( @@ -1330,7 +1335,7 @@ def add_history_entry(self, description, software=None): - ``homepage``: A URI to the homepage of the software - ``version``: The version of the software """ - if isinstance(software, list): + if isinstance(software, collections.abc.Sequence) and not isinstance(software, str): software = [Software(x) for x in software] elif software is not None: software = Software(software) @@ -1387,7 +1392,7 @@ def get_history_entries(self): if "history" not in self.tree: return [] - if isinstance(self.tree["history"], list): + if isinstance(self.tree["history"], collections.abc.Sequence) and not isinstance(self.tree["history"], str): return self.tree["history"] if "entries" in self.tree["history"]: diff --git a/asdf/_core/_converters/ndarray.py b/asdf/_core/_converters/ndarray.py index 197d8b28c..ae0eb42b2 100644 --- a/asdf/_core/_converters/ndarray.py +++ b/asdf/_core/_converters/ndarray.py @@ -1,3 +1,5 @@ +import collections.abc + import numpy as np from asdf.extension import Converter @@ -133,12 +135,12 @@ def from_yaml_tree(self, node, tag, ctx): from asdf.tags.core import NDArrayType from asdf.tags.core.ndarray import asdf_datatype_to_numpy_dtype - if isinstance(node, list): + if isinstance(node, collections.abc.Sequence) and not isinstance(node, str): instance = NDArrayType(node, None, None, None, None, None, None) ctx._blocks._set_array_storage(instance, "inline") return instance - if isinstance(node, dict): + if isinstance(node, collections.abc.Mapping): shape = node.get("shape", None) if "source" in node and "data" in node: msg = "Both source and data may not be provided at the same time" diff --git a/asdf/_lazy_nodes.py b/asdf/_lazy_nodes.py index 68626e241..6cb110979 100644 --- a/asdf/_lazy_nodes.py +++ b/asdf/_lazy_nodes.py @@ -1,4 +1,6 @@ import collections +import collections.abc +import copy import warnings from types import GeneratorType @@ -66,32 +68,46 @@ def __init__(self, data=None, af_ref=None): def tagged(self): return self.data + def copy(self): + return self.__class__(copy.copy(self.data), self._af_ref) -class AsdfListNode(AsdfNode, collections.UserList, list): + def __asdf_traverse__(self): + return self.data + + +class AsdfListNode(AsdfNode, collections.abc.MutableSequence): def __init__(self, data=None, af_ref=None): if data is None: data = [] AsdfNode.__init__(self, data, af_ref) - collections.UserList.__init__(self, data) - list.__init__(self, data) + + def __setitem__(self, index, value): + self.data.__setitem__(index, value) + + def __delitem__(self, index): + self.data.__delitem__(index) + + def __len__(self): + return self.data.__len__() + + def insert(self, index, value): + self.data.insert(index, value) def __eq__(self, other): if self is other: return True + if not isinstance(other, collections.abc.Sequence): + return False return list(self) == list(other) def __ne__(self, other): return not self.__eq__(other) - def __reduce__(self): - return collections.UserList.__reduce__(self) - def __getitem__(self, key): # key might be an int or slice - value = super().__getitem__(key) + value = self.data.__getitem__(key) if isinstance(key, slice): - value._af_ref = self._af_ref - return value + return AsdfListNode(value, self._af_ref) if isinstance(value, tagged.Tagged): value = _convert(value, self._af_ref) self[key] = value @@ -131,29 +147,36 @@ def __getitem__(self, key): return value -# dict is required here so TaggedDict doesn't convert this to a dict -# and so that json.dumps will work for this node TODO add test -class AsdfDictNode(AsdfNode, collections.UserDict, dict): +class AsdfDictNode(AsdfNode, collections.abc.MutableMapping): def __init__(self, data=None, af_ref=None): if data is None: data = {} AsdfNode.__init__(self, data, af_ref) - collections.UserDict.__init__(self, data) - dict.__init__(self, data) + + def __setitem__(self, index, value): + self.data.__setitem__(index, value) + + def __delitem__(self, index): + self.data.__delitem__(index) + + def __len__(self): + return self.data.__len__() + + def __iter__(self): + return self.data.__iter__() def __eq__(self, other): if self is other: return True + if not isinstance(other, collections.abc.Mapping): + return False return dict(self) == dict(other) def __ne__(self, other): return not self.__eq__(other) - def __reduce__(self): - return collections.UserDict.__reduce__(self) - def __getitem__(self, key): - value = super().__getitem__(key) + value = self.data.__getitem__(key) if isinstance(value, tagged.Tagged): value = _convert(value, self._af_ref) self[key] = value @@ -198,4 +221,3 @@ def __init__(self, data=None, af_ref=None): if data is None: data = collections.OrderedDict() AsdfDictNode.__init__(self, data, af_ref) - collections.OrderedDict.__init__(self, data) diff --git a/asdf/_node_info.py b/asdf/_node_info.py index e61b22241..dbd05b0ae 100644 --- a/asdf/_node_info.py +++ b/asdf/_node_info.py @@ -239,6 +239,8 @@ def set_schema_for_property(self, parent, identifier): def set_schema_from_node(self, node, extension_manager): """Pull a tagged schema for the node""" + if not hasattr(node, "_tag"): + return tag_def = extension_manager.get_tag_definition(node._tag) schema_uri = tag_def.schema_uris[0] schema = load_schema(schema_uri) diff --git a/asdf/_tests/_helpers.py b/asdf/_tests/_helpers.py index daa8085ad..e4aa1cd1e 100644 --- a/asdf/_tests/_helpers.py +++ b/asdf/_tests/_helpers.py @@ -1,3 +1,4 @@ +import collections.abc import io import os import warnings @@ -115,12 +116,17 @@ def recurse(old, new): else: getattr(old_type, funcname)(old, new) - elif isinstance(old, dict) and isinstance(new, dict): + elif isinstance(old, collections.abc.Mapping) and isinstance(new, collections.abc.Mapping): assert {x for x in old if x not in ignore_keys} == {x for x in new if x not in ignore_keys} for key in old: if key not in ignore_keys: recurse(old[key], new[key]) - elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)): + elif ( + isinstance(old, collections.abc.Sequence) + and isinstance(new, collections.abc.Sequence) + and not isinstance(old, str) + and not isinstance(new, str) + ): assert len(old) == len(new) for a, b in zip(old, new): recurse(a, b) diff --git a/asdf/_tests/test_history.py b/asdf/_tests/test_history.py index 8da69a354..afbb4c45f 100644 --- a/asdf/_tests/test_history.py +++ b/asdf/_tests/test_history.py @@ -1,3 +1,4 @@ +import collections.abc import datetime import fractions import warnings @@ -53,7 +54,7 @@ def test_history_to_file(tmp_path): # Test the history entry retrieval API entries = ff.get_history_entries() assert len(entries) == 1 - assert isinstance(entries, list) + assert isinstance(entries, collections.abc.Sequence) and not isinstance(entries, str) assert isinstance(entries[0], HistoryEntry) assert entries[0]["description"] == "This happened" assert entries[0]["software"]["name"] == "my_tool" @@ -78,7 +79,7 @@ def test_old_history(): # Test the history entry retrieval API entries = af.get_history_entries() assert len(entries) == 1 - assert isinstance(entries, list) + assert isinstance(entries, collections.abc.Sequence) and not isinstance(entries, str) assert isinstance(entries[0], HistoryEntry) assert entries[0]["description"] == "Here's a test of old history entries" assert entries[0]["software"]["name"] == "foo" diff --git a/asdf/_tests/test_yaml.py b/asdf/_tests/test_yaml.py index a4f9f7e31..4c4f8787d 100644 --- a/asdf/_tests/test_yaml.py +++ b/asdf/_tests/test_yaml.py @@ -7,7 +7,7 @@ import yaml import asdf -from asdf import tagged, treeutil, yamlutil +from asdf import _lazy_nodes, tagged, treeutil, yamlutil from asdf.exceptions import AsdfConversionWarning, AsdfWarning from . import _helpers as helpers @@ -30,7 +30,7 @@ def check_asdf(asdf): assert list(tree["ordered_dict"].keys()) == ["first", "second", "third"] assert not isinstance(tree["unordered_dict"], OrderedDict) - assert isinstance(tree["unordered_dict"], dict) + assert isinstance(tree["unordered_dict"], (dict, _lazy_nodes.AsdfDictNode)) def check_raw_yaml(content): assert b"OrderedDict" not in content @@ -79,7 +79,7 @@ class Foo: def run_tuple_test(tree, tmp_path): def check_asdf(asdf): - assert isinstance(asdf.tree["val"], list) + assert isinstance(asdf.tree["val"], (list, _lazy_nodes.AsdfListNode)) def check_raw_yaml(content): assert b"tuple" not in content diff --git a/asdf/tags/core/ndarray.py b/asdf/tags/core/ndarray.py index b59e331aa..0c2490ea8 100644 --- a/asdf/tags/core/ndarray.py +++ b/asdf/tags/core/ndarray.py @@ -1,3 +1,4 @@ +import collections.abc import mmap import sys @@ -48,7 +49,8 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None): return np.dtype(str(byteorder + datatype)) if ( - isinstance(datatype, list) + isinstance(datatype, collections.abc.Sequence) + and not isinstance(datatype, str) and len(datatype) == 2 and isinstance(datatype[0], str) and isinstance(datatype[1], int) @@ -60,7 +62,7 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None): return np.dtype(datatype) - if isinstance(datatype, dict): + if isinstance(datatype, collections.abc.Mapping): if "datatype" not in datatype: msg = f"Field entry has no datatype: '{datatype}'" raise ValueError(msg) @@ -75,7 +77,7 @@ def asdf_datatype_to_numpy_dtype(datatype, byteorder=None): return (str(name), datatype, tuple(shape)) - if isinstance(datatype, list): + if isinstance(datatype, collections.abc.Sequence) and not isinstance(datatype, str): datatype_list = [] for subdatatype in datatype: np_dtype = asdf_datatype_to_numpy_dtype(subdatatype, byteorder) @@ -160,7 +162,7 @@ def inline_data_asarray(inline, dtype=None): if dtype is not None and dtype.fields is not None: def find_innermost_match(line, depth=0): - if not isinstance(line, list) or not len(line): + if not isinstance(line, collections.abc.Sequence) or not len(line): msg = "data can not be converted to structured array" raise ValueError(msg) try: @@ -183,7 +185,7 @@ def convert_to_tuples(line, data_depth, depth=0): return np.asarray(inline, dtype=dtype) def handle_mask(inline): - if isinstance(inline, list): + if isinstance(inline, collections.abc.Sequence) and not isinstance(inline, str): if None in inline: inline_array = np.asarray(inline) nones = np.equal(inline_array, None) @@ -207,7 +209,7 @@ def tolist(x): if isinstance(x, (np.ndarray, NDArrayType)): x = x.astype("U").tolist() if x.dtype.char == "S" else x.tolist() - if isinstance(x, (list, tuple)): + if isinstance(x, collections.abc.Sequence) and not isinstance(x, str): return [tolist(y) for y in x] return x @@ -233,7 +235,7 @@ def __init__(self, source, shape, dtype, offset, strides, order, mask, data_call self._array = None self._mask = mask - if isinstance(source, list): + if isinstance(source, collections.abc.Sequence) and not isinstance(source, str): self._array = inline_data_asarray(source, dtype) self._array = self._apply_mask(self._array, self._mask) # single element structured arrays can have shape == () @@ -338,6 +340,8 @@ def get_actual_shape(self, shape, strides, dtype, block_size): Get the actual shape of an array, by computing it against the block_size if it contains a ``*``. """ + if hasattr(shape, "data"): + shape = shape.data num_stars = shape.count("*") if num_stars == 0: return shape @@ -478,11 +482,11 @@ def operation(self, *args): def _get_ndim(instance): - if isinstance(instance, list): + if isinstance(instance, collections.abc.Sequence) and not isinstance(instance, str): array = inline_data_asarray(instance) return array.ndim - if isinstance(instance, dict): + if isinstance(instance, collections.abc.Mapping): if "shape" in instance: return len(instance["shape"]) @@ -514,10 +518,10 @@ def validate_max_ndim(validator, max_ndim, instance, schema): def validate_datatype(validator, datatype, instance, schema): - if isinstance(instance, list): + if isinstance(instance, collections.abc.Sequence): array = inline_data_asarray(instance) in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype) - elif isinstance(instance, dict): + elif isinstance(instance, collections.abc.Mapping): if "datatype" in instance: in_datatype = instance["datatype"] elif "data" in instance: diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 3923ccfd6..2a90b92d8 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -369,11 +369,11 @@ def _handle_immutable_sequence(node, json_id): return result def _handle_children(node, json_id): - if isinstance(node, dict): + if isinstance(node, (dict, _lazy_nodes.AsdfDictNode)): result = _handle_mapping(node, json_id) elif isinstance(node, tuple): result = _handle_immutable_sequence(node, json_id) - elif isinstance(node, list): + elif isinstance(node, (list, _lazy_nodes.AsdfListNode)): result = _handle_mutable_sequence(node, json_id) else: result = node