From 48e174740139c18e9e5989e1c62b5c98fbcfde29 Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Mon, 5 Aug 2024 20:44:33 -0700 Subject: [PATCH] logger --- llmfoundry/command_utils/eval.py | 2 +- llmfoundry/command_utils/train.py | 2 +- llmfoundry/utils/config_utils.py | 12 +++--------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index bddd592dba..4f6b0fdce9 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -157,7 +157,7 @@ def evaluate_model( if should_log_config: log.info('Evaluation config:') - log_config(logged_config) + log_config(trainer.logger, logged_config) log.info(f'Starting eval for {model_name}...') if torch.cuda.is_available(): diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index c925e6e586..a55f53b6c1 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -573,7 +573,7 @@ def train(cfg: DictConfig) -> Trainer: if train_cfg.log_config: log.info('Logging config') - log_config(logged_cfg) + log_config(trainer.logger, logged_cfg) log_dataset_uri(logged_cfg) torch.cuda.empty_cache() gc.collect() diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index eb54fabc3d..130b8944ab 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -23,6 +23,7 @@ ) import mlflow +from composer.loggers import Logger from composer.utils import dist, parse_uri from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue from omegaconf import OmegaConf as om @@ -575,21 +576,14 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): return init_context -def log_config(cfg: Dict[str, Any]) -> None: +def log_config(logger: Logger, cfg: Dict[str, Any]) -> None: """Logs the current config and updates the wandb and mlflow configs. This function can be called multiple times to update the wandb and MLflow config with different variables. """ print(om.to_yaml(cfg)) - loggers = cfg.get('loggers', None) or {} - if 'wandb' in loggers: - import wandb - if wandb.run: - wandb.config.update(cfg) - - if 'mlflow' in loggers and mlflow.active_run(): - mlflow.log_params(params=cfg) + logger.log_hyperparameters(cfg) def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]: