diff --git a/src/armonik_cli/core/config.py b/src/armonik_cli/core/config.py index 3bb05f0..f703f19 100644 --- a/src/armonik_cli/core/config.py +++ b/src/armonik_cli/core/config.py @@ -20,7 +20,8 @@ class Configuration: _config_keys = ["endpoint"] _default_config = {"endpoint": None} - def __init__(self, endpoint: str) -> None: + def __init__(self, path: Path, endpoint: str) -> None: + self.path = path self.endpoint = endpoint @classmethod @@ -44,8 +45,21 @@ def load_default(cls) -> "Configuration": Returns: An instance of Configuration populated with values from the default file. """ - with cls.default_path.open("r") as config_file: - return cls(**json.loads(config_file.read())) + return cls.load(cls._default_config) + + @classmethod + def load(cls, path: Path) -> "Configuration": + """ + Load a configuration from a given configuration file. + + Args: + path: Path to the configuration file. + + Returns: + An instance of Configuration populated with values from the configuration file. + """ + with path.open("r") as config_file: + return cls(path=path, **json.loads(config_file.read())) def has(self, key: str) -> bool: """ @@ -98,5 +112,5 @@ def _save(self): """ Save the current configuration values to the default configuration file. """ - with self.default_path.open("w") as config_file: + with self.path.open("w") as config_file: config_file.write(json.dumps(self.to_dict(), indent=4)) diff --git a/src/armonik_cli/core/decorators.py b/src/armonik_cli/core/decorators.py index fc401d1..6a219f8 100644 --- a/src/armonik_cli/core/decorators.py +++ b/src/armonik_cli/core/decorators.py @@ -1,10 +1,14 @@ +import logging + from functools import wraps, partial +from typing import Dict, Optional import grpc import rich_click as click from armonik.common.channel import create_channel +from armonik_cli.core.config import Configuration from armonik_cli.core.console import console from armonik_cli.exceptions import NotFoundError, InternalError @@ -140,8 +144,11 @@ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): if connection_args: - optional_config = kwargs.pop("optional_config_file") - kwargs["channel_ctx"] = create_channel(kwargs.pop("endpoint"), certificate_authority=kwargs.pop("ca"), client_certificate=kwargs.pop("cert"), client_key=kwargs.pop("key")) + optional_config_file = kwargs.pop("optional_config_file") + optional_config = Configuration.load(optional_config_file) if optional_config_file else None + default_config = Configuration.load_default() + connection_details = reconcile_connection_details(cl_inputs={"endpoint": kwargs.pop("endpoint"), "certificate_authority": kwargs.pop("ca"), "client_certificate": kwargs.pop("cert"), "client_key": kwargs.pop("key")}, default_config, )optional_config + kwargs["channel_ctx"] = create_channel(connection_details.pop("endpoint"), **connection_details) return func(*args, **kwargs) return wrapper @@ -149,3 +156,59 @@ def wrapper(*args, **kwargs): if _func is None: return decorator return decorator(_func) + + +def reconcile_connection_details( + *, + cl_inputs: Dict[str, str], + default_config: Configuration, + optional_config: Optional[Configuration] = None, + logger: logging.Logger, +) -> Dict[str, str]: + """ + Reconciles parameters from command-line inputs, optional config file, and default config file. + Command-line params have highest priority, then optional config, then default config. + + Args: + cl_inputs: Parameters provided via command-line. + optional_config: Optional external configuration. Default is None. + default_config: Default configuration. + logger: A logger to log the origin of configuration parameters. + + Returns: + +Final parameters obtained in order of priority. + """ + + final_params = {} + + # Resolve each parameter in priority order + for key in cl_inputs.keys(): + source = "" + # Priority 1: command-line parameters + if cl_inputs.get(key) is not None: + source = click.get_current_context().get_parameter_source(key).name.lower() + final_params[key] = cl_inputs[key] + # Priority 2: optional config file, if exists + elif optional_config and optional_config.has(key): + source = str(optional_config.path) + final_params[key] = optional_config.get(key) + # Priority 3: default config file + elif default_config.HAS(key): + source = str(default_config.path) + final_params[key] = default_config.get(key) + else: + if key == "endpoint": + raise click.exceptions.UsageError("No endpoint provided.") + final_params[key] = None + + if source: + logger.info(f"Parameter '{key}' retrieved from {source}.") + else: + logger.info(f"Parameter '{key}' is missing.") + + return final_params + +# TODO: add check on the channel_ctx within tests +# TODO: update docstrings of base_command and reconcile_connection_details +# TODO: move reconcile_connection_details to a utils.py file