Skip to content

Commit

Permalink
feat: Add before-validation converts all ASDF NDArrayType to ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
ketozhang committed Jun 16, 2024
1 parent 338323e commit 7551610
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
18 changes: 16 additions & 2 deletions asdf_pydantic/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import textwrap
from typing import ClassVar
from typing import Any, ClassVar

import numpy as np
import yaml
from pydantic import BaseModel
from asdf.tags.core import NDArrayType
from numpy.typing import NDArray
from pydantic import BaseModel, ValidationInfo, field_validator


class AsdfPydanticModel(BaseModel):
Expand Down Expand Up @@ -67,3 +70,14 @@ def schema_asdf(
)
body = yaml.dump(cls.schema())
return header + body

@field_validator("*", mode="before")
@classmethod
def _allow_asdf_NDArrayType_to_be_ndarray(
cls, value: Any, info: ValidationInfo
) -> Any | NDArray:
"""Before Pydantic validation, convert ASDF `NDArrayType` to numpy `NDArray`."""
if not isinstance(value, NDArrayType):
return value

return np.asarray(value)
13 changes: 7 additions & 6 deletions tests/patterns/numpy_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def test_convert_ArrayContainer_to_asdf(tmp_path):
"""When writing ArrayContainer to an ASDF file, the array field should be
serialized to the original numpy array.
"""
af = asdf.AsdfFile({"data": ArrayContainer(array=np.array([1, 2, 3]))}).write_to(
tmp_path / "test.asdf"
)
data = ArrayContainer(array=np.array([1, 2, 3]))
af = asdf.AsdfFile({"data": data})
af.write_to(tmp_path / "test.asdf")

with asdf.open(tmp_path / "test.asdf") as af:
assert isinstance(af.tree["array"], np.ndarray), (
f"Expected {type(np.ndarray)}, " f"got {type(af.tree['array'])}"
breakpoint()
assert isinstance(af.tree["data"], np.ndarray), (
f"Expected {type(np.ndarray)}, " f"got {type(af.tree['data'])}"
)
assert np.all(af.tree["array"] == np.array([1, 2, 3]))
assert np.all(af.tree["data"] == np.array([1, 2, 3]))

0 comments on commit 7551610

Please sign in to comment.