From c485afefb9850430012e1526e5339c31b1ecee33 Mon Sep 17 00:00:00 2001 From: Matt Pryor Date: Mon, 6 Nov 2023 15:37:56 +0000 Subject: [PATCH] Migrate to Pydantic v2 --- configomatic/configuration.py | 118 +++++++++++++++++++++++----------- configomatic/loader.py | 5 +- configomatic/logging.py | 33 +++++----- configomatic/utils.py | 28 ++++++++ 4 files changed, 129 insertions(+), 55 deletions(-) create mode 100644 configomatic/utils.py diff --git a/configomatic/configuration.py b/configomatic/configuration.py index c8391df..e0cb7a6 100755 --- a/configomatic/configuration.py +++ b/configomatic/configuration.py @@ -2,56 +2,99 @@ Module containing the configuration base classes for configomatic. """ -from functools import reduce +import functools import os import pathlib +import typing as t -from pydantic import BaseModel, BaseConfig -from pydantic.utils import deep_update +from pydantic import BaseModel from .exceptions import FileNotFound from .loader import load_file as default_load_file +from .utils import merge, snake_to_pascal -def snake_to_pascal(name): +class Section( + BaseModel, + alias_generator = snake_to_pascal, + populate_by_name = True +): """ - Converts a snake case name to pascalCase. + Base class for a configuration section. + """ + + +class ConfigEnvironmentDict(t.TypedDict, total = False): """ - first, *rest = name.split("_") - return "".join([first] + [part.capitalize() for part in rest]) + TypedDict for configuring the config environment. + """ + path_env_var: t.Optional[str] + """An environment variable that may specify the config path""" + + default_path: t.Optional[str] + """The default configuration path""" + + load_file: t.Optional[t.Callable[[str], t.Dict[str, t.Any]]] + """The function to use to load the configuration file""" + env_prefix: t.Optional[str] + """The prefix to use for environment overrides""" +_config_env_keys = set(ConfigEnvironmentDict.__annotations__.keys()) -class Section(BaseModel): +class ConfigurationMeta(type(BaseModel)): """ - Base class for a configuration section. + Metaclass for a configuration. """ - class Config: - # Allow pascalCase names as well as snake_case - alias_generator = snake_to_pascal - allow_population_by_field_name = True + def __new__( + cls, + name, + bases, + attrs, + /, + **kwargs + ): + # Config environment configuration is inherited from the bases + config_env = ConfigEnvironmentDict() + for base in bases: + config_env.update(getattr(base, "config_env", None) or {}) + + # Then a config_env specified on the child + config_env.update(attrs.pop("config_env", None) or {}) + + # Partition the kwargs into those for the config env and those for Pydantic + config_env_kwargs = {} + pydantic_kwargs = {} + for key, value in kwargs.items(): + if key in _config_env_keys: + config_env_kwargs[key] = value + else: + pydantic_kwargs[key] = value + + # Update the config environment with the keywords + config_env.update(config_env_kwargs) + # Add the config env to the model attrs + # Add a classvar annotation so that Pydantic leaves it alone + attrs["config_env"] = config_env + attrs.setdefault("__annotations__", {})["config_env"] = t.ClassVar[ConfigEnvironmentDict] -class Configuration(BaseModel): + return super().__new__(cls, name, bases, attrs, **pydantic_kwargs) + + +class Configuration( + BaseModel, + metaclass = ConfigurationMeta, + alias_generator = snake_to_pascal, + populate_by_name = True, +): """ Base class for a configuration. """ - class Config(BaseConfig): - # An environment variable that may specify the config path - path_env_var = None - # The default configuration path - default_path = None - # The function to use to load the configuration file - load_file = default_load_file - # The prefix to use for environment overrides - env_prefix = None - # Allow pascalCase names as well as snake_case - alias_generator = snake_to_pascal - allow_population_by_field_name = True - - __config__ = Config + config_env = ConfigEnvironmentDict() + def __init__(self, _use_file = True, _path = None, _use_env = True, **init_kwargs): # Work out which configs to use @@ -64,22 +107,24 @@ def __init__(self, _use_file = True, _path = None, _use_env = True, **init_kwarg configs.append(self._load_environ()) # The highest precedence is given to directly supplied keyword args configs.append(init_kwargs) - super().__init__(**deep_update(*configs)) + super().__init__(**merge(*configs)) def _load_file(self, path): # If no path is given, try the environment variable - if not path and self.__config__.path_env_var: - path = os.environ.get(self.__config__.path_env_var) + path_env_var = self.config_env.get("path_env_var") + if not path and path_env_var: + path = os.environ.get(path_env_var) # Any path found by this point is explicitly defined by the user explicit_path = bool(path) # If we have still not found a path, use the default if not path: - path = self.__config__.default_path + path = self.config_env.get("default_path") # Load the specified configuration file if path: path = pathlib.Path(path) if path.is_file(): - return self.__config__.load_file(path) or {} + load_file = self.config_env.get("load_file") or default_load_file + return load_file(path) or {} elif explicit_path: # If the file was explicitly specified by the user, require it to exist raise FileNotFound(f"{path} does not exist") @@ -93,19 +138,20 @@ def _load_environ(self): # Build a nested dict from environment variables with the specified prefix # Nesting is specified using __ env_vars = {} + env_prefix = self.config_env.get("env_prefix") for env_var, env_val in os.environ.items(): # Only consider non-empty environment variables if not env_val: continue env_var_parts = env_var.split('__') # The first part must match the prefix, but is otherwise thrown away - if self.__config__.env_prefix: - if env_var_parts[0].upper() == self.__config__.env_prefix: + if env_prefix: + if env_var_parts[0].upper() == env_prefix.upper(): env_var_parts = env_var_parts[1:] else: continue # The rest of the parts form a nested dictionary - nested_vars = reduce( + nested_vars = functools.reduce( lambda vars, part: vars.setdefault(part.lower(), {}), env_var_parts[:-1], env_vars diff --git a/configomatic/loader.py b/configomatic/loader.py index b9ede29..ac2bf5f 100644 --- a/configomatic/loader.py +++ b/configomatic/loader.py @@ -2,8 +2,6 @@ import json import pathlib -from pydantic.utils import deep_update - try: import yaml yaml_available = True @@ -17,6 +15,7 @@ toml_available = False from .exceptions import RequiredPackageNotAvailable, NoSuitableLoader +from .utils import merge if yaml_available: @@ -55,7 +54,7 @@ def include_constructor(loader, node): for p in glob.iglob(path, recursive = True) ) # Merge the configs in sort order, so overrides are predictable - return deep_update(*[ + return merge(*[ load_file(path) for path in sorted(included_paths - excluded_paths) ]) diff --git a/configomatic/logging.py b/configomatic/logging.py index 2d0f33e..80feffd 100644 --- a/configomatic/logging.py +++ b/configomatic/logging.py @@ -1,7 +1,8 @@ import logging -from pydantic import BaseModel, Field, validator -from pydantic.utils import deep_update +from pydantic import BaseModel, Field, field_validator + +from .utils import merge class LessThanLevelFilter(logging.Filter): @@ -22,14 +23,14 @@ class LoggingConfiguration(BaseModel): # See https://docs.python.org/3/library/logging.config.html#logging-config-dictschema version: int = 1 disable_existing_loggers: bool = False - formatters: dict = Field(default_factory = dict) - filters: dict = Field(default_factory = dict) - handlers: dict = Field(default_factory = dict) - loggers: dict = Field(default_factory = dict) + formatters: dict = Field(default_factory = dict, validate_default = True) + filters: dict = Field(default_factory = dict, validate_default = True) + handlers: dict = Field(default_factory = dict, validate_default = True) + loggers: dict = Field(default_factory = dict, validate_default = True) - @validator("formatters", pre = True, always = True) + @field_validator("formatters") def default_formatters(cls, v): - return deep_update( + return merge( { "default": { "format": "[%(asctime)s] %(name)-20.20s [%(levelname)-8.8s] %(message)s", @@ -38,9 +39,9 @@ def default_formatters(cls, v): v or {} ) - @validator("filters", pre = True, always = True) + @field_validator("filters") def default_filters(cls, v): - return deep_update( + return merge( { # This filter allows us to send >= WARNING to stderr and < WARNING to stdout "less_than_warning": { @@ -51,9 +52,9 @@ def default_filters(cls, v): v or {} ) - @validator("handlers", pre = True, always = True) + @field_validator("handlers") def default_handlers(cls, v): - return deep_update( + return merge( { # Handlers for stdout/err with default formatting "stdout": { @@ -72,9 +73,9 @@ def default_handlers(cls, v): v or {} ) - @validator("loggers", pre = True, always = True) + @field_validator("loggers") def default_loggers(cls, v): - return deep_update( + return merge( { # Just set the config for the default logger here "": { @@ -91,7 +92,7 @@ def apply(self, overrides = None): Apply the logging configuration. """ import logging.config - config = self.dict() + config = self.model_dump() if overrides: - config = deep_update(config, overrides) + config = merge(config, overrides) logging.config.dictConfig(config) diff --git a/configomatic/utils.py b/configomatic/utils.py new file mode 100644 index 0000000..f4861bb --- /dev/null +++ b/configomatic/utils.py @@ -0,0 +1,28 @@ +import functools + + +def snake_to_pascal(name): + """ + Converts a snake_case name to pascalCase. + """ + first, *rest = name.split("_") + return "".join([first] + [part.capitalize() for part in rest]) + + +def merge(defaults, *overrides): + """ + Returns a new dictionary obtained by deep-merging multiple sets of overrides + into defaults, with precedence from right to left. + """ + def merge2(defaults, overrides): + if isinstance(defaults, dict) and isinstance(overrides, dict): + merged = defaults.copy() + for key, value in overrides.items(): + if key in defaults: + merged[key] = merge2(defaults[key], value) + else: + merged[key] = value + return merged + else: + return overrides if overrides is not None else defaults + return functools.reduce(merge2, overrides, defaults)