diff --git a/goodconf/__init__.py b/goodconf/__init__.py index e008710..4c04eeb 100644 --- a/goodconf/__init__.py +++ b/goodconf/__init__.py @@ -9,9 +9,9 @@ from io import StringIO from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, cast -from pydantic import BaseSettings, PrivateAttr +from pydantic import BaseModel, BaseSettings, PrivateAttr from pydantic.env_settings import SettingsSourceCallable -from pydantic.fields import Field, FieldInfo, ModelField, Undefined # noqa +from pydantic.fields import ModelField, Undefined log = logging.getLogger(__name__) @@ -53,18 +53,32 @@ def _find_file(filename: str, require: bool = True) -> Optional[str]: def initial_for_field(name: str, field: ModelField) -> Any: info = field.field_info + initial = "" # Default value try: if not callable(info.extra["initial"]): raise ValueError(f"Initial value for `{name}` must be a callable.") - return info.extra["initial"]() + initial = info.extra["initial"]() except KeyError: if info.default is not Undefined and info.default is not ...: - return info.default + initial = info.default if info.default_factory is not None: - return info.default_factory() + initial = info.default_factory() + + # If initial is a BaseModel generate the dictionary representation using pydantic + # built-in method + if isinstance(initial, BaseModel): + return initial.dict() + # If initial is a list, concatenate the result in an output list + elif isinstance(initial, list): + # If it contains a list of BaseModel, invoke dict on each of them + if any(isinstance(element, BaseModel) for element in initial): + return [element.dict() for element in initial] + else: + # If they are basic types, simply concatenate them + return [inner for inner in initial] if field.allow_none: return None - return "" + return initial def file_config_settings_source(settings: BaseSettings) -> Dict[str, Any]: @@ -144,7 +158,7 @@ def load(self, filename: Optional[str] = None) -> None: super().__init__() @classmethod - def get_initial(cls, **override) -> dict: + def get_initial(cls, **override) -> dict[str, Any]: return { k: override.get(k, initial_for_field(k, v)) for k, v in cls.__fields__.items() @@ -199,11 +213,62 @@ def generate_toml(cls, **override) -> str: document = tomlkit.document() if cls.__doc__: document.add(tomlkit.comment(cls.__doc__)) - for k, v in dict_from_toml.unwrap().items(): - document.add(k, v) - if cls.__fields__[k].field_info.description: - description = cast(str, cls.__fields__[k].field_info.description) - cast(Item, document[k]).comment(description) + + def create_item(field: ModelField, initial_value: Any) -> Item: + """Recursively traverse the input field, + building the appropriate TOML Item while descending the hierarchy. + Stop when find a basic type is encountered, created as a basic TOML Item""" + # Check to see if the initial_value is a complex type + if isinstance(initial_value, dict): + # If this field contains sub-fields inside, + # create them inside a TOML table + table = tomlkit.table() + # Invoke recursively on each subfield + for name, field in field.type_.__fields__.items(): + item = create_item(field, initial_value[name]) + # Add the item to the table + table[name] = item + return table + # Che if the initial_value is a list of object + elif isinstance(initial_value, list): + # Check to see if the list of sub-fields contains any complex type. + # In that case, an array of table (aot) is required + if getattr(field, "sub_fields") and any( + sub_field.is_complex() for sub_field in field.sub_fields + ): + array = tomlkit.aot() + else: + # The sub-fields are basic types + array = tomlkit.array() + + for index, _ in enumerate(initial_value): + # Invoke recursively on each element + if getattr(field, "sub_fields"): + # We have a complex type in the sub_fields + item = create_item(field.sub_fields[0], initial_value[index]) + else: + # We have a simple type + item = create_item(field, initial_value[index]) + # Append each item to the array + array.append(item) + + return array + # Base of the recursion: the initial_value is a simple type + else: + # Create a base TOML item + item = tomlkit.item(initial_value) + + # Add description to the item, if present + if field.field_info.description: + description = cast(str, field.field_info.description) + item.comment(description) + + return item + + for k, initial_value in dict_from_toml.unwrap().items(): + item = create_item(cls.__fields__[k], initial_value) + document.add(k, item) + return tomlkit.dumps(document) @classmethod diff --git a/tests/test_goodconf.py b/tests/test_goodconf.py index d56ef55..d345344 100644 --- a/tests/test_goodconf.py +++ b/tests/test_goodconf.py @@ -5,7 +5,7 @@ from typing import Optional import pytest -from pydantic import Field, ValidationError +from pydantic import BaseModel, Field, ValidationError from goodconf import GoodConf from tests.utils import env_var @@ -54,6 +54,34 @@ class TestConf(GoodConf): assert 'b = ""' in output +def test_dump_complex_toml(): + """Dump a complex configuration class, with inner classes and lists""" + pytest.importorskip("tomlkit") + import tomlkit + + class TestConf(GoodConf): + class A(BaseModel): + inner: bool = False + index: int + + outer = A(index=0) + simple_list: list[int] = [1, 2] + complex_list: list[A] = [A(index=0)] + + output = TestConf.generate_toml() + assert "[outer]" in output + assert "inner = false" in output + + # Check that generated toml is valid + doc = tomlkit.parse(output) + assert doc["outer"]["inner"] is False + + # Check the lists + assert len(doc["simple_list"]) == 2 + assert doc["simple_list"][0] == 1 + assert doc["complex_list"][0]["index"] == 0 + + def test_dump_yaml(): pytest.importorskip("ruamel.yaml") diff --git a/tests/test_initial.py b/tests/test_initial.py index 5a643e3..1f79390 100644 --- a/tests/test_initial.py +++ b/tests/test_initial.py @@ -1,6 +1,7 @@ from typing import Optional import pytest +from pydantic import BaseModel from goodconf import Field, GoodConf, initial_for_field @@ -59,3 +60,40 @@ class G(GoodConf): initial = G().get_initial() assert initial["a"] is None + + +def test_complex_initial(): + """Test a nested inner BaseModel""" + + class G(GoodConf): + class A(BaseModel): + inner: str = "test A" + + outer_a = A() + + initial = G().get_initial() + assert initial["outer_a"]["inner"] == "test A" + + +def test_list_initial(): + """Test a list of basic types""" + + class G(GoodConf): + list = [0, 1, 2] + + initial = G().get_initial() + assert len(initial["list"]) == 3 + + +def test_list_complex_initial(): + """Test a list of nested inner BaseModel""" + + class G(GoodConf): + class A(BaseModel): + inner: str = "test A" + + list = [A()] + + initial = G().get_initial() + assert len(initial["list"]) == 1 + assert initial["list"][0]["inner"] == "test A"