diff --git a/CHANGES.rst b/CHANGES.rst index 5945a2666..61ddf0b75 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -19,6 +19,8 @@ - Add new extension API to support versioned extensions. [#850, #851] +- Permit wildcard in tag validator URIs. [#858] + 2.7.0 (2020-07-23) ------------------ diff --git a/asdf/config.py b/asdf/config.py index 1396d93dd..31d82cf29 100644 --- a/asdf/config.py +++ b/asdf/config.py @@ -11,6 +11,8 @@ from . import versioning from ._helpers import validate_version from .extension import ExtensionProxy +from . import util + __all__ = ["AsdfConfig", "get_config", "config_context"] @@ -165,7 +167,7 @@ def remove_extension(self, extension=None, *, package=None): Parameters ---------- extension : asdf.extension.AsdfExtension or str, optional - An extension instance or URI to remove. + An extension instance or URI or URI pattern to remove. package : str, optional Remove only extensions provided by this package. If the `extension` argument is omitted, then all extensions from this package will @@ -181,7 +183,7 @@ def _remove_condition(e): result = True if isinstance(extension, str): - result = result and e.extension_uri == extension + result = result and util.uri_match(extension, e.extension_uri) elif isinstance(extension, ExtensionProxy): result = result and e == extension diff --git a/asdf/schema.py b/asdf/schema.py index 754077e7e..2fe3a51a8 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -76,18 +76,28 @@ def _type_to_tag(type_): return None -def validate_tag(validator, tagname, instance, schema): - +def validate_tag(validator, tag_pattern, instance, schema): + """ + Implements the tag validation directive, which checks the + tag against a pattern which may include '*' wildcards. + """ if hasattr(instance, '_tag'): instance_tag = instance._tag else: # Try tags for known Python builtins instance_tag = _type_to_tag(type(instance)) - if instance_tag is not None and instance_tag != tagname: + if instance_tag is None: + yield ValidationError( + "mismatched tags, wanted '{}', got unhandled object type '{}'".format( + tag_pattern, util.get_class_name(instance) + ) + ) + + if not util.uri_match(tag_pattern, instance_tag): yield ValidationError( "mismatched tags, wanted '{0}', got '{1}'".format( - tagname, instance_tag)) + tag_pattern, instance_tag)) def validate_propertyOrder(validator, order, instance, schema): diff --git a/asdf/tests/conftest.py b/asdf/tests/conftest.py index 627622d85..c8dd378d2 100644 --- a/asdf/tests/conftest.py +++ b/asdf/tests/conftest.py @@ -6,6 +6,7 @@ from . import create_small_tree, create_large_tree from asdf import config +from asdf import schema @pytest.fixture @@ -23,3 +24,15 @@ def restore_default_config(): yield config._global_config = config.AsdfConfig() config._local = config._ConfigLocal() + + +@pytest.fixture(autouse=True) +def clear_schema_cache(): + """ + Fixture that clears schema caches to prevent issues + when tests use same URI for different schema content. + """ + yield + schema._load_schema.cache_clear() + schema._load_schema_cached.cache_clear() + schema.load_custom_schema.cache_clear() diff --git a/asdf/tests/test_config.py b/asdf/tests/test_config.py index 48f33fe54..2fa75a179 100644 --- a/asdf/tests/test_config.py +++ b/asdf/tests/test_config.py @@ -230,6 +230,11 @@ class BarExtension: config.remove_extension(uri_extension.extension_uri) assert len(config.extensions) == len(original_extensions) + # And also by URI pattern: + config.add_extension(uri_extension) + config.remove_extension("asdf://somewhere.org/extensions/*") + assert len(config.extensions) == len(original_extensions) + # Remove by the name of the extension's package: config.add_extension(ExtensionProxy(new_extension, package_name="foo")) config.add_extension(ExtensionProxy(uri_extension, package_name="foo")) diff --git a/asdf/tests/test_schema.py b/asdf/tests/test_schema.py index 7961a2faf..83b6ea002 100644 --- a/asdf/tests/test_schema.py +++ b/asdf/tests/test_schema.py @@ -1058,3 +1058,37 @@ def _test_validator(validator, value, instance, schema): ) validator.validate(tree) assert len(visited_nodes) == 3 + + +def test_tag_validator(): + content="""%YAML 1.1 +--- +$schema: http://stsci.edu/schemas/asdf/asdf-schema-1.0.0 +id: asdf://somewhere.org/schemas/foo +tag: asdf://somewhere.org/tags/foo +... +""" + with asdf.config_context() as config: + config.add_resource_mapping({"asdf://somewhere.org/schemas/foo": content}) + + schema_tree = schema.load_schema("asdf://somewhere.org/schemas/foo") + instance = tagged.TaggedDict(tag="asdf://somewhere.org/tags/foo") + schema.validate(instance, schema=schema_tree) + with pytest.raises(ValidationError): + schema.validate(tagged.TaggedDict(tag="asdf://somewhere.org/tags/bar"), schema=schema_tree) + + content="""%YAML 1.1 +--- +$schema: http://stsci.edu/schemas/asdf/asdf-schema-1.0.0 +id: asdf://somewhere.org/schemas/bar +tag: asdf://somewhere.org/tags/bar-* +... +""" + with asdf.config_context() as config: + config.add_resource_mapping({"asdf://somewhere.org/schemas/bar": content}) + + schema_tree = schema.load_schema("asdf://somewhere.org/schemas/bar") + instance = tagged.TaggedDict(tag="asdf://somewhere.org/tags/bar-2.5") + schema.validate(instance, schema=schema_tree) + with pytest.raises(ValidationError): + schema.validate(tagged.TaggedDict(tag="asdf://somewhere.org/tags/foo-1.0"), schema=schema_tree) diff --git a/asdf/tests/test_util.py b/asdf/tests/test_util.py index 33c8071e3..83164c94d 100644 --- a/asdf/tests/test_util.py +++ b/asdf/tests/test_util.py @@ -1,3 +1,5 @@ +import pytest + from asdf import util from asdf.extension import BuiltinExtension @@ -40,3 +42,16 @@ def test_patched_urllib_parse(): assert urllib.parse is not util.patched_urllib_parse assert "asdf" not in urllib.parse.uses_relative assert "asdf" not in urllib.parse.uses_netloc + + +@pytest.mark.parametrize("pattern, uri, result", [ + ("asdf://somewhere.org/tags/foo-1.0", "asdf://somewhere.org/tags/foo-1.0", True), + ("asdf://somewhere.org/tags/foo-1.0", "asdf://somewhere.org/tags/bar-1.0", False), + ("asdf://somewhere.org/tags/foo-*", "asdf://somewhere.org/tags/foo-1.0", True), + ("asdf://somewhere.org/tags/foo-*", "asdf://somewhere.org/tags/bar-1.0", False), + ("asdf://*/tags/foo-*", "asdf://anywhere.org/tags/foo-4.9", True), + ("asdf://*/tags/foo-*", "asdf://anywhere.org/tags/bar-4.9", False), + ("asdf://somewhere.org/tags/foo-*", None, False), +]) +def test_uri_match(pattern, uri, result): + assert util.uri_match(pattern, uri) is result diff --git a/asdf/util.py b/asdf/util.py index 6e3e5bcb5..926b61b40 100644 --- a/asdf/util.py +++ b/asdf/util.py @@ -3,6 +3,8 @@ import struct import types import importlib.util +import re +from functools import lru_cache from urllib.request import pathname2url @@ -449,3 +451,37 @@ def is_primitive(value): or isinstance(value, complex) or isinstance(value, str) ) + + +def uri_match(pattern, uri): + """ + Determine if a URI matches a URI pattern with possible + wildcards. + + Parameters + ---------- + pattern : str + URI pattern with * wildcards. + uri : str + URI to check against the pattern. + + Returns + ------- + bool + `True` if URI matches the pattern. + """ + if not isinstance(uri, str): + return False + + if "*" in pattern: + return _compile_uri_match_pattern(pattern).match(uri) is not None + else: + return pattern == uri + + +@lru_cache(128) +def _compile_uri_match_pattern(pattern): + # Escape the pattern in case it contains regex special characters + # ('.' in particular is common in URIs) and then replace the + # escaped asterisk with a .* regex matcher. + return re.compile(re.escape(pattern).replace(r"\*", ".*"))