Skip to content

Commit

Permalink
Add support for defining tags with asdf.extension.TagDefinition (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
ketozhang authored Aug 26, 2024
2 parents 445c002 + d127085 commit 7c0abfc
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 18 deletions.
2 changes: 1 addition & 1 deletion asdf_pydantic/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def add_models(
cls, *model_classes: Type[AsdfPydanticModel]
) -> "AsdfPydanticConverter":
for model_class in model_classes:
cls._tag_to_class[model_class._tag] = model_class
cls._tag_to_class[model_class.get_tag_uri()] = model_class
return cls()

@property
Expand Down
33 changes: 22 additions & 11 deletions asdf_pydantic/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import ClassVar

import yaml
from pydantic import BaseModel
from asdf.extension import TagDefinition
from pydantic import BaseModel, ConfigDict
from typing_extensions import deprecated

from asdf_pydantic.schema import DEFAULT_ASDF_SCHEMA_REF_TEMPLATE, GenerateAsdfSchema
Expand All @@ -16,10 +17,8 @@ class AsdfPydanticModel(BaseModel):
AsdfPydanticModel object with py:meth`AsdfPydanticModel.parse_obj()`.
"""

_tag: ClassVar[str]

class Config:
arbitrary_types_allowed = True
_tag: ClassVar[str | TagDefinition]
model_config = ConfigDict(arbitrary_types_allowed=True)

def asdf_yaml_tree(self) -> dict:
d = {}
Expand All @@ -43,6 +42,22 @@ def asdf_yaml_tree(self) -> dict:

return d

@classmethod
def get_tag_definition(cls):
if isinstance(cls._tag, str):
return TagDefinition( # TODO: Add title and description
cls._tag,
schema_uris=[f"{cls._tag}/schema"],
)
return cls._tag

@classmethod
def get_tag_uri(cls):
if isinstance(cls._tag, TagDefinition):
return cls._tag.tag_uri
else:
return cls._tag

@classmethod
def model_asdf_schema(
cls,
Expand All @@ -53,15 +68,11 @@ def model_asdf_schema(
"""Get the ASDF schema definition for this model."""
# Implementation follows closely with the `BaseModel.model_json_schema`
schema_generator_instance = schema_generator(
by_alias=by_alias,
ref_template=ref_template,
tag=cls._tag,
by_alias=by_alias, ref_template=ref_template, tag_uri=cls.get_tag_uri()
)
json_schema = schema_generator_instance.generate(cls.__pydantic_core_schema__)

header = "%YAML 1.1\n---\n"

return f"{header}\n{yaml.safe_dump(json_schema)}"
return f"%YAML 1.1\n---\n{yaml.safe_dump(json_schema)}"

@classmethod
@deprecated(
Expand Down
11 changes: 5 additions & 6 deletions asdf_pydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,24 @@ class GenerateAsdfSchema(GenerateJsonSchema):
"""

# HACK: When we can support tree models, then not all schema should have tag
tag: Optional[str]
schema_dialect = "http://stsci.edu/schemas/asdf/asdf-schema-1.0.0"

def __init__(
self,
by_alias: bool = True,
ref_template: str = DEFAULT_ASDF_SCHEMA_REF_TEMPLATE,
tag: Optional[str] = None,
tag_uri: Optional[str] = None,
):
super().__init__(by_alias=by_alias, ref_template=ref_template)
self.tag = tag
self.tag_uri = tag_uri

def generate(self, schema, mode="validation"):
json_schema = super().generate(schema, mode) # noqa: F841

if self.tag:
if self.tag_uri:
json_schema["$schema"] = self.schema_dialect
json_schema["id"] = self.tag
json_schema["tag"] = f"tag:{self.tag.split('://', maxsplit=2)[-1]}"
json_schema["id"] = self.tag_uri
json_schema["tag"] = f"tag:{self.tag_uri.split('://', maxsplit=2)[-1]}"

# TODO: Convert jsonschema 2020-12 to ASDF schema
return json_schema
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ source = "vcs"

# Default Environment
[tool.hatch.envs.default]
installer = "uv"
dependencies = [
"ipython",
"pytest",
Expand Down
51 changes: 51 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
from asdf.extension import TagDefinition

from asdf_pydantic import AsdfPydanticModel


def tag(request):
return request.param


@pytest.mark.parametrize(
"tag",
(
"asdf://asdf-pydantic/test/tags/",
pytest.param(
TagDefinition("asdf://asdf-pydantic/tags/test-0.0.1"),
marks=pytest.mark.xfail(
reason="Tag definition without schema URIs not supported"
),
),
TagDefinition(
"asdf://asdf-pydantic/tags/test-0.0.1",
schema_uris=["asdf://asdf-pydantic/test/schemas/test-0.0.1"],
),
),
)
def test_can_get_tag_definition(tag):
class TestModel(AsdfPydanticModel):
_tag = tag

tag_definition = TestModel.get_tag_definition()
assert isinstance(tag_definition, TagDefinition)
assert tag_definition.schema_uris


@pytest.mark.parametrize(
"tag",
(
"asdf://asdf-pydantic/test/tags/",
TagDefinition("asdf://asdf-pydantic/tags/test-0.0.1"),
TagDefinition(
"asdf://asdf-pydantic/tags/test-0.0.1",
schema_uris=["asdf://asdf-pydantic/test/schemas/test-0.0.1"],
),
),
)
def test_can_get_tag_uris(tag):
class TestModel(AsdfPydanticModel):
_tag = tag

assert TestModel.get_tag_uri()

0 comments on commit 7c0abfc

Please sign in to comment.