diff --git a/tests/test_reference_files.py b/tests/test_reference_files.py index a5dde810..6527453a 100644 --- a/tests/test_reference_files.py +++ b/tests/test_reference_files.py @@ -1,13 +1,44 @@ import os +import asdf +import numpy as np import pytest -from asdf import open as asdf_open -from asdf import versioning -from asdf._tests._helpers import assert_tree_match _REFFILE_PATH = os.path.join(os.path.dirname(__file__), "..", "reference_files") +def assert_tree_match(tree_a, tree_b): + seen = set() + + def walk_trees(a, b): + seen_key = (id(a), id(b)) + + if seen_key in seen: + return + + seen.add(seen_key) + + assert isinstance(a, type(b)) + if isinstance(a, dict): + assert set(a.keys()) == set(b.keys()) + for k in a: + walk_trees(a[k], b[k]) + elif isinstance(a, list): + assert len(a) == len(b) + for i in range(len(a)): + walk_trees(a[k], b[k]) + elif isinstance(a, (np.ndarray, asdf.tags.core.ndarray.NDArrayType)): + np.testing.assert_array_equal(a, b) + else: + assert a == b + + ignore_keys = {"asdf_library", "history"} + walk_trees( + {k: v for k, v in tree_a.items() if k not in ignore_keys}, + {k: v for k, v in tree_b.items() if k not in ignore_keys}, + ) + + def get_test_id(reference_file_path): """Helper function to return the informative part of a schema path""" path = os.path.normpath(str(reference_file_path)) @@ -16,7 +47,7 @@ def get_test_id(reference_file_path): def collect_reference_files(): """Function used by pytest to collect ASDF reference files for testing.""" - for version in versioning.supported_versions: + for version in asdf.versioning.supported_versions: version_dir = os.path.join(_REFFILE_PATH, str(version)) if os.path.exists(version_dir): for filename in os.listdir(version_dir): @@ -31,10 +62,11 @@ def _compare_trees(name_without_ext): asdf_path = name_without_ext + ".asdf" yaml_path = name_without_ext + ".yaml" - with asdf_open(asdf_path) as af_handle: + with asdf.open(asdf_path) as af_handle: af_handle.resolve_references() - with asdf_open(yaml_path) as ref: + with asdf.open(yaml_path) as ref: + # TODO strip asdf_library and history assert_tree_match(af_handle.tree, ref.tree)