Skip to content

Commit

Permalink
validate tagged objects using schema for tag
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed May 8, 2024
1 parent f64fdb8 commit bca74f2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
2 changes: 1 addition & 1 deletion asdf/_tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_history():
)
assert len(ff.tree["history"]["entries"]) == 1

with pytest.raises(ValidationError, match=r".* is not valid under any of the given schemas"):
with pytest.raises(ValidationError, match=r"'name' is a required property"):
ff.add_history_entry("That happened", {"author": "John Doe", "version": "2.0"})
assert len(ff.tree["history"]["entries"]) == 1

Expand Down
27 changes: 27 additions & 0 deletions asdf/_tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,3 +1268,30 @@ def test_tag_validator():
schema.validate(instance, schema=schema_tree)
with pytest.raises(ValidationError, match=r"mismatched tags, wanted .*, got .*"):
schema.validate(tagged.TaggedDict(tag="asdf://somewhere.org/tags/foo-1.0"), schema=schema_tree)


def test_tagged_object_validation():
"""
Passing a tagged object to the asdf validator
should validate the object using the schema for the tag
"""
tag = "tag:stsci.edu:asdf/core/ndarray-1.0.0"
t = asdf.tagged.TaggedDict({"shape": "a"}, tag=tag)

schema = {
"$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema",
"tag": tag,
}

with pytest.raises(ValidationError, match=r"is not valid under any of the given schema"):
asdf.schema.validate(t, schema=schema)

# and the custom schema should be validated
schema = {
"$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema",
"tag": "tag:stsci.edu:asdf/core/time-1.0.0",
}

t = asdf.tagged.TaggedDict({"data": [1, 2, 3]}, tag=tag)
with pytest.raises(ValidationError, match=r"mismatched tags"):
asdf.schema.validate(t, schema=schema)
37 changes: 17 additions & 20 deletions asdf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _make_seen_key(self, instance, schema):


@lru_cache
def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False):
def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False, ctx=None, serialization_context=None):
meta_schema = _load_schema_cached(YAML_SCHEMA_METASCHEMA_ID, _tag_to_uri, False)

type_checker = mvalidators.Draft4Validator.TYPE_CHECKER.redefine_many(
Expand All @@ -255,17 +255,8 @@ def _create_validator(validators=YAML_VALIDATORS, visit_repeat_nodes=False):
type_checker=type_checker,
id_of=id_of,
)

def _patch_init(cls):
original_init = cls.__init__

def init(self, *args, **kwargs):
self.ctx = kwargs.pop("ctx", None)
self.serialization_context = kwargs.pop("serialization_context", None)

original_init(self, *args, **kwargs)

cls.__init__ = init
ASDFvalidator.ctx = ctx
ASDFvalidator.serialization_context = serialization_context

def _patch_iter_errors(cls):
original_iter_errors = cls.iter_errors
Expand All @@ -289,8 +280,8 @@ def iter_errors(self, instance, *args, **kwargs):
if (isinstance(instance, dict) and "$ref" in instance) or isinstance(instance, reference.Reference):
return

if not self.schema:
tag = getattr(instance, "_tag", None)
if hasattr(instance, "_tag") and self.serialization_context is not None:
tag = instance._tag
if tag is not None and self.serialization_context.extension_manager.handles_tag_definition(tag):
tag_def = self.serialization_context.extension_manager.get_tag_definition(tag)
schema_uris = tag_def.schema_uris
Expand All @@ -299,23 +290,24 @@ def iter_errors(self, instance, *args, **kwargs):
for schema_uri in schema_uris:
try:
with self.resolver.resolving(schema_uri) as resolved:
yield from self.descend(instance, resolved)
if id(resolved) != id(self.schema):
yield from self.descend(instance, resolved)
except RefResolutionError:
warnings.warn(f"Unable to locate schema file for '{tag}': '{schema_uri}'", AsdfWarning)

if self.schema:
yield from original_iter_errors(self, instance)
else:
if isinstance(instance, dict):
for val in instance.values():
yield from self.iter_errors(val)

elif isinstance(instance, list):
for val in instance:
yield from self.iter_errors(val)
else:
yield from original_iter_errors(self, instance)

cls.iter_errors = iter_errors

_patch_init(ASDFvalidator)
_patch_iter_errors(ASDFvalidator)

return ASDFvalidator
Expand Down Expand Up @@ -556,8 +548,13 @@ def get_validator(
# time of this writing, it was half of the runtime of the unit
# test suite!!!). Instead, we assume that the schemas are valid
# through the running of the unit tests, not at run time.
cls = _create_validator(validators=validators, visit_repeat_nodes=_visit_repeat_nodes)
return cls({} if schema is None else schema, *args, ctx=ctx, serialization_context=_serialization_context, **kwargs)
cls = _create_validator(
validators=validators,
visit_repeat_nodes=_visit_repeat_nodes,
ctx=ctx,
serialization_context=_serialization_context,
)
return cls({} if schema is None else schema, *args, **kwargs)


def _validate_large_literals(instance, reading):
Expand Down

0 comments on commit bca74f2

Please sign in to comment.