Skip to content

Commit

Permalink
feat: add schema support for asdf standard tags (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
ketozhang authored Nov 17, 2024
2 parents 28a01b5 + d2ae563 commit c6de000
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 1 deletion.
50 changes: 49 additions & 1 deletion asdf_pydantic/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
from typing import Optional
"""
## Adding existing ASDF tags as a field
Type annotation must be added to the field to specify the ASDF tag to use in the
ASDF schema. There are a few options to do this:
- Use `AsdfTag` to specify the tag URI.
- Use `WithAsdfSchema` and pass in a dictionary to extend the schema with
additional properties. The key `"$ref"` can be used to specify the tag URI.
from asdf_pydantic import AsdfPydanticModel
from asdf_pydantic.schema import AsdfTag
from astropy.table import Table
class MyModel(AsdfPydanticModel):
table: Annotated[Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0")]
For more customization of the ASDF schema output, you can use `WithAsdfSchema` to
extend the schema with additional properties.
# Changing the title of the field
table: Annotated[
Table,
WithAsdfSchema({
"title": "TABLE",
"$ref": "http://stsci.edu/schemas/asdf.org/table/table-1.1.0"
}),
]
"""

from typing import Literal, Optional

from pydantic import WithJsonSchema
from pydantic.json_schema import GenerateJsonSchema

DEFAULT_ASDF_SCHEMA_REF_TEMPLATE = "#/definitions/{model}"
Expand Down Expand Up @@ -60,3 +91,20 @@ def generate(self, schema, mode="validation"):
}

return json_schema


class WithAsdfSchema(WithJsonSchema):
def __init__(self, asdf_schema: dict, **kwargs):
super().__init__(asdf_schema, **kwargs)


def AsdfTag(tag: str, mode: Literal["auto", "ref", "tag"] = "auto") -> WithAsdfSchema:
if mode == "auto":
parsed_mode = "tag" if tag.startswith("tag") else "ref"
else:
parsed_mode = mode

if parsed_mode == "tag":
return WithAsdfSchema({"tag": tag})
else:
return WithAsdfSchema({"$ref": tag})
71 changes: 71 additions & 0 deletions tests/examples/test_astropy_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import Annotated

import asdf
import astropy.units as u
import pytest
import yaml
from asdf.extension import Extension
from astropy.table import Table
from astropy.units import Quantity

from asdf_pydantic import AsdfPydanticConverter, AsdfPydanticModel
from asdf_pydantic.schema import AsdfTag


class Database(AsdfPydanticModel):
_tag = "asdf://asdf-pydantic/examples/tags/database-1.0.0"
positions: Annotated[
Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0")
]


@pytest.fixture()
def asdf_extension():
"""Registers an ASDF extension containing models for this test."""
AsdfPydanticConverter.add_models(Database)

class TestExtension(Extension):
extension_uri = "asdf://asdf-pydantic/examples/extensions/test-1.0.0"

converters = [AsdfPydanticConverter()] # type: ignore
tags = [*AsdfPydanticConverter().tags] # type: ignore

asdf.get_config().add_extension(TestExtension())

with asdf.config_context() as asdf_config:
asdf_config.add_resource_mapping(
{
yaml.safe_load(Database.model_asdf_schema())[
"id"
]: Database.model_asdf_schema()
}
)
print(Database.model_asdf_schema())
asdf_config.add_extension(TestExtension())
yield asdf_config


@pytest.mark.usefixtures("asdf_extension")
def test_convert_to_asdf(tmp_path):
database = Database(
positions=Table(
{
"x": Quantity([1, 2, 3], u.m),
"y": Quantity([4, 5, 6], u.m),
}
)
)
asdf.AsdfFile({"data": database}).write_to(tmp_path / "test.asdf")

with asdf.open(tmp_path / "test.asdf") as af:
assert isinstance(af.tree["data"], Database)
assert isinstance(af.tree["data"].positions, Table)


@pytest.mark.usefixtures("asdf_extension")
def test_check_schema():
"""Tests the model schema is correct."""
schema = yaml.safe_load(Database.model_asdf_schema())
asdf.schema.check_schema(schema)
31 changes: 31 additions & 0 deletions tests/schema_validation_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tempfile import NamedTemporaryFile
from typing import Annotated

import asdf
import pydantic
Expand All @@ -7,6 +8,7 @@
from asdf.extension import Extension

from asdf_pydantic import AsdfPydanticConverter
from asdf_pydantic.model import AsdfPydanticModel
from tests.examples.shapes import AsdfRectangle
from tests.examples.tree import AsdfTreeNode

Expand Down Expand Up @@ -136,3 +138,32 @@ def test_given_child_field_contains_asdf_object_then_schema_has_child_tag():
child_schema = schema["definitions"]["AsdfNode"]["properties"]["child"]

assert {"tag": AsdfTreeNode._tag} in child_schema["anyOf"]


########################################################################################
# AsdfTag
########################################################################################
from asdf_pydantic.schema import AsdfTag # noqa: E402


@pytest.mark.parametrize(
"asdf_tag_str, mode, expected_ref_key",
[
("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "auto", "$ref"),
("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "ref", "$ref"),
("tag:stsci.edu:asdf/table/table-1.1.0", "auto", "tag"),
("tag:stsci.edu:asdf/table/table-1.1.0", "tag", "tag"),
],
)
def test_tag_mode(asdf_tag_str: str, mode, expected_ref_key):
"""Test that schema correctly has ``$ref:`` or ``tag:`` depending on the
selected mode.
"""
from astropy.table import Table

class TestModel(AsdfPydanticModel):
_tag = "asdf://asdf-pydantic/examples/tags/test-model-1.0.0"
table: Annotated[Table, AsdfTag(asdf_tag_str, mode=mode)]

schema = yaml.safe_load(TestModel.model_asdf_schema())
assert expected_ref_key in schema["properties"]["table"]

0 comments on commit c6de000

Please sign in to comment.