Skip to content

Commit

Permalink
Merge pull request #7 from tarsil/feature/set_attributes_from_fields
Browse files Browse the repository at this point in the history
PolyField
  • Loading branch information
tarsil authored Oct 16, 2023
2 parents f05b559 + 1bbcf82 commit ff7bfde
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 103 deletions.
41 changes: 29 additions & 12 deletions polyforce/_internal/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
cast,
)

from typing_extensions import dataclass_transform

from polyforce.exceptions import MissingAnnotation, ReturnSignatureMissing, ValidationError

from ..constants import INIT_FUNCTION, SPECIAL_CHECK
from ..core._polyforce_core import PolyforceUndefined
from ..decorator import polycheck
from ..fields import PolyField
from ..fields import Field, PolyField
from ._config import ConfigWrapper
from ._errors import ErrorDetail
from ._serializer import json_serializable
Expand All @@ -34,6 +36,7 @@
object_setattr = object.__setattr__


@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class PolyMetaclass(ABCMeta):
"""
Base metaclass used for the PolyModel objects
Expand All @@ -45,7 +48,11 @@ class PolyMetaclass(ABCMeta):
__signature__: ClassVar[Dict[str, Signature]] = {}

def __new__(
cls: Type["PolyMetaclass"], name: str, bases: Tuple[Type], attrs: Dict[str, Any]
cls: Type["PolyMetaclass"],
name: str,
bases: Tuple[Type],
attrs: Dict[str, Any],
**kwargs: Any,
) -> Type["PolyModel"]:
"""
Create a new class using the PolyMetaclass.
Expand Down Expand Up @@ -326,17 +333,27 @@ def generate_polyfields(
"""
For all the fields found in the signature, it will generate
PolyField type variable.
When generating PolyFields, it matches if there is already a
PolyField generated by the Field() type.
"""
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == Signature.empty
else parameter.default,
}

field = PolyField(**data)
field_data = {field.name: field}
if not isinstance(parameter.default, PolyField):
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == Signature.empty
else parameter.default,
}

field = PolyField(**data)
else:
field = parameter.default
field.annotation = parameter.annotation
field.name = parameter.name
field._validate_default_with_annotation()

field_data = {parameter.name: field}

if method not in cls.poly_fields:
cls.poly_fields[method] = {}
Expand Down
74 changes: 52 additions & 22 deletions polyforce/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,62 @@ def generate_polyfields(self) -> Dict[str, Dict[str, "PolyField"]]:
PolyField type variable.
"""
for parameter in self.args_spec.parameters.values():
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == inspect.Signature.empty
else parameter.default,
}
field = PolyField(**data)
field_data = {field.name: field}
if not isinstance(parameter.default, PolyField):
data = {
"annotation": parameter.annotation,
"name": parameter.name,
"default": PolyforceUndefined
if parameter.default == inspect.Signature.empty
else parameter.default,
}
field = PolyField(**data)
else:
field = parameter.default
field.annotation = parameter.annotation
field.name = parameter.name
field._validate_default_with_annotation()

field_data = {parameter.name: field}

if self.fn_name not in self.poly_fields:
self.poly_fields[self.fn_name] = {}

self.poly_fields[self.fn_name].update(field_data)
return self.poly_fields

def _extract_params(self) -> Dict[str, PolyField]:
"""
Extracts the params based on the type function.
If a function is of type staticmethod, means there is no `self`
or `cls` and therefore uses the signature or argspec generated.
If a function is of type classmethod or a simple function in general,
then validates if is a class or an object and extracts the values.
Returns:
Dict[str, PolyField]: A dictionary of function parameters.
"""
if not self.is_class_or_object:
return self.poly_fields[self.fn_name]

params: Dict[str, PolyField] = {}

# Get the function type (staticmethod, classmethod, or regular method)
func_type = getattr(self.class_or_object, self.fn_name)

if not isinstance(func_type, staticmethod):
if self.signature:
# If a signature is provided, use it to get function parameters
func_params = list(self.signature.parameters.values())
else:
# If no signature, use the poly_fields dictionary (modify as per your actual data structure)
func_params = list(
islice(self.poly_fields.get(self.fn_name, {}).values(), 1, None) # type: ignore[arg-type]
)
params = {param.name: param for param in func_params}
return params

def check_types(self, *args: Any, **kwargs: Any) -> Any:
"""
Validate the types of function parameters.
Expand All @@ -84,20 +124,8 @@ def check_types(self, *args: Any, **kwargs: Any) -> Any:
*args (Any): Positional arguments.
**kwargs (Any): Keyword arguments.
"""
merged_params: Dict[str, PolyField] = {}
if self.is_class_or_object:
func_type = inspect.getattr_static(self.class_or_object, self.fn_name)

# classmethod and staticmethod do not use the "self".
if not isinstance(func_type, (classmethod, staticmethod)):
func_params = list(
islice(self.poly_fields.get(self.fn_name, {}).values(), 1, None)
)
merged_params = {param.name: param for param in func_params}
else:
merged_params = self.poly_fields[self.fn_name]

params = dict(zip(merged_params, args))
params = dict(zip(self._extract_params(), args))
params.update(kwargs)

for name, value in params.items():
Expand Down Expand Up @@ -179,6 +207,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
if self.signature or len(args) == 1:
arguments = list(args)
arguments = arguments[1:]

# Is a class or an object
self.is_class_or_object = True
self.class_or_object = args[0]

Expand Down
20 changes: 16 additions & 4 deletions polyforce/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def __init__(self, **kwargs: Unpack[_FieldInputs]) -> None:
raise TypeError("cannot specify both default and default_factory")

self.name = kwargs.pop("name", None)
self._validate_default_with_annotation()

self.title = kwargs.pop("title", None)
self.description = kwargs.pop("description", None)
self.metadata = metadata

if self.default and self.default != PolyforceUndefined and self.annotation:
self._validate_default_with_annotation()

def _extract_type_hint(self, type_hint: Union[Type, tuple]) -> Any:
"""
Extracts the base type from a type hint, considering typing extensions.
Expand All @@ -117,7 +120,6 @@ def _extract_type_hint(self, type_hint: Union[Type, tuple]) -> Any:
original_hint = extract_type_hint(Union[int, str]) # Returns Union[int, str]
```
"""

origin = getattr(type_hint, "__origin__", type_hint)
if isinstance(origin, _SpecialForm):
origin = type_hint.__args__ # type: ignore
Expand All @@ -131,7 +133,7 @@ def _validate_default_with_annotation(self) -> None:
if not self.default or self.default == PolyforceUndefined:
return None

default = self.default() if callable(self.default) else self.default
default = self.get_default()

type_hint = self._extract_type_hint(self.annotation)
if not isinstance(default, type_hint):
Expand Down Expand Up @@ -161,6 +163,14 @@ def is_required(self) -> bool:
"""
return self.default is PolyforceUndefined and self.default_factory is None

def get_default(self) -> Any:
"""
Returns the default is
"""
if self.default_factory is None:
return self.default() if callable(self.default) else self.default
return self.default_factory()

@classmethod
def from_field(cls, default: Any = PolyforceUndefined, **kwargs: Unpack[_FieldInputs]) -> Self:
"""
Expand Down Expand Up @@ -212,11 +222,13 @@ def Field(
*,
default_factory: Union[Callable[[], Any], None] = PolyforceUndefined,
title: Union[str, None] = PolyforceUndefined, # type: ignore
name: Union[str, None] = PolyforceUndefined, # type: ignore
description: Union[str, None] = PolyforceUndefined, # type: ignore
) -> PolyField:
) -> Any:
return PolyField.from_field(
default=default,
default_factory=default_factory,
title=title,
description=description,
name=name,
)
2 changes: 1 addition & 1 deletion polyforce/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __getattribute__(self, name: str) -> Any:

def __repr_args__(self) -> "ReprArgs":
for k, v in self.__dict__.items():
field = self.poly_fields.get(k)
field = self.__dict__.get(k)
if field:
yield k, v

Expand Down
83 changes: 19 additions & 64 deletions tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,41 @@

import pytest

from polyforce import PolyField, PolyModel
from polyforce.exceptions import ValidationError
from polyforce import Field, PolyField, PolyModel


class Model(PolyModel):
def __init__(self, name: str, age: Union[str, int]) -> None:
def __init__(self, name: str = Field(), age: Union[str, int] = Field()) -> None:
...

def create_model(self, names: List[str]) -> None:
def create_model(self, names: List[str] = Field()) -> None:
return names

def get_model(self, models: Dict[str, Any]) -> Dict[str, Any]:
def get_model(self, models: Dict[str, Any] = Field()) -> Dict[str, Any]:
return models

def set_model(self, models: Mapping[str, PolyModel]) -> None:
def set_model(self, models: Mapping[str, PolyModel] = Field()) -> None:
return models


def test_can_create_polyfield():
field = PolyField(annotation=str, name="field")
assert field is not None
assert field.annotation == str
assert field.name == "field"
assert field.is_required() is True
def test_field():
model = Model(name="Polyforce", age=1)

assert len(model.poly_fields) == 4
assert model.poly_fields["create_model"]["names"].annotation == List[str]
assert model.poly_fields["get_model"]["models"].annotation == Dict[str, Any]
assert model.poly_fields["set_model"]["models"].annotation == Mapping[str, PolyModel]

def test_raise_type_error_on_default_field():
with pytest.raises(TypeError) as raised:
PolyField(annotation=str, default=2, name="name")

assert (
raised.value.args[0]
== "default 'int' for field 'name' is not valid for the field type annotation, it must be type 'str'"
)


def test_default_field():
default = "john"

def get_default():
nonlocal default
return default

field = PolyField(annotation=str, default=get_default, name="name")
assert field.default == default


def test_functions():
model = Model(name="PolyModel", age=1)

names = model.create_model(names=["poly"])
assert names == ["poly"]

models = model.get_model(models={"name": "poly"})
assert models == {"name": "poly"}

models = model.set_model(models={"name": "poly"})
assert models == {"name": "poly"}


@pytest.mark.parametrize("func", ["get_model", "set_model"])
def test_functions_raises_validation_error(func):
model = Model(name="PolyModel", age=1)

with pytest.raises(ValidationError):
model.create_model(names="a")

with pytest.raises(ValidationError):
getattr(model, func)(models="a")

def test_no_annotation():
field: PolyField = Field(default=2, name="name")

def test_poly_fields():
model = Model(name="PolyModel", age=1)
assert field.annotation is None

assert len(model.poly_fields) == 4

for value in ["create_model", "get_model", "set_model"]:
assert value in model.poly_fields
def test_raise_type_error_on_default_field():
with pytest.raises(TypeError):

assert len(model.poly_fields["__init__"]) == 2
assert len(model.poly_fields["create_model"]) == 1
assert len(model.poly_fields["get_model"]) == 1
assert len(model.poly_fields["set_model"]) == 1
class NotherModel(PolyModel):
def __init__(self, name: str = Field(default=2)) -> None:
...
Loading

0 comments on commit ff7bfde

Please sign in to comment.