diff --git a/amniotic/config.py b/amniotic/config.py index 5ccffd1..f4c1ee2 100644 --- a/amniotic/config.py +++ b/amniotic/config.py @@ -1,6 +1,5 @@ from os import getenv -import json import logging import tempfile import yaml @@ -46,8 +45,12 @@ def __post_init__(self): self.logging = self.logging or logging.INFO @classmethod - def from_file(cls): + def get_path_config(cls) -> Path: + """ + + Get path to config file from environment variable, or default location + """ path_config = getenv('AMNIOTIC_CONFIG_PATH') if not path_config: @@ -56,6 +59,13 @@ def from_file(cls): path_config = Path(path_config).absolute() + return path_config + + @classmethod + def from_file(cls): + + path_config = cls.get_path_config() + if not path_config.exists(): msg = f'Config file not found at "{path_config}". Default values will be used.' logging.warning(msg) @@ -63,18 +73,8 @@ def from_file(cls): else: msg = f'Config file found at "{path_config}"' logging.info(msg) - config_str = Path(path_config).read_text() - - if path_config.suffix in {'.yml', '.yaml'}: - config = yaml.safe_load(config_str) - elif path_config.suffix in {'.json'}: - config = json.loads(config_str) - else: - msg = f'Unknown config format "{path_config.suffix}"' - raise ValueError(msg) - - logging.warning(msg) + config = yaml.safe_load(config_str) field_names = {field.name for field in fields(Config)} for key in field_names: