Skip to content

Commit

Permalink
add tree_type yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Dec 8, 2023
1 parent 4f31ffc commit 4b1a625
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 32 deletions.
43 changes: 19 additions & 24 deletions asdf/_asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import warnings

import yaml
from packaging.version import Version

from . import _compression as mcompression
Expand All @@ -28,6 +29,8 @@
from .tags.core import AsdfObject, ExtensionMetadata, HistoryEntry, Software
from .util import NotSet

_yaml_base_loader = yaml.CBaseLoader if getattr(yaml, "__with_libyaml__", None) else yaml.BaseLoader


def __getattr__(name):
if name == "SerializationContext":
Expand Down Expand Up @@ -762,13 +765,22 @@ def _open_asdf(
fd,
validate_checksums=False,
extensions=None,
tree_type="custom",
_get_yaml_content=False,
tree_type=None,
_force_raw_types=NotSet,
strict_extension_check=False,
ignore_missing_extensions=False,
):
"""Attempt to populate AsdfFile data from file-like object"""
if tree_type is None:
tree_type = "custom"

if _force_raw_types is not NotSet:
warnings.warn("_force_raw_types is deprecated and will be replaced by tree_type", AsdfDeprecationWarning)
tree_type = "tagged" if _force_raw_types else tree_type

if tree_type not in ("custom", "tagged", "yaml"):
msg = f"Unsupported tree type {tree_type}"
raise ValueError(msg)

if strict_extension_check and ignore_missing_extensions:
msg = "'strict_extension_check' and 'ignore_missing_extensions' are incompatible options"
Expand Down Expand Up @@ -815,11 +827,10 @@ def _open_asdf(
initial_content=yaml_token,
)

# For testing: just return the raw YAML content
if _get_yaml_content:
yaml_content = reader.read()
fd.close()
return yaml_content
# just return the raw YAML content
if tree_type == "yaml":
self._tree = yaml.load(reader.read(), Loader=_yaml_base_loader) # noqa: S506
return self

# We parse the YAML content into basic data structures
# now, but we don't do anything special with it until
Expand Down Expand Up @@ -852,19 +863,6 @@ def _open_asdf(
self.close()
raise

if _force_raw_types is not NotSet:
warnings.warn(
"_force_raw_types is deprecated and will be replaced by tree_type", AsdfDeprecationWarning
)
tree_type = "tagged" if _force_raw_types else "custom"

if tree_type is None:
tree_type = "custom"

if tree_type not in ("custom", "tagged"):
msg = f"Unsupported tree type {tree_type}"
raise ValueError(msg)

if tree_type == "custom":
tree = yamlutil.tagged_tree_to_custom_tree(tree, self)

Expand All @@ -885,7 +883,6 @@ def _open_impl(
validate_checksums=False,
extensions=None,
tree_type=None,
_get_yaml_content=False,
_force_raw_types=NotSet,
strict_extension_check=False,
ignore_missing_extensions=False,
Expand All @@ -900,7 +897,6 @@ def _open_impl(
validate_checksums=validate_checksums,
extensions=extensions,
tree_type=tree_type,
_get_yaml_content=_get_yaml_content,
_force_raw_types=_force_raw_types,
strict_extension_check=strict_extension_check,
ignore_missing_extensions=ignore_missing_extensions,
Expand Down Expand Up @@ -1511,7 +1507,6 @@ def open_asdf(
custom_schema=None,
strict_extension_check=False,
ignore_missing_extensions=False,
_get_yaml_content=False,
):
"""
Open an existing ASDF file.
Expand Down Expand Up @@ -1584,6 +1579,7 @@ def open_asdf(
- "custom" (default) for a tree with custom objects
- "tagged" for a tree with `asdf.tagged` objects
- "yaml" return the raw yaml contents of just file
Returns
-------
Expand Down Expand Up @@ -1620,7 +1616,6 @@ def open_asdf(
validate_checksums=validate_checksums,
extensions=extensions,
tree_type=tree_type,
_get_yaml_content=_get_yaml_content,
_force_raw_types=_force_raw_types,
strict_extension_check=strict_extension_check,
ignore_missing_extensions=ignore_missing_extensions,
Expand Down
8 changes: 4 additions & 4 deletions asdf/_tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import asdf
from asdf import generic_io, versioning
from asdf._asdf import AsdfFile, _get_asdf_library_info
from asdf.constants import YAML_TAG_PREFIX
from asdf.constants import YAML_END_MARKER_REGEX, YAML_TAG_PREFIX
from asdf.exceptions import AsdfConversionWarning
from asdf.tags.core import AsdfObject
from asdf.versioning import (
Expand Down Expand Up @@ -219,13 +219,13 @@ def _assert_roundtrip_tree(
asdf_check_func(ff)

buff.seek(0)
ff = AsdfFile(extensions=extensions, **init_options)
content = AsdfFile._open_impl(ff, buff, mode="r", _get_yaml_content=True)
gf = generic_io.get_file(buff)
content = gf.reader_until(YAML_END_MARKER_REGEX, 7, "End of YAML marker", include=True).read()
buff.close()
# We *never* want to get any raw python objects out
assert b"!!python" not in content
assert b"!core/asdf" in content
assert content.startswith(b"%YAML 1.1")
assert b"%YAML 1.1" in content
if raw_yaml_check_func:
raw_yaml_check_func(content)

Expand Down
4 changes: 3 additions & 1 deletion asdf/_tests/test_asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def test_fsspec_http(httpserver):
assert_tree_match(tree, af.tree)


@pytest.mark.parametrize("tree_type", [None, "custom", "tagged", "unknown"])
@pytest.mark.parametrize("tree_type", [None, "custom", "tagged", "yaml", "unknown"])
def test_asdf_open_tree_type(tmp_path, tree_type):
fn = tmp_path / "test.asdf"
asdf.AsdfFile({"a": np.zeros(3)}).write_to(fn)
Expand All @@ -407,5 +407,7 @@ def test_asdf_open_tree_type(tmp_path, tree_type):
with asdf.open(fn, tree_type=tree_type) as af:
if tree_type == "tagged":
assert isinstance(af["a"], asdf.tagged.TaggedDict)
elif tree_type == "yaml":
assert isinstance(af["a"], dict)
else:
assert isinstance(af["a"], asdf.tags.core.ndarray.NDArrayType)
4 changes: 1 addition & 3 deletions asdf/_tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,7 @@ def test_internal_reference(tmp_path):
buff = io.BytesIO()
ff.write_to(buff)
buff.seek(0)
ff = asdf.AsdfFile()
content = asdf.AsdfFile()._open_impl(ff, buff, _get_yaml_content=True)
assert b"{$ref: ''}" in content
assert b"{$ref: ''}" in buff.read()


def test_implicit_internal_reference(tmp_path):
Expand Down

0 comments on commit 4b1a625

Please sign in to comment.