Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(GoodConf): generate TOML configuration for complex types #33

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 77 additions & 12 deletions goodconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from io import StringIO
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast

from pydantic import BaseSettings, PrivateAttr
from pydantic import BaseModel, BaseSettings, PrivateAttr
from pydantic.env_settings import SettingsSourceCallable
from pydantic.fields import Field, FieldInfo, ModelField, Undefined # noqa
from pydantic.fields import ModelField, Undefined

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,18 +53,32 @@ def _find_file(filename: str, require: bool = True) -> Optional[str]:

def initial_for_field(name: str, field: ModelField) -> Any:
info = field.field_info
initial = "" # Default value
try:
if not callable(info.extra["initial"]):
raise ValueError(f"Initial value for `{name}` must be a callable.")
return info.extra["initial"]()
initial = info.extra["initial"]()
except KeyError:
if info.default is not Undefined and info.default is not ...:
return info.default
initial = info.default
if info.default_factory is not None:
return info.default_factory()
initial = info.default_factory()

# If initial is a BaseModel generate the dictionary representation using pydantic
# built-in method
if isinstance(initial, BaseModel):
return initial.dict()
# If initial is a list, concatenate the result in an output list
elif isinstance(initial, list):
# If it contains a list of BaseModel, invoke dict on each of them
if any(isinstance(element, BaseModel) for element in initial):
return [element.dict() for element in initial]
else:
# If they are basic types, simply concatenate them
return [inner for inner in initial]
if field.allow_none:
return None
return ""
return initial


def file_config_settings_source(settings: BaseSettings) -> Dict[str, Any]:
Expand Down Expand Up @@ -144,7 +158,7 @@ def load(self, filename: Optional[str] = None) -> None:
super().__init__()

@classmethod
def get_initial(cls, **override) -> dict:
def get_initial(cls, **override) -> dict[str, Any]:
return {
k: override.get(k, initial_for_field(k, v))
for k, v in cls.__fields__.items()
Expand Down Expand Up @@ -199,11 +213,62 @@ def generate_toml(cls, **override) -> str:
document = tomlkit.document()
if cls.__doc__:
document.add(tomlkit.comment(cls.__doc__))
for k, v in dict_from_toml.unwrap().items():
document.add(k, v)
if cls.__fields__[k].field_info.description:
description = cast(str, cls.__fields__[k].field_info.description)
cast(Item, document[k]).comment(description)

def create_item(field: ModelField, initial_value: Any) -> Item:
"""Recursively traverse the input field,
building the appropriate TOML Item while descending the hierarchy.
Stop when find a basic type is encountered, created as a basic TOML Item"""
# Check to see if the initial_value is a complex type
if isinstance(initial_value, dict):
# If this field contains sub-fields inside,
# create them inside a TOML table
table = tomlkit.table()
# Invoke recursively on each subfield
for name, field in field.type_.__fields__.items():
item = create_item(field, initial_value[name])
# Add the item to the table
table[name] = item
return table
# Che if the initial_value is a list of object
elif isinstance(initial_value, list):
# Check to see if the list of sub-fields contains any complex type.
# In that case, an array of table (aot) is required
if getattr(field, "sub_fields") and any(
sub_field.is_complex() for sub_field in field.sub_fields
):
array = tomlkit.aot()
else:
# The sub-fields are basic types
array = tomlkit.array()

for index, _ in enumerate(initial_value):
# Invoke recursively on each element
if getattr(field, "sub_fields"):
# We have a complex type in the sub_fields
item = create_item(field.sub_fields[0], initial_value[index])
else:
# We have a simple type
item = create_item(field, initial_value[index])
# Append each item to the array
array.append(item)

return array
# Base of the recursion: the initial_value is a simple type
else:
# Create a base TOML item
item = tomlkit.item(initial_value)

# Add description to the item, if present
if field.field_info.description:
description = cast(str, field.field_info.description)
item.comment(description)

return item

for k, initial_value in dict_from_toml.unwrap().items():
item = create_item(cls.__fields__[k], initial_value)
document.add(k, item)

return tomlkit.dumps(document)

@classmethod
Expand Down
30 changes: 29 additions & 1 deletion tests/test_goodconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import pytest
from pydantic import Field, ValidationError
from pydantic import BaseModel, Field, ValidationError

from goodconf import GoodConf
from tests.utils import env_var
Expand Down Expand Up @@ -54,6 +54,34 @@ class TestConf(GoodConf):
assert 'b = ""' in output


def test_dump_complex_toml():
"""Dump a complex configuration class, with inner classes and lists"""
pytest.importorskip("tomlkit")
import tomlkit

class TestConf(GoodConf):
class A(BaseModel):
inner: bool = False
index: int

outer = A(index=0)
simple_list: list[int] = [1, 2]
complex_list: list[A] = [A(index=0)]

output = TestConf.generate_toml()
assert "[outer]" in output
assert "inner = false" in output

# Check that generated toml is valid
doc = tomlkit.parse(output)
assert doc["outer"]["inner"] is False

# Check the lists
assert len(doc["simple_list"]) == 2
assert doc["simple_list"][0] == 1
assert doc["complex_list"][0]["index"] == 0


def test_dump_yaml():
pytest.importorskip("ruamel.yaml")

Expand Down
38 changes: 38 additions & 0 deletions tests/test_initial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import pytest
from pydantic import BaseModel

from goodconf import Field, GoodConf, initial_for_field

Expand Down Expand Up @@ -59,3 +60,40 @@ class G(GoodConf):

initial = G().get_initial()
assert initial["a"] is None


def test_complex_initial():
"""Test a nested inner BaseModel"""

class G(GoodConf):
class A(BaseModel):
inner: str = "test A"

outer_a = A()

initial = G().get_initial()
assert initial["outer_a"]["inner"] == "test A"


def test_list_initial():
"""Test a list of basic types"""

class G(GoodConf):
list = [0, 1, 2]

initial = G().get_initial()
assert len(initial["list"]) == 3


def test_list_complex_initial():
"""Test a list of nested inner BaseModel"""

class G(GoodConf):
class A(BaseModel):
inner: str = "test A"

list = [A()]

initial = G().get_initial()
assert len(initial["list"]) == 1
assert initial["list"][0]["inner"] == "test A"