From 752d96e813c218536238fa20368be14abd4272c5 Mon Sep 17 00:00:00 2001 From: Ed Slavich Date: Tue, 4 Aug 2020 13:16:22 -0400 Subject: [PATCH] Add Extension and Converter ABCs --- CHANGES.rst | 2 +- asdf/asdf.py | 56 +++- asdf/config.py | 10 +- asdf/entry_points.py | 4 +- asdf/extension/__init__.py | 11 +- asdf/extension/_converter.py | 358 ++++++++++++++++++++++++ asdf/extension/_extension.py | 142 +++++++++- asdf/extension/_legacy.py | 6 +- asdf/extension/_manager.py | 210 ++++++++++++++ asdf/extension/_tag.py | 76 +++++ asdf/fits_embed.py | 7 +- asdf/schema.py | 25 +- asdf/tests/test_asdf.py | 8 +- asdf/tests/test_entry_points.py | 49 ++++ asdf/tests/test_extension.py | 474 +++++++++++++++++++++++++++++++- asdf/tests/test_integration.py | 70 +++++ asdf/tests/test_resource.py | 10 + asdf/yamlutil.py | 64 ++++- 18 files changed, 1517 insertions(+), 65 deletions(-) create mode 100644 asdf/extension/_converter.py create mode 100644 asdf/extension/_manager.py create mode 100644 asdf/extension/_tag.py create mode 100644 asdf/tests/test_integration.py diff --git a/CHANGES.rst b/CHANGES.rst index 61ddf0b75..777b5bec8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -17,7 +17,7 @@ - Drop support for Python 3.5. [#856] - Add new extension API to support versioned extensions. - [#850, #851] + [#850, #851, #853] - Permit wildcard in tag validator URIs. [#858] diff --git a/asdf/asdf.py b/asdf/asdf.py index a9777b568..5437762ef 100644 --- a/asdf/asdf.py +++ b/asdf/asdf.py @@ -27,6 +27,7 @@ AsdfExtension, ExtensionProxy, get_cached_asdf_extension_list, + get_cached_extension_manager, ) from .util import NotSet from .search import AsdfSearchResult @@ -71,9 +72,10 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None, extensions : object, optional Additional extensions to use when reading and writing the file. - May be any of the following: `asdf.extension.AsdfExtension`, `str` - extension URI, `asdf.extension.AsdfExtensionList` or a `list` - of URIs and/or extensions. + May be any of the following: `asdf.extension.AsdfExtension`, + `asdf.extension.Extension`, `str` extension URI, + `asdf.extension.AsdfExtensionList` or a `list` of URIs and/or + extensions. version : str, optional The ASDF Standard version. If not provided, defaults to the @@ -184,7 +186,7 @@ def version(self): @version.setter def version(self, value): - """" + """ Set this AsdfFile's ASDF Standard version. Parameters @@ -219,7 +221,7 @@ def extensions(self): Returns ------- - list of asdf.extension.AsdfExtension + list of asdf.extension.AsdfExtension or asdf.extension.Extension """ return self._extensions @@ -231,15 +233,30 @@ def extensions(self, value): Parameters ---------- - value : list of asdf.extension.AsdfExtension + value : list of asdf.extension.AsdfExtension or asdf.extension.Extension """ self._extensions = self._process_extensions(value) + self._extension_manager = None self._extension_list = None + @property + def extension_manager(self): + """ + Get the ExtensionManager for this AsdfFile. + + Returns + ------- + asdf.extension.ExtensionManager + """ + if self._extension_manager is None: + self._extension_manager = get_cached_extension_manager(self.extensions) + return self._extension_manager + @property def extension_list(self): """ Get the AsdfExtensionList for this AsdfFile. + Returns ------- asdf.extension.AsdfExtensionList @@ -1549,7 +1566,7 @@ def _warn_tag_mismatch(self, tag, best_tag): # This function is called from within yamlutil methods to create # a context when one isn't explicitly passed in. def _create_serialization_context(self): - return SerializationContext(self.version_string) + return SerializationContext(self.version_string, self.extension_manager) # Inherit docstring from dictionary @@ -1605,9 +1622,10 @@ def open_asdf(fd, uri=None, mode=None, validate_checksums=False, extensions : object, optional Additional extensions to use when reading and writing the file. - May be any of the following: `asdf.extension.AsdfExtension`, `str` - extension URI, `asdf.extension.AsdfExtensionList` or a `list` - of URIs and/or extensions. + May be any of the following: `asdf.extension.AsdfExtension`, + `asdf.extension.Extension`, `str` extension URI, + `asdf.extension.AsdfExtensionList` or a `list` of URIs and/or + extensions. do_not_fill_defaults : bool, optional When `True`, do not fill in missing default values. @@ -1729,8 +1747,9 @@ class SerializationContext: """ Container for parameters of the current (de)serialization. """ - def __init__(self, version): + def __init__(self, version, extension_manager): self._version = validate_version(version) + self._extension_manager = extension_manager self._extensions_used = set() @@ -1745,13 +1764,24 @@ def version(self): """ return self._version + @property + def extension_manager(self): + """ + Get the ExtensionManager for enabled extensions. + + Returns + ------- + asdf.extension.ExtensionManager + """ + return self._extension_manager + def mark_extension_used(self, extension): """ Note that an extension was used when reading or writing the file. Parameters ---------- - extension : asdf.extension.AsdfExtension + extension : asdf.extension.AsdfExtension or asdf.extension.Extension """ self._extensions_used.add(ExtensionProxy.maybe_wrap(extension)) @@ -1762,6 +1792,6 @@ def extensions_used(self): Returns ------- - set of asdf.extension.AsdfExtension + set of asdf.extension.AsdfExtension or asdf.extension.Extension """ return self._extensions_used diff --git a/asdf/config.py b/asdf/config.py index 31d82cf29..48c137067 100644 --- a/asdf/config.py +++ b/asdf/config.py @@ -139,7 +139,7 @@ def extensions(self): Returns ------- - list of asdf.extension.AsdfExtension + list of asdf.extension.AsdfExtension or asdf.extension.Extension """ if self._extensions is None: with self._lock: @@ -154,7 +154,7 @@ def add_extension(self, extension): Parameters ---------- - extension : asdf.extension.AsdfExtension + extension : asdf.extension.AsdfExtension or asdf.extension.Extension """ with self._lock: extension = ExtensionProxy.maybe_wrap(extension) @@ -166,8 +166,8 @@ def remove_extension(self, extension=None, *, package=None): Parameters ---------- - extension : asdf.extension.AsdfExtension or str, optional - An extension instance or URI or URI pattern to remove. + extension : asdf.extension.AsdfExtension or asdf.extension.Extension or str, optional + An extension instance or URI pattern to remove. package : str, optional Remove only extensions provided by this package. If the `extension` argument is omitted, then all extensions from this package will @@ -206,7 +206,7 @@ def get_extension(self, extension_uri): Returns ------- - asdf.extension.AsdfExtension + asdf.extension.AsdfExtension or asdf.extension.Extension Raises ------ diff --git a/asdf/entry_points.py b/asdf/entry_points.py index 60652ff6d..959626d51 100644 --- a/asdf/entry_points.py +++ b/asdf/entry_points.py @@ -7,6 +7,7 @@ RESOURCE_MAPPINGS_GROUP = "asdf.resource_mappings" +EXTENSIONS_GROUP = "asdf.extensions" LEGACY_EXTENSIONS_GROUP = "asdf_extensions" @@ -15,9 +16,10 @@ def get_resource_mappings(): def get_extensions(): + extensions = _list_entry_points(EXTENSIONS_GROUP, ExtensionProxy) legacy_extensions = _list_entry_points(LEGACY_EXTENSIONS_GROUP, ExtensionProxy) - return legacy_extensions + return extensions + legacy_extensions def _list_entry_points(group, proxy_class): diff --git a/asdf/extension/__init__.py b/asdf/extension/__init__.py index c82cb414c..ab7852825 100644 --- a/asdf/extension/__init__.py +++ b/asdf/extension/__init__.py @@ -2,7 +2,10 @@ Support for plugins that extend asdf to serialize additional custom types. """ -from ._extension import ExtensionProxy +from ._extension import Extension, ExtensionProxy +from ._manager import ExtensionManager, get_cached_extension_manager +from ._tag import TagDefinition +from ._converter import Converter, ConverterProxy from ._legacy import ( AsdfExtension, AsdfExtensionList, @@ -15,7 +18,13 @@ __all__ = [ # New API + "Extension", "ExtensionProxy", + "ExtensionManager", + "get_cached_extension_manager", + "TagDefinition", + "Converter", + "ConverterProxy", # Legacy API "AsdfExtension", "AsdfExtensionList", diff --git a/asdf/extension/_converter.py b/asdf/extension/_converter.py new file mode 100644 index 000000000..8bb8beda4 --- /dev/null +++ b/asdf/extension/_converter.py @@ -0,0 +1,358 @@ +""" +Support for Converter, the new API for serializing custom +types. Will eventually replace the `asdf.types` module. +""" +import abc + +from ..util import get_class_name, uri_match +from ._tag import TagDefinition + + +class Converter(abc.ABC): + """ + Abstract base class for plugins that convert nodes from the + parsed YAML tree into custom objects, and vice versa. + + Implementing classes must provide the `tags` and `types` + properties and `to_yaml_tree` and `from_yaml_tree` methods. + The `select_tag` method is optional. + """ + @classmethod + def __subclasshook__(cls, C): + if cls is Converter: + return (hasattr(C, "tags") and + hasattr(C, "types") and + hasattr(C, "to_yaml_tree") and + hasattr(C, "from_yaml_tree")) + return NotImplemented # pragma: no cover + + @abc.abstractproperty + def tags(self): + """ + Get the YAML tags that this converter is capable of + handling. URI patterns are permitted, see + `asdf.util.uri_match` for details. + + Returns + ------- + iterable of str + Tag URIs or URI patterns. + """ + pass # pragma: no cover + + @abc.abstractproperty + def types(self): + """ + Get the Python types that this converter is capable of + handling. + + Returns + ------- + iterable of str or type + If str, the fully qualified class name of the type. + """ + pass # pragma: no cover + + def select_tag(self, obj, tags, ctx): + """ + Select the tag to use when converting an object to YAML. + Typically only one tag will be active in a given context, but + converters that map one type to many tags must provide logic + to choose the appropriate tag. + + Parameters + ---------- + obj : object + Instance of the custom type being converted. Guaranteed + to be an instance of one of the types listed in the + `types` property. + tags : list of str + List of active tags to choose from. Guaranteed to match + one of the tag patterns listed in the 'tags' property. + ctx : asdf.asdf.SerializationContext + Context of the current serialization request. + + Returns + ------- + str + The selected tag. Should be one of the tags passed + to this method in the `tags` parameter. + """ + return tags[0] + + @abc.abstractmethod + def to_yaml_tree(self, obj, tag, ctx): + """ + Convert an object into a node suitable for YAML serialization. + This method is not responsible for writing actual YAML; rather, it + converts an instance of a custom type to a built-in Python object type + (such as dict, list, str, or number), which can then be automatically + serialized to YAML as needed. + + For container types returned by this method (dict or list), + the children of the container need not themselves be converted. + Any list elements or dict values will be converted by subsequent + calls to to_yaml_tree implementations. + + The returned node must be an instance of `dict`, `list`, or `str`. + Children may be any type supported by an available Converter. + + Parameters + ---------- + obj : object + Instance of a custom type to be serialized. Guaranteed to + be an instance of one of the types listed in the `types` + property. + tag : str + The tag identifying the YAML type that `obj` should be + converted into. Selected by a call to this converter's + select_tag method. + ctx : asdf.asdf.SerializationContext + The context of the current serialization request. + + Returns + ------- + dict or list or str + The YAML node representation of the object. + """ + pass # pragma: no cover + + @abc.abstractmethod + def from_yaml_tree(self, node, tag, ctx): + """ + Convert a YAML node into an instance of a custom type. + + For container types received by this method (dict or list), + the children of the container will have already been converted + by prior calls to from_yaml_tree implementations. + + Note on circular references: trees that reference themselves + among their descendants must be handled with care. Most + implementations need not concern themselves with this case, but + if the custom type supports circular references, then the + implementation of this method will need to return a generator. + Consult the documentation for more details. + + Parameters + ---------- + tree : dict or list or str + The YAML node to convert. + tag : str + The YAML tag of the object being converted. + ctx : asdf.asdf.SerializationContext + The context of the current deserialization request. + + Returns + ------- + object + An instance of one of the types listed in the `types` property, + or a generator that yields such an instance. + """ + pass # pragma: no cover + + +class ConverterProxy(Converter): + """ + Proxy that wraps a `Converter` and provides default + implementations of optional methods. + """ + def __init__(self, delegate, extension): + if not isinstance(delegate, Converter): + raise TypeError("Converter must implement the asdf.extension.Converter interface") + + self._delegate = delegate + self._extension = extension + self._class_name = get_class_name(delegate) + + # Sort these out up-front so that errors are raised when the extension is loaded + # and not in the middle of the user's session. The extension will fail to load + # and a warning will be emitted, but it won't crash the program. + + relevant_tags = set() + for tag in delegate.tags: + if isinstance(tag, str): + relevant_tags.update(t.tag_uri for t in extension.tags if uri_match(tag, t.tag_uri)) + else: + raise TypeError("Converter property 'tags' must contain str values") + + if len(relevant_tags) > 1 and not hasattr(delegate, "select_tag"): + raise RuntimeError( + "Converter handles multiple tags for this extension, " + "but does not implement a select_tag method." + ) + + self._tags = sorted(relevant_tags) + + self._types = [] + for typ in delegate.types: + if isinstance(typ, (str, type)): + self._types.append(typ) + else: + raise TypeError("Converter property 'types' must contain str or type values") + + @property + def tags(self): + """ + Get the list of tag URIs that this converter is capable of + handling. + + Returns + ------- + list of str + """ + return self._tags + + @property + def types(self): + """ + Get the Python types that this converter is capable of + handling. + + Returns + ------- + list of type or str + """ + return self._types + + def select_tag(self, obj, ctx): + """ + Select the tag to use when converting an object to YAML. + + Parameters + ---------- + obj : object + Instance of the custom type being converted. + ctx : asdf.asdf.SerializationContext + Serialization parameters. + + Returns + ------- + str + Selected tag. + """ + method = getattr(self._delegate, "select_tag", None) + if method is None: + return self._tags[0] + else: + return method(obj, self._tags, ctx) + + def to_yaml_tree(self, obj, tag, ctx): + """ + Convert an object into a node suitable for YAML serialization. + + Parameters + ---------- + obj : object + Instance of a custom type to be serialized. + tag : str + The tag identifying the YAML type that `obj` should be + converted into. + ctx : asdf.asdf.SerializationContext + Serialization parameters. + + Returns + ------- + object + The YAML node representation of the object. + """ + return self._delegate.to_yaml_tree(obj, tag, ctx) + + def from_yaml_tree(self, node, tag, ctx): + """ + Convert a YAML node into an instance of a custom type. + + Parameters + ---------- + tree : dict or list or str + The YAML node to convert. + tag : str + The YAML tag of the object being converted. + ctx : asdf.asdf.SerializationContext + Serialization parameters. + + Returns + ------- + object + """ + return self._delegate.from_yaml_tree(node, tag, ctx) + + @property + def delegate(self): + """ + Get the wrapped converter instance. + + Returns + ------- + asdf.extension.Converter + """ + return self._delegate + + @property + def extension(self): + """ + Get the extension that provided this converter. + + Returns + ------- + asdf.extension.Extension + """ + return self._extension + + @property + def package_name(self): + """ + Get the name of the Python package of this converter's + extension. This may not be the same package that implements + the converter's class. + + Returns + ------- + str or None + Package name, or `None` if the extension was added at runtime. + """ + return self.extension.package_name + + @property + def package_version(self): + """ + Get the version of the Python package of this converter's + extension. This may not be the same package that implements + the converter's class. + + Returns + ------- + str or None + Package version, or `None` if the extension was added at runtime. + """ + return self.extension.package_version + + @property + def class_name(self): + """ + Get the fully qualified class name of this converter. + + Returns + ------- + str + """ + return self._class_name + + def __eq__(self, other): + if isinstance(other, ConverterProxy): + return other.delegate is self.delegate and other.extension is self.extension + else: + return False + + def __hash__(self): + return hash((id(self.delegate), id(self.extension))) + + def __repr__(self): + if self.package_name is None: + package_description = "(none)" + else: + package_description = "{}=={}".format(self.package_name, self.package_version) + + return "".format( + self.class_name, + package_description, + ) diff --git a/asdf/extension/_extension.py b/asdf/extension/_extension.py index 8ce9e2564..b0a838e8d 100644 --- a/asdf/extension/_extension.py +++ b/asdf/extension/_extension.py @@ -1,10 +1,91 @@ +import abc + from packaging.specifiers import SpecifierSet from ..util import get_class_name +from ._tag import TagDefinition from ._legacy import AsdfExtension +from ._converter import ConverterProxy + + +class Extension(abc.ABC): + """ + Abstract base class defining an extension to ASDF. + + Implementing classes must provide the `extension_uri`. + Other properties are optional. + """ + @classmethod + def __subclasshook__(cls, C): + if cls is Extension: + return hasattr(C, "extension_uri") + return NotImplemented # pragma: no cover + + @abc.abstractproperty + def extension_uri(self): + """ + Get the URI of the extension to the ASDF Standard implemented + by this class. Note that this may not uniquely identify the + class itself. + + Returns + ------- + str + """ + pass # pragma: no cover + + @property + def legacy_class_names(self): + """ + Get the set of fully-qualified class names used by older + versions of this extension. This allows a new-style + implementation of an extension to prevent warnings when a + legacy extension is missing. + + Returns + ------- + iterable of str + """ + return set() + + @property + def asdf_standard_requirement(self): + """ + Get the ASDF Standard version requirement for this extension. + + Returns + ------- + str or None + If str, PEP 440 version specifier. + If None, support all versions. + """ + return None + @property + def converters(self): + """ + Get the `asdf.extension.Converter` instances for tags + and Python types supported by this extension. + + Returns + ------- + iterable of asdf.extension.Converter + """ + return [] -class ExtensionProxy(AsdfExtension): + @property + def tags(self): + """ + Get the YAML tags supported by this extension. + + Returns + ------- + iterable of str or asdf.extension.TagDefinition + """ + return [] + + +class ExtensionProxy(Extension, AsdfExtension): """ Proxy that wraps an extension, provides default implementations of optional methods, and carries additional information on the @@ -18,8 +99,10 @@ def maybe_wrap(self, delegate): return ExtensionProxy(delegate) def __init__(self, delegate, package_name=None, package_version=None): - if not isinstance(delegate, AsdfExtension): - raise TypeError("Extension must implement the AsdfExtension interface") + if not isinstance(delegate, (Extension, AsdfExtension)): + raise TypeError( + "Extension must implement the Extension or AsdfExtension interface" + ) self._delegate = delegate self._package_name = package_name @@ -27,7 +110,11 @@ def __init__(self, delegate, package_name=None, package_version=None): self._class_name = get_class_name(delegate) - self._legacy = True + self._legacy = isinstance(delegate, AsdfExtension) + + # Sort these out up-front so that errors are raised when the extension is loaded + # and not in the middle of the user's session. The extension will fail to load + # and a warning will be emitted, but it won't crash the program. self._legacy_class_names = set() for class_name in getattr(self._delegate, "legacy_class_names", []): @@ -36,6 +123,9 @@ def __init__(self, delegate, package_name=None, package_version=None): else: raise TypeError("Extension property 'legacy_class_names' must contain str values") + if self._legacy: + self._legacy_class_names.add(self._class_name) + value = getattr(self._delegate, "asdf_standard_requirement", None) if isinstance(value, str): self._asdf_standard_requirement = SpecifierSet(value) @@ -44,13 +134,25 @@ def __init__(self, delegate, package_name=None, package_version=None): else: raise TypeError("Extension property 'asdf_standard_requirement' must be str or None") - if self._legacy: - self._legacy_class_names.add(self._class_name) + self._tags = [] + for tag in getattr(self._delegate, "tags", []): + if isinstance(tag, str): + self._tags.append(TagDefinition(tag)) + elif isinstance(tag, TagDefinition): + self._tags.append(tag) + else: + raise TypeError("Extension property 'tags' must contain str or asdf.extension.TagDefinition values") + + # Process the converters last, since they expect ExtensionProxy + # properties to already be available. + self._converters = [ConverterProxy(c, self) for c in getattr(self._delegate, "converters", [])] @property def extension_uri(self): """ - Get the extension's identifying URI. + Get the URI of the extension to the ASDF Standard implemented + by this class. Note that this may not uniquely identify the + class itself. Returns ------- @@ -68,7 +170,7 @@ def legacy_class_names(self): Returns ------- - iterable of str + set of str """ return self._legacy_class_names @@ -83,6 +185,28 @@ def asdf_standard_requirement(self): """ return self._asdf_standard_requirement + @property + def converters(self): + """ + Get the extension's converters. + + Returns + ------- + list of asdf.extension.Converter + """ + return self._converters + + @property + def tags(self): + """ + Get the YAML tags supported by this extension. + + Returns + ------- + list of asdf.extension.TagDefinition + """ + return self._tags + @property def types(self): """ @@ -123,7 +247,7 @@ def delegate(self): Returns ------- - asdf.extension.AsdfExtension + asdf.extension.Extension or asdf.extension.AsdfExtension """ return self._delegate diff --git a/asdf/extension/_legacy.py b/asdf/extension/_legacy.py index 6cf7faa6f..55f41cde0 100644 --- a/asdf/extension/_legacy.py +++ b/asdf/extension/_legacy.py @@ -8,12 +8,10 @@ from ..exceptions import AsdfDeprecationWarning -ASDF_TEST_BUILD_ENV = 'ASDF_TEST_BUILD' - - class AsdfExtension(metaclass=abc.ABCMeta): """ - Abstract base class defining an extension to ASDF. + Abstract base class defining a (legacy) extension to ASDF. + New code should use `asdf.extension.Extension` instead. """ @classmethod def __subclasshook__(cls, C): diff --git a/asdf/extension/_manager.py b/asdf/extension/_manager.py new file mode 100644 index 000000000..525943695 --- /dev/null +++ b/asdf/extension/_manager.py @@ -0,0 +1,210 @@ +from functools import lru_cache + +from ._extension import ExtensionProxy +from ..util import get_class_name + + +class ExtensionManager: + """ + Wraps a list of extensions and indexes their converters + by tag and by Python type. + + Parameters + ---------- + extensions : iterable of asdf.extension.Extension + List of enabled extensions to manage. Extensions placed earlier + in the list take precedence. + """ + def __init__(self, extensions): + self._extensions = [ExtensionProxy.maybe_wrap(e) for e in extensions] + + self._tag_defs_by_tag = {} + self._converters_by_tag = {} + # This dict has both str and type keys: + self._converters_by_type = {} + + for extension in self._extensions: + for tag_def in extension.tags: + if tag_def.tag_uri not in self._tag_defs_by_tag: + self._tag_defs_by_tag[tag_def.tag_uri] = tag_def + for converter in extension.converters: + # If a converter's tags do not actually overlap with + # the extension tag list, then there's no reason to + # use it. + if len(converter.tags) > 0: + for tag in converter.tags: + if tag not in self._converters_by_tag: + self._converters_by_tag[tag] = converter + for typ in converter.types: + if isinstance(typ, str): + if typ not in self._converters_by_type: + self._converters_by_type[typ] = converter + else: + type_class_name = get_class_name(typ, instance=False) + if typ not in self._converters_by_type and type_class_name not in self._converters_by_type: + self._converters_by_type[typ] = converter + self._converters_by_type[type_class_name] = converter + + @property + def extensions(self): + """ + Get the list of extensions. + + Returns + ------- + list of asdf.extension.Extension + """ + return self._extensions + + def handles_tag(self, tag): + """ + Return `True` if the specified tag is handled by a + converter. + + Parameters + ---------- + tag : str + Tag URI. + + Returns + ------- + bool + """ + return tag in self._converters_by_tag + + def handles_type(self, typ): + """ + Returns `True` if the specified Python type is handled + by a converter. + + Parameters + ---------- + typ : type + + Returns + ------- + bool + """ + return ( + typ in self._converters_by_type + or get_class_name(typ, instance=False) in self._converters_by_type + ) + + def get_tag_definition(self, tag): + """ + Get the tag definition for the specified tag. + + Parameters + ---------- + tag : str + Tag URI. + + Returns + ------- + asdf.extension.TagDefinition + + Raises + ------ + KeyError + Unrecognized tag URI. + """ + try: + return self._tag_defs_by_tag[tag] + except KeyError: + raise KeyError( + "No support available for YAML tag '{}'. " + "You may need to install a missing extension.".format( + tag + ) + ) from None + + def get_converter_for_tag(self, tag): + """ + Get the converter for the specified tag. + + Parameters + ---------- + tag : str + Tag URI. + + Returns + ------- + asdf.extension.Converter + + Raises + ------ + KeyError + Unrecognized tag URI. + """ + try: + return self._converters_by_tag[tag] + except KeyError: + raise KeyError( + "No support available for YAML tag '{}'. " + "You may need to install a missing extension.".format( + tag + ) + ) from None + + def get_converter_for_type(self, typ): + """ + Get the converter for the specified Python type. + + Parameters + ---------- + typ : type + + Returns + ------- + asdf.extension.AsdfConverter + + Raises + ------ + KeyError + Unrecognized type. + """ + try: + return self._converters_by_type[typ] + except KeyError: + class_name = get_class_name(typ, instance=False) + try: + return self._converters_by_type[class_name] + except KeyError: + raise KeyError( + "No support available for Python type '{}'. " + "You may need to install or enable an extension.".format( + get_class_name(typ, instance=False) + ) + ) from None + + +def get_cached_extension_manager(extensions): + """ + Get a previously created ExtensionManager for the specified + extensions, or create and cache one if necessary. Building + the manager is expensive, so it helps performance to reuse + it when possible. + + Parameters + ---------- + extensions : list of asdf.extension.AsdfExtension or asdf.extension.Extension + + Returns + ------- + asdf.extension.ExtensionManager + """ + from ._extension import ExtensionProxy + # The tuple makes the extensions hashable so that we + # can pass them to the lru_cache method. The ExtensionProxy + # overrides __hash__ to return the hashed object id of the wrapped + # extension, so this will method will only return the same + # ExtensionManager if the list contains identical extension + # instances in identical order. + extensions = tuple(ExtensionProxy.maybe_wrap(e) for e in extensions) + + return _get_cached_extension_manager(extensions) + + +@lru_cache() +def _get_cached_extension_manager(extensions): + return ExtensionManager(extensions) diff --git a/asdf/extension/_tag.py b/asdf/extension/_tag.py new file mode 100644 index 000000000..1d6ea7186 --- /dev/null +++ b/asdf/extension/_tag.py @@ -0,0 +1,76 @@ +""" +TagDefinition is defined here to avoid a circular reference between +_extension.py and _converter.py. +""" +class TagDefinition: + """ + Container for properties of a custom YAML tag. + + Parameters + ---------- + tag_uri : str + Tag URI. + schema_uri : str, optional + URI of the schema that should be used to validate objects + with this tag. + title : str, optional + Short description of the tag. + description : str, optional + Long description of the tag. + """ + def __init__(self, tag_uri, *, schema_uri=None, title=None, description=None): + if "*" in tag_uri: + raise ValueError("URI patterns are not permitted in TagDefinition") + + self._tag_uri = tag_uri + self._schema_uri = schema_uri + self._title = title + self._description = description + + @property + def tag_uri(self): + """ + Get the tag URI. + + Returns + ------- + str + """ + return self._tag_uri + + @property + def schema_uri(self): + """ + Get the URI of the schema that should be used to validate + objects wtih this tag. + + Returns + ------- + str or None + """ + return self._schema_uri + + @property + def title(self): + """ + Get the short description of the tag. + + Returns + ------- + str or None + """ + return self._title + + @property + def description(self): + """ + Get the long description of the tag. + + Returns + ------- + str or None + """ + return self._description + + def __repr__(self): + return ("".format(self.tag_uri)) diff --git a/asdf/fits_embed.py b/asdf/fits_embed.py index fc4772409..485db6fea 100644 --- a/asdf/fits_embed.py +++ b/asdf/fits_embed.py @@ -198,9 +198,10 @@ def open(cls, fd, uri=None, validate_checksums=False, extensions=None, extensions : object, optional Additional extensions to use when reading and writing the file. - May be any of the following: `asdf.extension.AsdfExtension`, `str` - extension URI, `asdf.extension.AsdfExtensionList` or a `list` - of URIs and/or extensions. + May be any of the following: `asdf.extension.AsdfExtension`, + `asdf.extension.Extension`, `str` extension URI, + `asdf.extension.AsdfExtensionList` or a `list` of URIs and/or + extensions. ignore_version_mismatch : bool, optional When `True`, do not raise warnings for mismatched schema versions. diff --git a/asdf/schema.py b/asdf/schema.py index ffe5432a8..174db1a27 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -303,19 +303,27 @@ def iter_errors(self, instance, _schema=None): if _schema is None: tag = getattr(instance, '_tag', None) if tag is not None: - schema_path = self.ctx.resolver(tag) - if schema_path != tag: + if self.serialization_context.extension_manager.handles_tag(tag): + tag_def = self.serialization_context.extension_manager.get_tag_definition(tag) + schema_uri = tag_def.schema_uri + else: + schema_uri = self.ctx.tag_mapping(tag) + if schema_uri == tag: + schema_uri = None + + if schema_uri is not None: try: - s = _load_schema_cached(schema_path, self.ctx.resolver, False, False) + s = _load_schema_cached(schema_uri, self.ctx.resolver, False, False) except FileNotFoundError: msg = "Unable to locate schema file for '{}': '{}'" - warnings.warn(msg.format(tag, schema_path), AsdfWarning) + warnings.warn(msg.format(tag, schema_uri), AsdfWarning) s = {} if s: - with self.resolver.in_scope(schema_path): + with self.resolver.in_scope(schema_uri): for x in super(ASDFValidator, self).iter_errors(instance, s): yield x + if isinstance(instance, dict): for val in instance.values(): for x in self.iter_errors(val): @@ -492,7 +500,8 @@ def resolve_refs(node, json_id): def get_validator(schema={}, ctx=None, validators=None, url_mapping=None, - *args, _visit_repeat_nodes=False, **kwargs): + *args, _visit_repeat_nodes=False, _serialization_context=None, + **kwargs): """ Get a JSON schema validator object for the given schema. @@ -530,6 +539,9 @@ def get_validator(schema={}, ctx=None, validators=None, url_mapping=None, from .asdf import AsdfFile ctx = AsdfFile() + if _serialization_context is None: + _serialization_context = ctx._create_serialization_context() + if validators is None: validators = util.HashableDict(YAML_VALIDATORS.copy()) validators.update(ctx.extension_list.validators) @@ -544,6 +556,7 @@ def get_validator(schema={}, ctx=None, validators=None, url_mapping=None, cls = _create_validator(validators=validators, visit_repeat_nodes=_visit_repeat_nodes) validator = cls(schema, *args, **kwargs) validator.ctx = ctx + validator.serialization_context = _serialization_context return validator diff --git a/asdf/tests/test_asdf.py b/asdf/tests/test_asdf.py index 54e96d3a1..500d9d1e7 100644 --- a/asdf/tests/test_asdf.py +++ b/asdf/tests/test_asdf.py @@ -4,7 +4,7 @@ from asdf import config_context, get_config from asdf.versioning import AsdfVersion from asdf.exceptions import AsdfWarning -from asdf.extension import ExtensionProxy, AsdfExtensionList +from asdf.extension import ExtensionProxy, AsdfExtensionList, ExtensionManager from asdf.tests.helpers import yaml_to_asdf, assert_no_warnings @@ -211,8 +211,10 @@ def test_open_asdf_extensions(tmpdir): def test_serialization_context(): - context = SerializationContext("1.4.0") + extension_manager = ExtensionManager([]) + context = SerializationContext("1.4.0", extension_manager) assert context.version == "1.4.0" + assert context.extension_manager is extension_manager assert context.extensions_used == set() extension = get_config().extensions[0] @@ -227,7 +229,7 @@ def test_serialization_context(): context.mark_extension_used(object()) with pytest.raises(ValueError): - SerializationContext("0.5.4") + SerializationContext("0.5.4", extension_manager) def test_reading_extension_metadata(): diff --git a/asdf/tests/test_entry_points.py b/asdf/tests/test_entry_points.py index f7f3f2271..ca5da7d1b 100644 --- a/asdf/tests/test_entry_points.py +++ b/asdf/tests/test_entry_points.py @@ -70,6 +70,34 @@ def test_get_resource_mappings(mock_entry_points): assert len(mappings) == 2 +class MinimumExtension: + def __init__(self, extension_uri): + self._extension_uri = extension_uri + + @property + def extension_uri(self): + return self._extension_uri + + +def extensions_entry_point_successful(): + return [ + MinimumExtension("http://somewhere.org/extensions/foo-1.0"), + MinimumExtension("http://somewhere.org/extensions/bar-1.0"), + ] + + +def extensions_entry_point_failing(): + raise Exception("NOPE") + + +def extensions_entry_point_bad_element(): + return [ + MinimumExtension("http://somewhere.org/extensions/baz-1.0"), + object(), + MinimumExtension("http://somewhere.org/extensions/foz-1.0"), + ] + + class LegacyExtension: types = [] tag_mapping = [] @@ -81,6 +109,27 @@ class FauxLegacyExtension: def test_get_extensions(mock_entry_points): + mock_entry_points.append(("asdf.extensions", "successful", "extensions_entry_point_successful")) + extensions = entry_points.get_extensions() + assert len(extensions) == 2 + for e in extensions: + assert isinstance(e, ExtensionProxy) + assert e.package_name == "asdf" + assert e.package_version == asdf_package_version + + mock_entry_points.clear() + mock_entry_points.append(("asdf.extensions", "failing", "extensions_entry_point_failing")) + with pytest.warns(AsdfWarning, match="Exception: NOPE"): + extensions = entry_points.get_extensions() + assert len(extensions) == 0 + + mock_entry_points.clear() + mock_entry_points.append(("asdf.extensions", "bad_element", "extensions_entry_point_bad_element")) + with pytest.warns(AsdfWarning, match="TypeError: Extension must implement the Extension or AsdfExtension interface"): + extensions = entry_points.get_extensions() + assert len(extensions) == 2 + + mock_entry_points.clear() mock_entry_points.append(("asdf_extensions", "legacy", "LegacyExtension")) extensions = entry_points.get_extensions() assert len(extensions) == 1 diff --git a/asdf/tests/test_extension.py b/asdf/tests/test_extension.py index 2c0c0db7c..8e0aff646 100644 --- a/asdf/tests/test_extension.py +++ b/asdf/tests/test_extension.py @@ -2,9 +2,16 @@ from packaging.specifiers import SpecifierSet from asdf.extension import ( - BuiltinExtension, + Extension, ExtensionProxy, - get_cached_asdf_extension_list, + ExtensionManager, + get_cached_extension_manager, + TagDefinition, + Converter, + ConverterProxy, + AsdfExtension, + BuiltinExtension, + get_cached_asdf_extension_list ) from asdf.types import CustomType @@ -28,8 +35,93 @@ class LegacyExtension: url_mapping = [("http://somewhere.org/", "http://somewhere.org/{url_suffix}.yaml")] -def test_proxy_maybe_wrap(): - extension = LegacyExtension() +class MinimumExtension: + extension_uri = "asdf://somewhere.org/extensions/minimum-1.0" + + +class MinimumExtensionSubclassed(Extension): + extension_uri = "asdf://somewhere.org/extensions/minimum-1.0" + + +class FullExtension: + extension_uri = "asdf://somewhere.org/extensions/full-1.0" + + def __init__( + self, + converters=None, + asdf_standard_requirement=None, + tags=None, + legacy_class_names=None, + ): + self._converters = [] if converters is None else converters + self._asdf_standard_requirement = asdf_standard_requirement + self._tags = tags + self._legacy_class_names = [] if legacy_class_names is None else legacy_class_names + + @property + def converters(self): + return self._converters + + @property + def asdf_standard_requirement(self): + return self._asdf_standard_requirement + + @property + def tags(self): + return self._tags + + @property + def legacy_class_names(self): + return self._legacy_class_names + + +class MinimumConverter: + def __init__(self, tags=None, types=None): + if tags is None: + self._tags = [] + else: + self._tags = tags + + if types is None: + self._types = [] + else: + self._types = types + + @property + def tags(self): + return self._tags + + @property + def types(self): + return self._types + + def to_yaml_tree(self, obj, tag, ctx): + return "to_yaml_tree result" + + def from_yaml_tree(self, obj, tag, ctx): + return "from_yaml_tree result" + + +class FullConverter(MinimumConverter): + def select_tag(self, obj, tags, ctx): + return "select_tag result" + + +# Some dummy types for testing converters: +class FooType: + pass + + +class BarType: + pass + + +class BazType: + pass + + +def test_extension_proxy_maybe_wrap(): + extension = MinimumExtension() proxy = ExtensionProxy.maybe_wrap(extension) assert proxy.delegate is extension assert ExtensionProxy.maybe_wrap(proxy) is proxy @@ -38,13 +130,151 @@ def test_proxy_maybe_wrap(): ExtensionProxy.maybe_wrap(object()) -def test_proxy_legacy(): +def test_extension_proxy(): + # Test with minimum properties: + extension = MinimumExtension() + proxy = ExtensionProxy(extension) + + assert isinstance(proxy, Extension) + assert isinstance(proxy, AsdfExtension) + + assert proxy.extension_uri == "asdf://somewhere.org/extensions/minimum-1.0" + assert proxy.legacy_class_names == set() + assert proxy.asdf_standard_requirement == SpecifierSet() + assert proxy.converters == [] + assert proxy.tags == [] + assert proxy.types == [] + assert proxy.tag_mapping == [] + assert proxy.url_mapping == [] + assert proxy.delegate is extension + assert proxy.legacy is False + assert proxy.package_name is None + assert proxy.package_version is None + assert proxy.class_name == "asdf.tests.test_extension.MinimumExtension" + + # The subclassed version should have the same defaults: + extension = MinimumExtensionSubclassed() + subclassed_proxy = ExtensionProxy(extension) + assert subclassed_proxy.extension_uri == proxy.extension_uri + assert subclassed_proxy.legacy_class_names == proxy.legacy_class_names + assert subclassed_proxy.asdf_standard_requirement == proxy.asdf_standard_requirement + assert subclassed_proxy.converters == proxy.converters + assert subclassed_proxy.tags == proxy.tags + assert subclassed_proxy.types == proxy.types + assert subclassed_proxy.tag_mapping == proxy.tag_mapping + assert subclassed_proxy.url_mapping == proxy.url_mapping + assert subclassed_proxy.delegate is extension + assert subclassed_proxy.legacy == proxy.legacy + assert subclassed_proxy.package_name == proxy.package_name + assert subclassed_proxy.package_version == proxy.package_name + assert subclassed_proxy.class_name == "asdf.tests.test_extension.MinimumExtensionSubclassed" + + # Test with all properties present: + converters = [ + MinimumConverter( + tags=["asdf://somewhere.org/extensions/full/tags/foo-*"], + types=[] + ) + ] + extension = FullExtension( + converters=converters, + asdf_standard_requirement=">=1.4.0", + tags=["asdf://somewhere.org/extensions/full/tags/foo-1.0"], + legacy_class_names=["foo.extensions.SomeOldExtensionClass"] + ) + proxy = ExtensionProxy(extension, package_name="foo", package_version="1.2.3") + + assert proxy.extension_uri == "asdf://somewhere.org/extensions/full-1.0" + assert proxy.legacy_class_names == {"foo.extensions.SomeOldExtensionClass"} + assert proxy.asdf_standard_requirement == SpecifierSet(">=1.4.0") + assert proxy.converters == [ConverterProxy(c, proxy) for c in converters] + assert len(proxy.tags) == 1 + assert proxy.tags[0].tag_uri == "asdf://somewhere.org/extensions/full/tags/foo-1.0" + assert proxy.types == [] + assert proxy.tag_mapping == [] + assert proxy.url_mapping == [] + assert proxy.delegate is extension + assert proxy.legacy is False + assert proxy.package_name == "foo" + assert proxy.package_version == "1.2.3" + assert proxy.class_name == "asdf.tests.test_extension.FullExtension" + + # Should fail when the input is not one of the two extension interfaces: + with pytest.raises(TypeError): + ExtensionProxy(object) + + # Should fail with a bad converter: + with pytest.raises(TypeError): + ExtensionProxy(FullExtension(converters=[object()])) + + # Unparseable ASDF Standard requirement: + with pytest.raises(ValueError): + ExtensionProxy(FullExtension(asdf_standard_requirement="asdf-standard >= 1.4.0")) + + # Unrecognized ASDF Standard requirement type: + with pytest.raises(TypeError): + ExtensionProxy(FullExtension(asdf_standard_requirement=object())) + + # Bad tag: + with pytest.raises(TypeError): + ExtensionProxy(FullExtension(tags=[object()])) + + # Bad legacy class names: + with pytest.raises(TypeError): + ExtensionProxy(FullExtension(legacy_class_names=[object])) + + +def test_extension_proxy_tags(): + """ + The tags behavior is a tad complex, so they get their own test. + """ + foo_tag_uri = "asdf://somewhere.org/extensions/full/tags/foo-1.0" + foo_tag_def = TagDefinition( + foo_tag_uri, + schema_uri="asdf://somewhere.org/extensions/full/schemas/foo-1.0", + title="Some tag title", + description="Some tag description" + ) + + bar_tag_uri = "asdf://somewhere.org/extensions/full/tags/bar-1.0" + bar_tag_def = TagDefinition( + bar_tag_uri, + schema_uri="asdf://somewhere.org/extensions/full/schemas/bar-1.0", + title="Some other tag title", + description="Some other tag description" + ) + + # The converter should return only the tags + # supported by the extension. + converter = FullConverter(tags=["**"]) + extension = FullExtension(tags=[foo_tag_def], converters=[converter]) + proxy = ExtensionProxy(extension) + assert proxy.converters[0].tags == [foo_tag_uri] + + # The converter should not return tags that + # its patterns do not match. + converter = FullConverter(tags=["**/foo-1.0"]) + extension = FullExtension(tags=[foo_tag_def, bar_tag_def], converters=[converter]) + proxy = ExtensionProxy(extension) + assert proxy.converters[0].tags == [foo_tag_uri] + + # The process should still work if the extension property + # contains str instead of TagDescription. + converter = FullConverter(tags=["**/foo-1.0"]) + extension = FullExtension(tags=[foo_tag_uri, bar_tag_uri], converters=[converter]) + proxy = ExtensionProxy(extension) + assert proxy.converters[0].tags == [foo_tag_uri] + + +def test_extension_proxy_legacy(): extension = LegacyExtension() proxy = ExtensionProxy(extension, package_name="foo", package_version="1.2.3") assert proxy.extension_uri is None assert proxy.legacy_class_names == {"asdf.tests.test_extension.LegacyExtension"} assert proxy.asdf_standard_requirement == SpecifierSet() + assert proxy.converters == [] + assert proxy.tags == [] assert proxy.types == [LegacyType] assert proxy.tag_mapping == LegacyExtension.tag_mapping assert proxy.url_mapping == LegacyExtension.url_mapping @@ -55,8 +285,8 @@ def test_proxy_legacy(): assert proxy.class_name == "asdf.tests.test_extension.LegacyExtension" -def test_proxy_hash_and_eq(): - extension = LegacyExtension() +def test_extension_proxy_hash_and_eq(): + extension = MinimumExtension() proxy1 = ExtensionProxy(extension) proxy2 = ExtensionProxy(extension, package_name="foo", package_version="1.2.3") @@ -66,16 +296,238 @@ def test_proxy_hash_and_eq(): assert proxy2 != extension -def test_proxy_repr(): +def test_extension_proxy_repr(): + proxy = ExtensionProxy(MinimumExtension(), package_name="foo", package_version="1.2.3") + assert "class: asdf.tests.test_extension.MinimumExtension" in repr(proxy) + assert "package: foo==1.2.3" in repr(proxy) + assert "legacy: False" in repr(proxy) + + proxy = ExtensionProxy(MinimumExtension()) + assert "class: asdf.tests.test_extension.MinimumExtension" in repr(proxy) + assert "package: (none)" in repr(proxy) + assert "legacy: False" in repr(proxy) + proxy = ExtensionProxy(LegacyExtension(), package_name="foo", package_version="1.2.3") assert "class: asdf.tests.test_extension.LegacyExtension" in repr(proxy) assert "package: foo==1.2.3" in repr(proxy) assert "legacy: True" in repr(proxy) - proxy = ExtensionProxy(LegacyExtension()) - assert "class: asdf.tests.test_extension.LegacyExtension" in repr(proxy) + +def test_extension_manager(): + converter1 = FullConverter( + tags=[ + "asdf://somewhere.org/extensions/full/tags/foo-*", + "asdf://somewhere.org/extensions/full/tags/bar-*", + ], + types=[ + FooType, + "asdf.tests.test_extension.BarType", + ], + ) + converter2 = FullConverter( + tags=[ + "asdf://somewhere.org/extensions/full/tags/baz-*", + ], + types=[ + BazType + ], + ) + converter3= FullConverter( + tags=[ + "asdf://somewhere.org/extensions/full/tags/foo-*", + ], + types=[ + FooType, + BarType, + ], + ) + extension1 = FullExtension( + converters=[converter1, converter2], + tags=[ + "asdf://somewhere.org/extensions/full/tags/foo-1.0", + "asdf://somewhere.org/extensions/full/tags/baz-1.0", + ] + ) + extension2 = FullExtension( + converters=[converter3], + tags = [ + "asdf://somewhere.org/extensions/full/tags/foo-1.0", + ] + ) + + manager = ExtensionManager([extension1, extension2]) + + assert manager.extensions == [ExtensionProxy(extension1), ExtensionProxy(extension2)] + + assert manager.handles_tag("asdf://somewhere.org/extensions/full/tags/foo-1.0") is True + assert manager.handles_tag("asdf://somewhere.org/extensions/full/tags/bar-1.0") is False + assert manager.handles_tag("asdf://somewhere.org/extensions/full/tags/baz-1.0") is True + + assert manager.handles_type(FooType) is True + # This should return True even though BarType was listed + # as string class name: + assert manager.handles_type(BarType) is True + assert manager.handles_type(BazType) is True + + assert manager.get_tag_definition("asdf://somewhere.org/extensions/full/tags/foo-1.0").tag_uri == "asdf://somewhere.org/extensions/full/tags/foo-1.0" + assert manager.get_tag_definition("asdf://somewhere.org/extensions/full/tags/baz-1.0").tag_uri == "asdf://somewhere.org/extensions/full/tags/baz-1.0" + with pytest.raises(KeyError): + manager.get_tag_definition("asdf://somewhere.org/extensions/full/tags/bar-1.0") + + assert manager.get_converter_for_tag("asdf://somewhere.org/extensions/full/tags/foo-1.0").delegate is converter1 + assert manager.get_converter_for_tag("asdf://somewhere.org/extensions/full/tags/baz-1.0").delegate is converter2 + with pytest.raises(KeyError): + manager.get_converter_for_tag("asdf://somewhere.org/extensions/full/tags/bar-1.0") + + assert manager.get_converter_for_type(FooType).delegate is converter1 + assert manager.get_converter_for_type(BarType).delegate is converter1 + assert manager.get_converter_for_type(BazType).delegate is converter2 + with pytest.raises(KeyError): + manager.get_converter_for_type(object) + + +def test_get_cached_extension_manager(): + extension = MinimumExtension() + extension_manager = get_cached_extension_manager([extension]) + assert get_cached_extension_manager([extension]) is extension_manager + assert get_cached_extension_manager([MinimumExtension()]) is not extension_manager + + +def test_tag_definition(): + tag_def = TagDefinition( + "asdf://somewhere.org/extensions/foo/tags/foo-1.0", + schema_uri="asdf://somewhere.org/extensions/foo/schemas/foo-1.0", + title="Some title", + description="Some description", + ) + + assert tag_def.tag_uri == "asdf://somewhere.org/extensions/foo/tags/foo-1.0" + assert tag_def.schema_uri == "asdf://somewhere.org/extensions/foo/schemas/foo-1.0" + assert tag_def.title == "Some title" + assert tag_def.description == "Some description" + + assert "URI: asdf://somewhere.org/extensions/foo/tags/foo-1.0" in repr(tag_def) + + with pytest.raises(ValueError): + TagDefinition("asdf://somewhere.org/extensions/foo/tags/foo-*") + + +def test_converter(): + class ConverterNoSubclass: + tags = [] + types = [] + + def to_yaml_tree(self, *args): + pass + + def from_yaml_tree(self, *args): + pass + + assert issubclass(ConverterNoSubclass, Converter) + + class ConverterWithSubclass(Converter): + tags = [] + types = [] + + def to_yaml_tree(self, *args): + pass + + def from_yaml_tree(self, *args): + pass + + # Confirm the behavior of the default select_tag implementation + assert ConverterWithSubclass().select_tag(object(), ["tag1", "tag2"], object()) == "tag1" + + +def test_converter_proxy(): + # Test the minimum set of converter methods: + extension = ExtensionProxy(MinimumExtension()) + converter = MinimumConverter() + proxy = ConverterProxy(converter, extension) + + assert isinstance(proxy, Converter) + + assert proxy.tags == [] + assert proxy.types == [] + assert proxy.to_yaml_tree(None, None, None) == "to_yaml_tree result" + assert proxy.from_yaml_tree(None, None, None) == "from_yaml_tree result" + assert proxy.tags == [] + assert proxy.delegate is converter + assert proxy.extension == extension + assert proxy.package_name is None + assert proxy.package_version is None + assert proxy.class_name == "asdf.tests.test_extension.MinimumConverter" + + # Check the __eq__ and __hash__ behavior: + assert proxy == ConverterProxy(converter, extension) + assert proxy != ConverterProxy(MinimumConverter(), extension) + assert proxy != ConverterProxy(converter, MinimumExtension()) + assert proxy in {ConverterProxy(converter, extension)} + assert proxy not in { + ConverterProxy(MinimumConverter(), extension), + ConverterProxy(converter, MinimumExtension()) + } + + # Check the __repr__: + assert "class: asdf.tests.test_extension.MinimumConverter" in repr(proxy) assert "package: (none)" in repr(proxy) - assert "legacy: True" in repr(proxy) + + # Test the full set of converter methods: + converter = FullConverter( + tags=[ + "asdf://somewhere.org/extensions/test/tags/foo-*", + "asdf://somewhere.org/extensions/test/tags/bar-*", + ], + types=[FooType, BarType] + ) + + extension = FullExtension( + tags=[ + TagDefinition( + "asdf://somewhere.org/extensions/test/tags/foo-1.0", + schema_uri="asdf://somewhere.org/extensions/test/schemas/foo-1.0", + title="Foo tag title", + description="Foo tag description" + ), + TagDefinition( + "asdf://somewhere.org/extensions/test/tags/bar-1.0", + schema_uri="asdf://somewhere.org/extensions/test/schemas/bar-1.0", + title="Bar tag title", + description="Bar tag description" + ), + ] + ) + + extension_proxy = ExtensionProxy(extension, package_name="foo", package_version="1.2.3") + proxy = ConverterProxy(converter, extension_proxy) + assert len(proxy.tags) == 2 + assert "asdf://somewhere.org/extensions/test/tags/foo-1.0" in proxy.tags + assert "asdf://somewhere.org/extensions/test/tags/bar-1.0" in proxy.tags + assert proxy.types == [FooType, BarType] + assert proxy.to_yaml_tree(None, None, None) == "to_yaml_tree result" + assert proxy.from_yaml_tree(None, None, None) == "from_yaml_tree result" + assert proxy.select_tag(None, None) == "select_tag result" + assert proxy.delegate is converter + assert proxy.extension == extension_proxy + assert proxy.package_name == "foo" + assert proxy.package_version == "1.2.3" + assert proxy.class_name == "asdf.tests.test_extension.FullConverter" + + # Check the __repr__ since it will contain package info now: + assert "class: asdf.tests.test_extension.FullConverter" in repr(proxy) + assert "package: foo==1.2.3" in repr(proxy) + + # Should error because object() does fulfill the Converter interface: + with pytest.raises(TypeError): + ConverterProxy(object(), extension) + + # Should fail because tags must be str: + with pytest.raises(TypeError): + ConverterProxy(MinimumConverter(tags=[object()]), extension) + + # Should fail because types must instances of type: + with pytest.raises(TypeError): + ConverterProxy(MinimumConverter(types=[object()]), extension) def test_get_cached_asdf_extension_list(): diff --git a/asdf/tests/test_integration.py b/asdf/tests/test_integration.py new file mode 100644 index 000000000..8238df6f6 --- /dev/null +++ b/asdf/tests/test_integration.py @@ -0,0 +1,70 @@ +""" +Integration tests for the new plugin APIs. +""" +import pytest + +import asdf +from asdf.extension import TagDefinition + + +FOO_SCHEMA_URI = "asdf://somewhere.org/extensions/foo/schemas/foo-1.0" +FOO_SCHEMA = """ +id: {} +type: object +properties: + value: + type: string +required: ["value"] +""".format(FOO_SCHEMA_URI) + + +class Foo: + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + +class FooConverter: + types = [Foo] + tags = ["asdf://somewhere.org/extensions/foo/tags/foo-*"] + + def to_yaml_tree(self, obj, tag, ctx): + return { + "value": obj.value + } + + def from_yaml_tree(self, obj, tag, ctx): + return Foo(obj["value"]) + + +class FooExtension: + extension_uri = "asdf://somewhere.org/extensions/foo-1.0" + converters = [FooConverter()] + tags = [ + TagDefinition( + "asdf://somewhere.org/extensions/foo/tags/foo-1.0", + schema_uri=FOO_SCHEMA_URI, + ) + ] + + +def test_serialize_custom_type(tmpdir): + with asdf.config_context() as config: + config.add_resource_mapping({FOO_SCHEMA_URI: FOO_SCHEMA}) + config.add_extension(FooExtension()) + + path = str(tmpdir/"test.asdf") + + af = asdf.AsdfFile() + af["foo"] = Foo("bar") + af.write_to(path) + + with asdf.open(path) as af2: + assert af2["foo"].value == "bar" + + with pytest.raises(asdf.ValidationError): + af["foo"] = Foo(12) + af.write_to(path) diff --git a/asdf/tests/test_resource.py b/asdf/tests/test_resource.py index 8cd476060..1de5fee0e 100644 --- a/asdf/tests/test_resource.py +++ b/asdf/tests/test_resource.py @@ -1,6 +1,7 @@ import io import sys from pathlib import Path +from collections.abc import Mapping if sys.version_info < (3, 9): import importlib_resources @@ -29,6 +30,7 @@ def test_directory_resource_mapping(tmpdir): f.write("id: http://somewhere.org/schemas/baz-7.8.9\n") mapping = DirectoryResourceMapping(str(tmpdir/"schemas"), "http://somewhere.org/schemas") + assert isinstance(mapping, Mapping) assert len(mapping) == 1 assert set(mapping) == {"http://somewhere.org/schemas/foo-1.2.3"} assert "http://somewhere.org/schemas/foo-1.2.3" in mapping @@ -191,6 +193,8 @@ def test_resource_manager(): } manager = ResourceManager([mapping1, mapping2]) + assert isinstance(manager, Mapping) + assert len(manager) == 4 assert set(manager) == { "http://somewhere.org/schemas/foo-1.0.0", @@ -216,6 +220,8 @@ def test_resource_manager(): def test_jsonschema_resource_mapping(): mapping = JsonschemaResourceMapping() + assert isinstance(mapping, Mapping) + assert len(mapping) == 1 assert set(mapping) == {"http://json-schema.org/draft-04/schema"} assert "http://json-schema.org/draft-04/schema" in mapping @@ -238,6 +244,10 @@ def test_get_core_resource_mappings(uri): assert uri.encode("utf-8") in mapping[uri] +def test_proxy_is_mapping(): + assert isinstance(ResourceMappingProxy({}), Mapping) + + def test_proxy_maybe_wrap(): mapping = { "http://somewhere.org/resources/foo": "foo", diff --git a/asdf/yamlutil.py b/asdf/yamlutil.py index a3745bf75..bb7fe7f71 100644 --- a/asdf/yamlutil.py +++ b/asdf/yamlutil.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict +from types import GeneratorType import numpy as np @@ -221,15 +222,54 @@ def custom_tree_to_tagged_tree(tree, ctx, _serialization_context=None): if _serialization_context is None: _serialization_context = ctx._create_serialization_context() - def walker(node): - tag = ctx.type_index.from_custom_type(type(node), ctx.version_string, _serialization_context=_serialization_context) - if tag is not None: - return tag.to_tree_tagged(node, ctx) - return node + extension_manager = _serialization_context.extension_manager + + def _convert_obj(obj): + converter = extension_manager.get_converter_for_type(type(obj)) + tag = converter.select_tag(obj, _serialization_context) + node = converter.to_yaml_tree(obj, tag, _serialization_context) + + if isinstance(node, GeneratorType): + generator = node + node = next(generator) + else: + generator = None + + if isinstance(node, dict): + tagged_node = tagged.TaggedDict(node, tag) + elif isinstance(node, list): + tagged_node = tagged.TaggedList(node, tag) + elif isinstance(node, str): + tagged_node = tagged.TaggedString(node) + tagged_node._tag = tag + else: + raise TypeError( + "AsdfConverter returned illegal node type: {}".format(util.get_class_name(node)) + ) + + _serialization_context.mark_extension_used(converter.extension) + + yield tagged_node + if generator is not None: + yield from generator + + def _walker(obj): + if extension_manager.handles_type(type(obj)): + return _convert_obj(obj) + else: + tag = ctx.type_index.from_custom_type( + type(obj), + ctx.version_string, + _serialization_context=_serialization_context + ) + + if tag is not None: + return tag.to_tree_tagged(obj, ctx) + return obj return treeutil.walk_and_modify( tree, - walker, + _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, # Walk the tree in preorder, so that extensions can return # container nodes with unserialized children. @@ -246,7 +286,9 @@ def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_ if _serialization_context is None: _serialization_context = ctx._create_serialization_context() - def walker(node): + extension_manager = _serialization_context.extension_manager + + def _walker(node): if force_raw_types: return node @@ -254,6 +296,12 @@ def walker(node): if tag is None: return node + if extension_manager.handles_tag(tag): + converter = extension_manager.get_converter_for_tag(tag) + obj = converter.from_yaml_tree(node.data, tag, _serialization_context) + _serialization_context.mark_extension_used(converter.extension) + return obj + tag_type = ctx.type_index.from_yaml_tag(ctx, tag, _serialization_context=_serialization_context) # This means the tag did not correspond to any type in our type index. if tag_type is None: @@ -289,7 +337,7 @@ def walker(node): return treeutil.walk_and_modify( tree, - walker, + _walker, ignore_implicit_conversion=ctx._ignore_implicit_conversion, # Walk the tree in postorder, so that extensions receive # container nodes with children already deserialized.