Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v2' into test-unstable-asdf-v4
Browse files Browse the repository at this point in the history
  • Loading branch information
ketozhang committed Jun 28, 2024
2 parents 4de4c9b + b5140a9 commit 59b2212
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 27 deletions.
56 changes: 37 additions & 19 deletions asdf_pydantic/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import textwrap
from typing import ClassVar

import yaml
from pydantic import BaseModel
from typing_extensions import deprecated

from asdf_pydantic.schema import DEFAULT_ASDF_SCHEMA_REF_TEMPLATE, GenerateAsdfSchema


class AsdfPydanticModel(BaseModel):
Expand Down Expand Up @@ -42,29 +44,45 @@ def asdf_yaml_tree(self) -> dict:
return d

@classmethod
def model_asdf_schema(
cls,
by_alias: bool = True,
ref_template: str = DEFAULT_ASDF_SCHEMA_REF_TEMPLATE,
schema_generator: type[GenerateAsdfSchema] = GenerateAsdfSchema,
):
"""Get the ASDF schema definition for this model."""
# Implementation follows closely with the `BaseModel.model_json_schema`
schema_generator_instance = schema_generator(
by_alias=by_alias,
ref_template=ref_template,
tag=cls._tag,
)
json_schema = schema_generator_instance.generate(cls.__pydantic_core_schema__)

header = "%YAML 1.1\n---\n"

return f"{header}\n{yaml.safe_dump(json_schema)}"

@classmethod
@deprecated(
"The `schema_asdf` method is deprecated; use `model_asdf_schema` instead."
)
def schema_asdf(
cls, *, metaschema: str = "http://stsci.edu/schemas/asdf/asdf-schema-1.0.0"
cls,
*,
metaschema: str = GenerateAsdfSchema.schema_dialect,
**kwargs,
) -> str:
"""Get the ASDF schema definition for this model.
Parameters
----------
metaschema, optional
A metaschema URI, by default "http://stsci.edu/schemas/asdf/asdf-schema-1.0.0".
See https://asdf.readthedocs.io/en/stable/asdf/extending/schemas.html#anatomy-of-a-schema
for more options.
A metaschema URI
""" # noqa: E501
# TODO: Function signature should follow BaseModel.schema() or
# BaseModel.schema_json()
header = textwrap.dedent(
f"""
%YAML 1.1
---
$schema: {metaschema}
id: {cls._tag}
tag: tag:{cls._tag.split('://', maxsplit=2)[-1]}
"""
)
body = yaml.safe_dump(cls.model_json_schema())
return header + body
if metaschema != GenerateAsdfSchema.schema_dialect:
raise NotImplementedError(
f"Only {GenerateAsdfSchema.schema_dialect} is supported as metaschema."
)

return cls.model_asdf_schema(**kwargs)
40 changes: 40 additions & 0 deletions asdf_pydantic/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional

from pydantic.json_schema import GenerateJsonSchema

DEFAULT_ASDF_SCHEMA_REF_TEMPLATE = "#/definitions/{model}"


class GenerateAsdfSchema(GenerateJsonSchema):
"""Generates ASDF-compatible schema from Pydantic's default JSON schema generator.
```{caution} Experimental
This schema generator is not complete. It currently creates JSON 2020-12
schema (despite `$schema` says it's `asdf-schema-1.0.0`) which are not
compatible with ASDF.
```
"""

# HACK: When we can support tree models, then not all schema should have tag
tag: Optional[str]
schema_dialect = "http://stsci.edu/schemas/asdf/asdf-schema-1.0.0"

def __init__(
self,
by_alias: bool = True,
ref_template: str = DEFAULT_ASDF_SCHEMA_REF_TEMPLATE,
tag: Optional[str] = None,
):
super().__init__(by_alias=by_alias, ref_template=ref_template)
self.tag = tag

def generate(self, schema, mode="validation"):
json_schema = super().generate(schema, mode) # noqa: F841

if self.tag:
json_schema["$schema"] = self.schema_dialect
json_schema["id"] = self.tag
json_schema["tag"] = f"tag:{self.tag.split('://', maxsplit=2)[-1]}"

# TODO: Convert jsonschema 2020-12 to ASDF schema
return json_schema
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ classifiers = [
dependencies = [
"asdf>=3",
"pydantic>=2",
"numpy>=1.25"
"numpy>=1.25",
"numpy<2;python_version<'3.10'",
]
dynamic = ["version"]

Expand Down Expand Up @@ -63,14 +64,20 @@ matrix-name-format = "{variable}={value}"
test = "pytest {args}"

[[tool.hatch.envs.test.matrix]]
python = ["3.9", "3.10", "3.11", "3.12"]
# Only test with numpy v1 on Python 3.9
python = ["3.9"]
numpy-version = ["1"]

[[tool.hatch.envs.test.matrix]]
python = ["3.10", "3.11", "3.12"]
numpy = ["1", "2"]
asdf = ["3", "non_lazy_ndarray"]

[tool.hatch.envs.test.overrides]
matrix.numpy.dependencies = [
{ value = "numpy>=1,<2", if = ["1"] },
{ value = "numpy>=2,<3", if = ["2"] },
{ value = "astropy>=6.1", if = ["2"] },
]

matrix.asdf.dependencies = [
Expand Down
12 changes: 8 additions & 4 deletions tests/examples/test_rectangle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asdf
import pytest
from asdf.extension import Extension
from asdf.schema import check_schema, load_schema
from yaml.scanner import ScannerError

from asdf_pydantic import AsdfPydanticConverter
from asdf_pydantic.examples.shapes import AsdfRectangle
Expand All @@ -23,17 +25,19 @@ class TestExtension(Extension):
asdf.get_config().add_resource_mapping(
{
"asdf://asdf-pydantic/shapes/schemas/rectangle-1.0.0": (
AsdfRectangle.schema_asdf().encode("utf-8")
AsdfRectangle.model_asdf_schema().encode("utf-8")
)
}
)
asdf.get_config().add_extension(TestExtension())


def test_schema():
schema = load_schema("asdf://asdf-pydantic/shapes/schemas/rectangle-1.0.0")

check_schema(schema)
try:
schema = load_schema("asdf://asdf-pydantic/shapes/schemas/rectangle-1.0.0")
check_schema(schema)
except ScannerError as e:
pytest.fail(f"{e}\n{AsdfRectangle.model_asdf_schema()}")

assert schema["$schema"] == "http://stsci.edu/schemas/asdf/asdf-schema-1.0.0"
assert schema["title"] == "AsdfRectangle"
Expand Down
4 changes: 2 additions & 2 deletions tests/schema_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestExtension(Extension):
asdf.get_config().add_resource_mapping(
{
"asdf://asdf-pydantic/shapes/schemas/rectangle-1.0.0": (
AsdfRectangle.schema_asdf().encode("utf-8")
AsdfRectangle.model_asdf_schema().encode("utf-8")
)
}
)
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_validate_fail_on_bad_yaml_file():
def test_given_child_field_contains_asdf_object_then_schema_has_child_tag():
from asdf.schema import check_schema

schema = yaml.safe_load(AsdfNode.schema_asdf()) # type: ignore
schema = yaml.safe_load(AsdfNode.model_asdf_schema()) # type: ignore
check_schema(schema)

child_schema = schema["definitions"]["AsdfNode"]["properties"]["child"]
Expand Down

0 comments on commit 59b2212

Please sign in to comment.