Skip to content

Commit

Permalink
Migrate to Pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
mkjpryor committed Nov 6, 2023
1 parent f9b4737 commit c485afe
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 55 deletions.
118 changes: 82 additions & 36 deletions configomatic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,99 @@
Module containing the configuration base classes for configomatic.
"""

from functools import reduce
import functools
import os
import pathlib
import typing as t

from pydantic import BaseModel, BaseConfig
from pydantic.utils import deep_update
from pydantic import BaseModel

from .exceptions import FileNotFound
from .loader import load_file as default_load_file
from .utils import merge, snake_to_pascal


def snake_to_pascal(name):
class Section(
BaseModel,
alias_generator = snake_to_pascal,
populate_by_name = True
):
"""
Converts a snake case name to pascalCase.
Base class for a configuration section.
"""


class ConfigEnvironmentDict(t.TypedDict, total = False):
"""
first, *rest = name.split("_")
return "".join([first] + [part.capitalize() for part in rest])
TypedDict for configuring the config environment.
"""
path_env_var: t.Optional[str]
"""An environment variable that may specify the config path"""

default_path: t.Optional[str]
"""The default configuration path"""

load_file: t.Optional[t.Callable[[str], t.Dict[str, t.Any]]]
"""The function to use to load the configuration file"""

env_prefix: t.Optional[str]
"""The prefix to use for environment overrides"""


_config_env_keys = set(ConfigEnvironmentDict.__annotations__.keys())


class Section(BaseModel):
class ConfigurationMeta(type(BaseModel)):
"""
Base class for a configuration section.
Metaclass for a configuration.
"""
class Config:
# Allow pascalCase names as well as snake_case
alias_generator = snake_to_pascal
allow_population_by_field_name = True
def __new__(
cls,
name,
bases,
attrs,
/,
**kwargs
):
# Config environment configuration is inherited from the bases
config_env = ConfigEnvironmentDict()
for base in bases:
config_env.update(getattr(base, "config_env", None) or {})

# Then a config_env specified on the child
config_env.update(attrs.pop("config_env", None) or {})

# Partition the kwargs into those for the config env and those for Pydantic
config_env_kwargs = {}
pydantic_kwargs = {}
for key, value in kwargs.items():
if key in _config_env_keys:
config_env_kwargs[key] = value
else:
pydantic_kwargs[key] = value

# Update the config environment with the keywords
config_env.update(config_env_kwargs)

# Add the config env to the model attrs
# Add a classvar annotation so that Pydantic leaves it alone
attrs["config_env"] = config_env
attrs.setdefault("__annotations__", {})["config_env"] = t.ClassVar[ConfigEnvironmentDict]

class Configuration(BaseModel):
return super().__new__(cls, name, bases, attrs, **pydantic_kwargs)


class Configuration(
BaseModel,
metaclass = ConfigurationMeta,
alias_generator = snake_to_pascal,
populate_by_name = True,
):
"""
Base class for a configuration.
"""
class Config(BaseConfig):
# An environment variable that may specify the config path
path_env_var = None
# The default configuration path
default_path = None
# The function to use to load the configuration file
load_file = default_load_file
# The prefix to use for environment overrides
env_prefix = None
# Allow pascalCase names as well as snake_case
alias_generator = snake_to_pascal
allow_population_by_field_name = True

__config__ = Config
config_env = ConfigEnvironmentDict()


def __init__(self, _use_file = True, _path = None, _use_env = True, **init_kwargs):
# Work out which configs to use
Expand All @@ -64,22 +107,24 @@ def __init__(self, _use_file = True, _path = None, _use_env = True, **init_kwarg
configs.append(self._load_environ())
# The highest precedence is given to directly supplied keyword args
configs.append(init_kwargs)
super().__init__(**deep_update(*configs))
super().__init__(**merge(*configs))

def _load_file(self, path):
# If no path is given, try the environment variable
if not path and self.__config__.path_env_var:
path = os.environ.get(self.__config__.path_env_var)
path_env_var = self.config_env.get("path_env_var")
if not path and path_env_var:
path = os.environ.get(path_env_var)
# Any path found by this point is explicitly defined by the user
explicit_path = bool(path)
# If we have still not found a path, use the default
if not path:
path = self.__config__.default_path
path = self.config_env.get("default_path")
# Load the specified configuration file
if path:
path = pathlib.Path(path)
if path.is_file():
return self.__config__.load_file(path) or {}
load_file = self.config_env.get("load_file") or default_load_file
return load_file(path) or {}
elif explicit_path:
# If the file was explicitly specified by the user, require it to exist
raise FileNotFound(f"{path} does not exist")
Expand All @@ -93,19 +138,20 @@ def _load_environ(self):
# Build a nested dict from environment variables with the specified prefix
# Nesting is specified using __
env_vars = {}
env_prefix = self.config_env.get("env_prefix")
for env_var, env_val in os.environ.items():
# Only consider non-empty environment variables
if not env_val:
continue
env_var_parts = env_var.split('__')
# The first part must match the prefix, but is otherwise thrown away
if self.__config__.env_prefix:
if env_var_parts[0].upper() == self.__config__.env_prefix:
if env_prefix:
if env_var_parts[0].upper() == env_prefix.upper():
env_var_parts = env_var_parts[1:]
else:
continue
# The rest of the parts form a nested dictionary
nested_vars = reduce(
nested_vars = functools.reduce(
lambda vars, part: vars.setdefault(part.lower(), {}),
env_var_parts[:-1],
env_vars
Expand Down
5 changes: 2 additions & 3 deletions configomatic/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import json
import pathlib

from pydantic.utils import deep_update

try:
import yaml
yaml_available = True
Expand All @@ -17,6 +15,7 @@
toml_available = False

from .exceptions import RequiredPackageNotAvailable, NoSuitableLoader
from .utils import merge


if yaml_available:
Expand Down Expand Up @@ -55,7 +54,7 @@ def include_constructor(loader, node):
for p in glob.iglob(path, recursive = True)
)
# Merge the configs in sort order, so overrides are predictable
return deep_update(*[
return merge(*[
load_file(path)
for path in sorted(included_paths - excluded_paths)
])
Expand Down
33 changes: 17 additions & 16 deletions configomatic/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging

from pydantic import BaseModel, Field, validator
from pydantic.utils import deep_update
from pydantic import BaseModel, Field, field_validator

from .utils import merge


class LessThanLevelFilter(logging.Filter):
Expand All @@ -22,14 +23,14 @@ class LoggingConfiguration(BaseModel):
# See https://docs.python.org/3/library/logging.config.html#logging-config-dictschema
version: int = 1
disable_existing_loggers: bool = False
formatters: dict = Field(default_factory = dict)
filters: dict = Field(default_factory = dict)
handlers: dict = Field(default_factory = dict)
loggers: dict = Field(default_factory = dict)
formatters: dict = Field(default_factory = dict, validate_default = True)
filters: dict = Field(default_factory = dict, validate_default = True)
handlers: dict = Field(default_factory = dict, validate_default = True)
loggers: dict = Field(default_factory = dict, validate_default = True)

@validator("formatters", pre = True, always = True)
@field_validator("formatters")
def default_formatters(cls, v):
return deep_update(
return merge(
{
"default": {
"format": "[%(asctime)s] %(name)-20.20s [%(levelname)-8.8s] %(message)s",
Expand All @@ -38,9 +39,9 @@ def default_formatters(cls, v):
v or {}
)

@validator("filters", pre = True, always = True)
@field_validator("filters")
def default_filters(cls, v):
return deep_update(
return merge(
{
# This filter allows us to send >= WARNING to stderr and < WARNING to stdout
"less_than_warning": {
Expand All @@ -51,9 +52,9 @@ def default_filters(cls, v):
v or {}
)

@validator("handlers", pre = True, always = True)
@field_validator("handlers")
def default_handlers(cls, v):
return deep_update(
return merge(
{
# Handlers for stdout/err with default formatting
"stdout": {
Expand All @@ -72,9 +73,9 @@ def default_handlers(cls, v):
v or {}
)

@validator("loggers", pre = True, always = True)
@field_validator("loggers")
def default_loggers(cls, v):
return deep_update(
return merge(
{
# Just set the config for the default logger here
"": {
Expand All @@ -91,7 +92,7 @@ def apply(self, overrides = None):
Apply the logging configuration.
"""
import logging.config
config = self.dict()
config = self.model_dump()
if overrides:
config = deep_update(config, overrides)
config = merge(config, overrides)
logging.config.dictConfig(config)
28 changes: 28 additions & 0 deletions configomatic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools


def snake_to_pascal(name):
"""
Converts a snake_case name to pascalCase.
"""
first, *rest = name.split("_")
return "".join([first] + [part.capitalize() for part in rest])


def merge(defaults, *overrides):
"""
Returns a new dictionary obtained by deep-merging multiple sets of overrides
into defaults, with precedence from right to left.
"""
def merge2(defaults, overrides):
if isinstance(defaults, dict) and isinstance(overrides, dict):
merged = defaults.copy()
for key, value in overrides.items():
if key in defaults:
merged[key] = merge2(defaults[key], value)
else:
merged[key] = value
return merged
else:
return overrides if overrides is not None else defaults
return functools.reduce(merge2, overrides, defaults)

0 comments on commit c485afe

Please sign in to comment.