From 785881b686c8429ee562050bd787a699fe72ed5b Mon Sep 17 00:00:00 2001 From: "Keto D. Zhang" Date: Sat, 15 Jun 2024 18:18:32 -0700 Subject: [PATCH] feat: Add before-validation converts all ASDF NDArrayType to ndarray --- asdf_pydantic/model.py | 18 ++++++++++++++++-- tests/patterns/numpy_type_test.py | 13 +++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/asdf_pydantic/model.py b/asdf_pydantic/model.py index 638ed77..4015db6 100644 --- a/asdf_pydantic/model.py +++ b/asdf_pydantic/model.py @@ -1,8 +1,11 @@ import textwrap -from typing import ClassVar +from typing import Any, ClassVar +import numpy as np import yaml -from pydantic import BaseModel +from asdf.tags.core import NDArrayType +from numpy.typing import NDArray +from pydantic import BaseModel, ValidationInfo, field_validator class AsdfPydanticModel(BaseModel): @@ -67,3 +70,14 @@ def schema_asdf( ) body = yaml.dump(cls.schema()) return header + body + + @field_validator("*", mode="before") + @classmethod + def _allow_asdf_NDArrayType_to_be_ndarray( + cls, value: Any, info: ValidationInfo + ) -> Any | NDArray: + """Before Pydantic validation, convert NDArrayType to ndarray.""" + if not isinstance(value, NDArrayType): + return value + + return np.asarray(value) diff --git a/tests/patterns/numpy_type_test.py b/tests/patterns/numpy_type_test.py index ecce787..ded9af5 100644 --- a/tests/patterns/numpy_type_test.py +++ b/tests/patterns/numpy_type_test.py @@ -37,12 +37,13 @@ def test_convert_ArrayContainer_to_asdf(tmp_path): """When writing ArrayContainer to an ASDF file, the array field should be serialized to the original numpy array. """ - af = asdf.AsdfFile({"data": ArrayContainer(array=np.array([1, 2, 3]))}).write_to( - tmp_path / "test.asdf" - ) + data = ArrayContainer(array=np.array([1, 2, 3])) + af = asdf.AsdfFile({"data": data}) + af.write_to(tmp_path / "test.asdf") with asdf.open(tmp_path / "test.asdf") as af: - assert isinstance(af.tree["array"], np.ndarray), ( - f"Expected {type(np.ndarray)}, " f"got {type(af.tree['array'])}" + breakpoint() + assert isinstance(af.tree["data"], np.ndarray), ( + f"Expected {type(np.ndarray)}, " f"got {type(af.tree['data'])}" ) - assert np.all(af.tree["array"] == np.array([1, 2, 3])) + assert np.all(af.tree["data"] == np.array([1, 2, 3]))