From d5ed9e7f407e10792e61af312e0ca61003ed9dfd Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Fri, 16 Aug 2024 23:13:12 -0700 Subject: [PATCH] Fix log_config (#1432) Co-authored-by: Saaketh Narayan --- llmfoundry/command_utils/eval.py | 2 +- llmfoundry/command_utils/train.py | 2 +- llmfoundry/utils/config_utils.py | 12 +++--------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/llmfoundry/command_utils/eval.py b/llmfoundry/command_utils/eval.py index d628107db2..f622ca182d 100644 --- a/llmfoundry/command_utils/eval.py +++ b/llmfoundry/command_utils/eval.py @@ -157,7 +157,7 @@ def evaluate_model( if should_log_config: log.info('Evaluation config:') - log_config(logged_config) + log_config(trainer.logger, logged_config) log.info(f'Starting eval for {model_name}...') if torch.cuda.is_available(): diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index b11924004d..8fac739544 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -573,7 +573,7 @@ def train(cfg: DictConfig) -> Trainer: if train_cfg.log_config: log.info('Logging config') - log_config(logged_cfg) + log_config(trainer.logger, logged_cfg) log_dataset_uri(logged_cfg) torch.cuda.empty_cache() gc.collect() diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 623aadb5cc..7da60c5a02 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -19,6 +19,7 @@ ) import mlflow +from composer.loggers import Logger from composer.utils import dist, parse_uri from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue from omegaconf import OmegaConf as om @@ -571,21 +572,14 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): return init_context -def log_config(cfg: dict[str, Any]) -> None: +def log_config(logger: Logger, cfg: dict[str, Any]) -> None: """Logs the current config and updates the wandb and mlflow configs. This function can be called multiple times to update the wandb and MLflow config with different variables. """ print(om.to_yaml(cfg)) - loggers = cfg.get('loggers', None) or {} - if 'wandb' in loggers: - import wandb - if wandb.run: - wandb.config.update(cfg) - - if 'mlflow' in loggers and mlflow.active_run(): - mlflow.log_params(params=cfg) + logger.log_hyperparameters(cfg) def _parse_source_dataset(cfg: dict[str, Any]) -> list[tuple[str, str, str]]: