From bca74f2a3b96713da2d3200b69f60f7dba9dc60e Mon Sep 17 00:00:00 2001 From: Brett Date: Fri, 19 Jan 2024 15:31:02 -0500 Subject: [PATCH] validate tagged objects using schema for tag --- asdf/_tests/test_history.py | 2 +- asdf/_tests/test_schema.py | 27 +++++++++++++++++++++++++++ asdf/schema.py | 37 +++++++++++++++++-------------------- 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/asdf/_tests/test_history.py b/asdf/_tests/test_history.py index 8da69a354..435fdc4c1 100644 --- a/asdf/_tests/test_history.py +++ b/asdf/_tests/test_history.py @@ -20,7 +20,7 @@ def test_history(): ) assert len(ff.tree["history"]["entries"]) == 1 - with pytest.raises(ValidationError, match=r".* is not valid under any of the given schemas"): + with pytest.raises(ValidationError, match=r"'name' is a required property"): ff.add_history_entry("That happened", {"author": "John Doe", "version": "2.0"}) assert len(ff.tree["history"]["entries"]) == 1 diff --git a/asdf/_tests/test_schema.py b/asdf/_tests/test_schema.py index 831c35f1a..cdb86745b 100644 --- a/asdf/_tests/test_schema.py +++ b/asdf/_tests/test_schema.py @@ -1268,3 +1268,30 @@ def test_tag_validator(): schema.validate(instance, schema=schema_tree) with pytest.raises(ValidationError, match=r"mismatched tags, wanted .*, got .*"): schema.validate(tagged.TaggedDict(tag="asdf://somewhere.org/tags/foo-1.0"), schema=schema_tree) + + +def test_tagged_object_validation(): + """ + Passing a tagged object to the asdf validator + should validate the object using the schema for the tag + """ + tag = "tag:stsci.edu:asdf/core/ndarray-1.0.0" + t = asdf.tagged.TaggedDict({"shape": "a"}, tag=tag) + + schema = { + "$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema", + "tag": tag, + } + + with pytest.raises(ValidationError, match=r"is not valid under any of the given schema"): + asdf.schema.validate(t, schema=schema) + + # and the custom schema should be validated + schema = { + "$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema", + "tag": "tag:stsci.edu:asdf/core/time-1.0.0", + } + + t = asdf.tagged.TaggedDict({"data": [1, 2, 3]}, tag=tag) + with pytest.raises(ValidationError, match=r"mismatched tags"): + asdf.schema.validate(t, schema=schema) diff --git a/asdf/schema.py b/asdf/schema.py index c09bad792..1283719b1 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -238,7 +238,7 @@ def _make_seen_key(self, instance, schema): @lru_cache -def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False): +def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False, ctx=None, serialization_context=None): meta_schema = _load_schema_cached(YAML_SCHEMA_METASCHEMA_ID, _tag_to_uri, False) type_checker = mvalidators.Draft4Validator.TYPE_CHECKER.redefine_many( @@ -255,17 +255,8 @@ def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False): type_checker=type_checker, id_of=id_of, ) - - def _patch_init(cls): - original_init = cls.__init__ - - def init(self, *args, **kwargs): - self.ctx = kwargs.pop("ctx", None) - self.serialization_context = kwargs.pop("serialization_context", None) - - original_init(self, *args, **kwargs) - - cls.__init__ = init + ASDFvalidator.ctx = ctx + ASDFvalidator.serialization_context = serialization_context def _patch_iter_errors(cls): original_iter_errors = cls.iter_errors @@ -289,8 +280,8 @@ def iter_errors(self, instance, *args, **kwargs): if (isinstance(instance, dict) and "$ref" in instance) or isinstance(instance, reference.Reference): return - if not self.schema: - tag = getattr(instance, "_tag", None) + if hasattr(instance, "_tag") and self.serialization_context is not None: + tag = instance._tag if tag is not None and self.serialization_context.extension_manager.handles_tag_definition(tag): tag_def = self.serialization_context.extension_manager.get_tag_definition(tag) schema_uris = tag_def.schema_uris @@ -299,10 +290,14 @@ def iter_errors(self, instance, *args, **kwargs): for schema_uri in schema_uris: try: with self.resolver.resolving(schema_uri) as resolved: - yield from self.descend(instance, resolved) + if id(resolved) != id(self.schema): + yield from self.descend(instance, resolved) except RefResolutionError: warnings.warn(f"Unable to locate schema file for '{tag}': '{schema_uri}'", AsdfWarning) + if self.schema: + yield from original_iter_errors(self, instance) + else: if isinstance(instance, dict): for val in instance.values(): yield from self.iter_errors(val) @@ -310,12 +305,9 @@ def iter_errors(self, instance, *args, **kwargs): elif isinstance(instance, list): for val in instance: yield from self.iter_errors(val) - else: - yield from original_iter_errors(self, instance) cls.iter_errors = iter_errors - _patch_init(ASDFvalidator) _patch_iter_errors(ASDFvalidator) return ASDFvalidator @@ -556,8 +548,13 @@ def get_validator( # time of this writing, it was half of the runtime of the unit # test suite!!!). Instead, we assume that the schemas are valid # through the running of the unit tests, not at run time. - cls = _create_validator(validators=validators, visit_repeat_nodes=_visit_repeat_nodes) - return cls({} if schema is None else schema, *args, ctx=ctx, serialization_context=_serialization_context, **kwargs) + cls = _create_validator( + validators=validators, + visit_repeat_nodes=_visit_repeat_nodes, + ctx=ctx, + serialization_context=_serialization_context, + ) + return cls({} if schema is None else schema, *args, **kwargs) def _validate_large_literals(instance, reading):