Skip to content

Commit

Permalink
Add Field() function to defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil committed Oct 16, 2023
1 parent f05b559 commit 2c2fcf7
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 94 deletions.
41 changes: 29 additions & 12 deletions polyforce/_internal/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
cast,
)

from typing_extensions import dataclass_transform

from polyforce.exceptions import MissingAnnotation, ReturnSignatureMissing, ValidationError

from ..constants import INIT_FUNCTION, SPECIAL_CHECK
from ..core._polyforce_core import PolyforceUndefined
from ..decorator import polycheck
from ..fields import PolyField
from ..fields import Field, PolyField
from ._config import ConfigWrapper
from ._errors import ErrorDetail
from ._serializer import json_serializable
Expand All @@ -34,6 +36,7 @@
object_setattr = object.__setattr__


@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class PolyMetaclass(ABCMeta):
"""
Base metaclass used for the PolyModel objects
Expand All @@ -45,7 +48,11 @@ class PolyMetaclass(ABCMeta):
__signature__: ClassVar[Dict[str, Signature]] = {}

def __new__(
cls: Type["PolyMetaclass"], name: str, bases: Tuple[Type], attrs: Dict[str, Any]
cls: Type["PolyMetaclass"],
name: str,
bases: Tuple[Type],
attrs: Dict[str, Any],
**kwargs: Any,
) -> Type["PolyModel"]:
"""
Create a new class using the PolyMetaclass.
Expand Down Expand Up @@ -326,17 +333,27 @@ def generate_polyfields(
"""
For all the fields found in the signature, it will generate
PolyField type variable.
When generating PolyFields, it matches if there is already a
PolyField generated by the Field() type.
"""
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == Signature.empty
else parameter.default,
}

field = PolyField(**data)
field_data = {field.name: field}
if not isinstance(parameter.default, PolyField):
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == Signature.empty
else parameter.default,
}

field = PolyField(**data)
else:
field = parameter.default
field.annotation = parameter.annotation
field.name = parameter.name
field._validate_default_with_annotation()

field_data = {parameter.name: field}

if method not in cls.poly_fields:
cls.poly_fields[method] = {}
Expand Down
25 changes: 16 additions & 9 deletions polyforce/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,22 @@ def generate_polyfields(self) -> Dict[str, Dict[str, "PolyField"]]:
PolyField type variable.
"""
for parameter in self.args_spec.parameters.values():
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == inspect.Signature.empty
else parameter.default,
}
field = PolyField(**data)
field_data = {field.name: field}
if not isinstance(parameter.default, PolyField):
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == inspect.Signature.empty
else parameter.default,
}
field = PolyField(**data)
else:
field = parameter.default
field.annotation = parameter.annotation
field.name = parameter.name
field._validate_default_with_annotation()

field_data = {parameter.name: field}

if self.fn_name not in self.poly_fields:
self.poly_fields[self.fn_name] = {}
Expand Down
20 changes: 16 additions & 4 deletions polyforce/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def __init__(self, **kwargs: Unpack[_FieldInputs]) -> None:
raise TypeError("cannot specify both default and default_factory")

self.name = kwargs.pop("name", None)
self._validate_default_with_annotation()

self.title = kwargs.pop("title", None)
self.description = kwargs.pop("description", None)
self.metadata = metadata

if self.default and self.default != PolyforceUndefined and self.annotation:
self._validate_default_with_annotation()

def _extract_type_hint(self, type_hint: Union[Type, tuple]) -> Any:
"""
Extracts the base type from a type hint, considering typing extensions.
Expand All @@ -117,7 +120,6 @@ def _extract_type_hint(self, type_hint: Union[Type, tuple]) -> Any:
original_hint = extract_type_hint(Union[int, str]) # Returns Union[int, str]
```
"""

origin = getattr(type_hint, "__origin__", type_hint)
if isinstance(origin, _SpecialForm):
origin = type_hint.__args__ # type: ignore
Expand All @@ -131,7 +133,7 @@ def _validate_default_with_annotation(self) -> None:
if not self.default or self.default == PolyforceUndefined:
return None

default = self.default() if callable(self.default) else self.default
default = self.get_default()

type_hint = self._extract_type_hint(self.annotation)
if not isinstance(default, type_hint):
Expand Down Expand Up @@ -161,6 +163,14 @@ def is_required(self) -> bool:
"""
return self.default is PolyforceUndefined and self.default_factory is None

def get_default(self) -> Any:
"""
Returns the default is
"""
if self.default_factory is None:
return self.default() if callable(self.default) else self.default
return self.default_factory()

@classmethod
def from_field(cls, default: Any = PolyforceUndefined, **kwargs: Unpack[_FieldInputs]) -> Self:
"""
Expand Down Expand Up @@ -212,11 +222,13 @@ def Field(
*,
default_factory: Union[Callable[[], Any], None] = PolyforceUndefined,
title: Union[str, None] = PolyforceUndefined, # type: ignore
name: Union[str, None] = PolyforceUndefined, # type: ignore
description: Union[str, None] = PolyforceUndefined, # type: ignore
) -> PolyField:
) -> Any:
return PolyField.from_field(
default=default,
default_factory=default_factory,
title=title,
description=description,
name=name,
)
83 changes: 14 additions & 69 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,32 @@
from typing import Any, Dict, List, Mapping, Union

import pytest

from polyforce import PolyField, PolyModel
from polyforce.exceptions import ValidationError
from polyforce import Field, PolyField, PolyModel


class Model(PolyModel):
def __init__(self, name: str, age: Union[str, int]) -> None:
def __init__(self, name: str = Field(), age: Union[str, int] = Field()) -> None:
...

def create_model(self, names: List[str]) -> None:
def create_model(self, names: List[str] = Field()) -> None:
return names

def get_model(self, models: Dict[str, Any]) -> Dict[str, Any]:
def get_model(self, models: Dict[str, Any] = Field()) -> Dict[str, Any]:
return models

def set_model(self, models: Mapping[str, PolyModel]) -> None:
def set_model(self, models: Mapping[str, PolyModel] = Field()) -> None:
return models


def test_can_create_polyfield():
field = PolyField(annotation=str, name="field")
assert field is not None
assert field.annotation == str
assert field.name == "field"
assert field.is_required() is True


def test_raise_type_error_on_default_field():
with pytest.raises(TypeError) as raised:
PolyField(annotation=str, default=2, name="name")

assert (
raised.value.args[0]
== "default 'int' for field 'name' is not valid for the field type annotation, it must be type 'str'"
)


def test_default_field():
default = "john"

def get_default():
nonlocal default
return default

field = PolyField(annotation=str, default=get_default, name="name")
assert field.default == default


def test_functions():
model = Model(name="PolyModel", age=1)

names = model.create_model(names=["poly"])
assert names == ["poly"]

models = model.get_model(models={"name": "poly"})
assert models == {"name": "poly"}

models = model.set_model(models={"name": "poly"})
assert models == {"name": "poly"}


@pytest.mark.parametrize("func", ["get_model", "set_model"])
def test_functions_raises_validation_error(func):
model = Model(name="PolyModel", age=1)

with pytest.raises(ValidationError):
model.create_model(names="a")

with pytest.raises(ValidationError):
getattr(model, func)(models="a")


def test_poly_fields():
model = Model(name="PolyModel", age=1)
def test_field():
model = Model(name="Polyforce", age=1)

assert len(model.poly_fields) == 4
assert model.poly_fields["create_model"]["names"].annotation == List[str]
assert model.poly_fields["get_model"]["models"].annotation == Dict[str, Any]
assert model.poly_fields["set_model"]["models"].annotation == Mapping[str, PolyModel]


for value in ["create_model", "get_model", "set_model"]:
assert value in model.poly_fields
def test_no_annotation():
field: PolyField = Field(default=2, name="name")

assert len(model.poly_fields["__init__"]) == 2
assert len(model.poly_fields["create_model"]) == 1
assert len(model.poly_fields["get_model"]) == 1
assert len(model.poly_fields["set_model"]) == 1
assert field.annotation is None
87 changes: 87 additions & 0 deletions tests/fields/test_polyfield_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any, Dict, List, Mapping, Union

import pytest

from polyforce import PolyField, PolyModel
from polyforce.exceptions import ValidationError


class Model(PolyModel):
def __init__(self, name: str, age: Union[str, int]) -> None:
...

def create_model(self, names: List[str]) -> None:
return names

def get_model(self, models: Dict[str, Any]) -> Dict[str, Any]:
return models

def set_model(self, models: Mapping[str, PolyModel]) -> None:
return models


def test_can_create_polyfield():
field = PolyField(annotation=str, name="field")
assert field is not None
assert field.annotation == str
assert field.name == "field"
assert field.is_required() is True


def test_raise_type_error_on_default_field():
with pytest.raises(TypeError) as raised:
PolyField(annotation=str, default=2, name="name")

assert (
raised.value.args[0]
== "default 'int' for field 'name' is not valid for the field type annotation, it must be type 'str'"
)


def test_default_field():
default = "john"

def get_default():
nonlocal default
return default

field = PolyField(annotation=str, default=get_default, name="name")
assert field.default == default


def test_functions():
model = Model(name="PolyModel", age=1)

names = model.create_model(names=["poly"])
assert names == ["poly"]

models = model.get_model(models={"name": "poly"})
assert models == {"name": "poly"}

models = model.set_model(models={"name": "poly"})
assert models == {"name": "poly"}


@pytest.mark.parametrize("func", ["get_model", "set_model"])
def test_functions_raises_validation_error(func):
model = Model(name="PolyModel", age=1)

with pytest.raises(ValidationError):
model.create_model(names="a")

with pytest.raises(ValidationError):
getattr(model, func)(models="a")


def test_poly_fields():
model = Model(name="PolyModel", age=1)

assert len(model.poly_fields) == 4

for value in ["create_model", "get_model", "set_model"]:
assert value in model.poly_fields

assert len(model.poly_fields["__init__"]) == 2
assert len(model.poly_fields["create_model"]) == 1
assert len(model.poly_fields["get_model"]) == 1
assert len(model.poly_fields["set_model"]) == 1

0 comments on commit 2c2fcf7

Please sign in to comment.