diff --git a/asdf_pydantic/schema.py b/asdf_pydantic/schema.py index 38a2821..705cbe7 100644 --- a/asdf_pydantic/schema.py +++ b/asdf_pydantic/schema.py @@ -1,5 +1,36 @@ -from typing import Optional +""" +## Adding existing ASDF tags as a field +Type annotation must be added to the field to specify the ASDF tag to use in the +ASDF schema. There are a few options to do this: + + - Use `AsdfTag` to specify the tag URI. + - Use `WithAsdfSchema` and pass in a dictionary to extend the schema with + additional properties. The key `"$ref"` can be used to specify the tag URI. + + from asdf_pydantic import AsdfPydanticModel + from asdf_pydantic.schema import AsdfTag + from astropy.table import Table + + class MyModel(AsdfPydanticModel): + table: Annotated[Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0")] + +For more customization of the ASDF schema output, you can use `WithAsdfSchema` to +extend the schema with additional properties. + + # Changing the title of the field + table: Annotated[ + Table, + WithAsdfSchema({ + "title": "TABLE", + "$ref": "http://stsci.edu/schemas/asdf.org/table/table-1.1.0" + }), + ] +""" + +from typing import Literal, Optional + +from pydantic import WithJsonSchema from pydantic.json_schema import GenerateJsonSchema DEFAULT_ASDF_SCHEMA_REF_TEMPLATE = "#/definitions/{model}" @@ -60,3 +91,20 @@ def generate(self, schema, mode="validation"): } return json_schema + + +class WithAsdfSchema(WithJsonSchema): + def __init__(self, asdf_schema: dict, **kwargs): + super().__init__(asdf_schema, **kwargs) + + +def AsdfTag(tag: str, mode: Literal["auto", "ref", "tag"] = "auto") -> WithAsdfSchema: + if mode == "auto": + parsed_mode = "tag" if tag.startswith("tag") else "ref" + else: + parsed_mode = mode + + if parsed_mode == "tag": + return WithAsdfSchema({"tag": tag}) + else: + return WithAsdfSchema({"$ref": tag}) diff --git a/tests/examples/test_astropy_tables.py b/tests/examples/test_astropy_tables.py new file mode 100644 index 0000000..d90bd4a --- /dev/null +++ b/tests/examples/test_astropy_tables.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Annotated + +import asdf +import astropy.units as u +import pytest +import yaml +from asdf.extension import Extension +from astropy.table import Table +from astropy.units import Quantity + +from asdf_pydantic import AsdfPydanticConverter, AsdfPydanticModel +from asdf_pydantic.schema import AsdfTag + + +class Database(AsdfPydanticModel): + _tag = "asdf://asdf-pydantic/examples/tags/database-1.0.0" + positions: Annotated[ + Table, AsdfTag("http://stsci.edu/schemas/asdf.org/table/table-1.1.0") + ] + + +@pytest.fixture() +def asdf_extension(): + """Registers an ASDF extension containing models for this test.""" + AsdfPydanticConverter.add_models(Database) + + class TestExtension(Extension): + extension_uri = "asdf://asdf-pydantic/examples/extensions/test-1.0.0" + + converters = [AsdfPydanticConverter()] # type: ignore + tags = [*AsdfPydanticConverter().tags] # type: ignore + + asdf.get_config().add_extension(TestExtension()) + + with asdf.config_context() as asdf_config: + asdf_config.add_resource_mapping( + { + yaml.safe_load(Database.model_asdf_schema())[ + "id" + ]: Database.model_asdf_schema() + } + ) + print(Database.model_asdf_schema()) + asdf_config.add_extension(TestExtension()) + yield asdf_config + + +@pytest.mark.usefixtures("asdf_extension") +def test_convert_to_asdf(tmp_path): + database = Database( + positions=Table( + { + "x": Quantity([1, 2, 3], u.m), + "y": Quantity([4, 5, 6], u.m), + } + ) + ) + asdf.AsdfFile({"data": database}).write_to(tmp_path / "test.asdf") + + with asdf.open(tmp_path / "test.asdf") as af: + assert isinstance(af.tree["data"], Database) + assert isinstance(af.tree["data"].positions, Table) + + +@pytest.mark.usefixtures("asdf_extension") +def test_check_schema(): + """Tests the model schema is correct.""" + schema = yaml.safe_load(Database.model_asdf_schema()) + asdf.schema.check_schema(schema) diff --git a/tests/schema_validation_test.py b/tests/schema_validation_test.py index 73b4d52..60a3f09 100644 --- a/tests/schema_validation_test.py +++ b/tests/schema_validation_test.py @@ -1,4 +1,5 @@ from tempfile import NamedTemporaryFile +from typing import Annotated import asdf import pydantic @@ -7,6 +8,7 @@ from asdf.extension import Extension from asdf_pydantic import AsdfPydanticConverter +from asdf_pydantic.model import AsdfPydanticModel from tests.examples.shapes import AsdfRectangle from tests.examples.tree import AsdfTreeNode @@ -136,3 +138,32 @@ def test_given_child_field_contains_asdf_object_then_schema_has_child_tag(): child_schema = schema["definitions"]["AsdfNode"]["properties"]["child"] assert {"tag": AsdfTreeNode._tag} in child_schema["anyOf"] + + +######################################################################################## +# AsdfTag +######################################################################################## +from asdf_pydantic.schema import AsdfTag # noqa: E402 + + +@pytest.mark.parametrize( + "asdf_tag_str, mode, expected_ref_key", + [ + ("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "auto", "$ref"), + ("http://stsci.edu/schemas/asdf/unit/quantity-1.2.0", "ref", "$ref"), + ("tag:stsci.edu:asdf/table/table-1.1.0", "auto", "tag"), + ("tag:stsci.edu:asdf/table/table-1.1.0", "tag", "tag"), + ], +) +def test_tag_mode(asdf_tag_str: str, mode, expected_ref_key): + """Test that schema correctly has ``$ref:`` or ``tag:`` depending on the + selected mode. + """ + from astropy.table import Table + + class TestModel(AsdfPydanticModel): + _tag = "asdf://asdf-pydantic/examples/tags/test-model-1.0.0" + table: Annotated[Table, AsdfTag(asdf_tag_str, mode=mode)] + + schema = yaml.safe_load(TestModel.model_asdf_schema()) + assert expected_ref_key in schema["properties"]["table"]