diff --git a/aiopenapi3/base.py b/aiopenapi3/base.py index 3a69f5f6..60af5d1b 100644 --- a/aiopenapi3/base.py +++ b/aiopenapi3/base.py @@ -18,7 +18,7 @@ else: from typing_extensions import TypeGuard -from pydantic import BaseModel, Field, AnyUrl, model_validator, PrivateAttr, ConfigDict +from pydantic import RootModel, BaseModel, TypeAdapter, Field, AnyUrl, model_validator, PrivateAttr, ConfigDict from .json import JSONPointer, JSONReference from .errors import ReferenceResolutionError, OperationParameterValidationError @@ -424,7 +424,7 @@ def get_type( discriminators: Optional[Sequence[DiscriminatorBase]] = None, extra: Optional["SchemaBase"] = None, fwdref: bool = False, - ) -> Union[Type[BaseModel], ForwardRef]: + ) -> Union[Type[BaseModel], Type[TypeAdapter], ForwardRef]: if fwdref: if "module" in ForwardRef.__init__.__code__.co_varnames: # FIXME Python < 3.9 compat @@ -449,18 +449,15 @@ def model(self, data: "JSON") -> Union[BaseModel, List[BaseModel]]: :rtype: self.get_type() """ - if self.type == "boolean": - assert len(self.properties) == 0 - t = Model.typeof(cast("SchemaType", self)) - if not isinstance(data, t): - return t(data) - return data + type_ = cast("SchemaType", self.get_type()) + if isinstance(type_, TypeAdapter): + r = type_.validate_python(data) else: - type_ = cast("SchemaType", self.get_type()) r = type_.model_validate(data) - if self.type in ("string", "number", "integer", "array"): + if self.type in ("string", "number", "integer", "array", "boolean"): + if isinstance(r, RootModel): return r.root - return r + return r class OperationBase: diff --git a/aiopenapi3/model.py b/aiopenapi3/model.py index b6f0dd79..1d25171e 100644 --- a/aiopenapi3/model.py +++ b/aiopenapi3/model.py @@ -17,7 +17,7 @@ from typing import List, Optional, Union, Tuple, Dict from typing_extensions import Annotated, Literal -from pydantic import BaseModel, Field, RootModel, ConfigDict +from pydantic import BaseModel, TypeAdapter, Field, RootModel, ConfigDict import pydantic from .base import ReferenceBase, SchemaBase @@ -69,12 +69,15 @@ def class_from_schema(s, _type): return b +import pydantic_core + + @dataclasses.dataclass class _ClassInfo: @dataclasses.dataclass class _PropertyInfo: annotation: Any = None - default: Any = None + default: Any = pydantic_core.PydanticUndefined root: Any = None config: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -115,16 +118,18 @@ def from_schema( schemanames: Optional[List[str]] = None, discriminators: Optional[List["DiscriminatorType"]] = None, extra: Optional["SchemaType"] = None, - ) -> Type[BaseModel]: + ) -> Union[Type[BaseModel], Type[TypeAdapter]]: if schemanames is None: schemanames = [] if discriminators is None: discriminators = [] - r: List[Type[BaseModel]] = list() + r: List[Union[Type[BaseModel], Type[TypeAdapter]]] = list() for _type in Model.types(schema): + if _type == "null": + continue r.append(Model.from_schema_type(schema, _type, schemanames, discriminators, extra)) if len(r) > 1: @@ -134,7 +139,13 @@ def from_schema( elif len(r) == 1: m: Type[BaseModel] = cast(Type[BaseModel], r[0]) else: # == 0 - raise ValueError(r) + assert schema.type == "null" + return TypeAdapter(None.__class__) + + if not isinstance(m, TypeAdapter) and Model.is_nullable(schema): + n = TypeAdapter(Optional[m]) + return cast(Type[TypeAdapter], n) + return cast(Type[BaseModel], m) @classmethod @@ -152,11 +163,8 @@ def from_schema_type( classinfo = _ClassInfo() - # do not create models for primitive types + # create models for primitive types to be nullable if _type in ("string", "integer", "number", "boolean"): - if _type == "boolean": - return bool - if typing.get_origin((_t := Model.typeof(schema, _type=_type))) != Literal: classinfo.root = Annotated[_t, Model.fieldof_args(schema, None)] else: @@ -325,7 +333,7 @@ def typeof( if schema is None: return BaseModel if isinstance(schema, SchemaBase): - nullable = False + nullable = Model.is_nullable(schema) schema = cast("SchemaType", schema) """ Required, can be None: Optional[str] @@ -520,7 +528,7 @@ def or_type(schema: "SchemaType", type_: str, l: Optional[int] = 2) -> bool: @staticmethod def is_nullable(schema: "SchemaType") -> bool: - return Model.or_type(schema, "null", l=None) or getattr(schema, "nullable", False) + return Model.or_type(schema, "null", l=None) or getattr(schema, "nullable", False) is True @staticmethod def is_type_any(schema: "SchemaType"): @@ -537,14 +545,12 @@ def fieldof(schema: "SchemaType", classinfo: _ClassInfo): for name, f in schema.properties.items(): args: Dict[str, Any] = dict() assert schema.required is not None - if name not in schema.required: + if (v := getattr(f, "default", None)) is not None: + args["default"] = v + elif name not in schema.required: args["default"] = None + name = Model.nameof(name, args=args) - if Model.is_nullable(f): - args["default"] = None - for i in ["default"]: - if (v := getattr(f, i, None)) is not None: - args[i] = v classinfo.properties[name].default = Model.fieldof_args(f, args) else: raise ValueError(schema.type) diff --git a/tests/conftest.py b/tests/conftest.py index 9c27ba05..e883b95b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -493,3 +493,8 @@ def with_paths_response_error(): @pytest.fixture def with_schema_ref_nesting(): yield _get_parsed_yaml("schema-ref-nesting.yaml") + + +@pytest.fixture +def with_schema_nullable(openapi_version): + yield _get_parsed_yaml(f"schema-nullable-v{openapi_version.major}{openapi_version.minor}.yaml") diff --git a/tests/fixtures/schema-nullable-v30.yaml b/tests/fixtures/schema-nullable-v30.yaml new file mode 100644 index 00000000..fe9092e8 --- /dev/null +++ b/tests/fixtures/schema-nullable-v30.yaml @@ -0,0 +1,45 @@ +openapi: 3.0.3 +info: + title: '' + version: 0.0.0 +servers: + - url: http://127.0.0.1/api + +security: + - {} + +paths: {} + +components: + schemas: + object: + type: object + additionalProperties: false + properties: + attr: + $ref: '#/components/schemas/nullable' + nullable: true + required: + - attr + + array: + type: array + items: + $ref: '#/components/schemas/nullable' + nullable: true + + string: + type: string + nullable: true + + integer: + type: integer + nullable: true + + boolean: + type: boolean + nullable: true + + nullable: + nullable: true + type: string diff --git a/tests/fixtures/schema-nullable-v31.yaml b/tests/fixtures/schema-nullable-v31.yaml new file mode 100644 index 00000000..37dd6933 --- /dev/null +++ b/tests/fixtures/schema-nullable-v31.yaml @@ -0,0 +1,40 @@ +openapi: 3.1.0 +info: + title: '' + version: 0.0.0 +servers: + - url: http://127.0.0.1/api + +security: + - {} + +paths: {} + +components: + schemas: + object: + type: [object, "null"] + additionalProperties: false + properties: + attr: + $ref: '#/components/schemas/nullable' + required: + - attr + + array: + type: [array, "null"] + items: + $ref: '#/components/schemas/nullable' + + + string: + type: [string, "null"] + + integer: + type: [integer, "null"] + + boolean: + type: [boolean, "null"] + + nullable: + type: [string, "null"] diff --git a/tests/schema_test.py b/tests/schema_test.py index f24843fe..581f1d95 100644 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -423,11 +423,11 @@ def test_schema_enum(with_schema_enum): with pytest.raises(ValidationError): String(None) - Nullable = api.components.schemas["Nullable"].get_type() - Nullable("a") - Nullable(None) + Nullable = api.components.schemas["Nullable"] + Nullable.model("a") + Nullable.model(None) with pytest.raises(ValidationError): - Nullable("c") + Nullable.model("c") Mixed = api.components.schemas["Mixed"].get_type() Mixed(1) @@ -470,3 +470,32 @@ def test_schema_baseurl_v20(with_schema_baseurl_v20): def test_schema_ref_nesting(with_schema_ref_nesting): for i in range(10): OpenAPI("/", with_schema_ref_nesting) + + +@pytest.mark.parametrize( + "schema, input, output, okay", + [ + ("object", None, None, True), + ("object", {"attr": "a"}, {"attr": "a"}, True), + ("object", {"attr": None}, {"attr": None}, True), + ("object", {}, {}, False), + ("integer", None, None, True), + ("integer", 1, 1, True), + ("boolean", None, None, True), + ("boolean", True, True, True), + ("string", None, None, True), + ("string", "a", "a", True), + ("array", None, None, True), + ("array", [], [], True), + ], +) +def test_schema_nullable(with_schema_nullable, schema, input, output, okay): + api = OpenAPI("/", with_schema_nullable) # , plugins=[NullableRefs()]) + + m = api.components.schemas[schema] + t = m.get_type() + if okay: + m.model(input) + else: + with pytest.raises(ValidationError): + m.model(input)