Skip to content

Commit

Permalink
Fix log_config (#1432)
Browse files Browse the repository at this point in the history
Co-authored-by: Saaketh Narayan <[email protected]>
  • Loading branch information
josejg and snarayan21 authored Aug 17, 2024
1 parent ddccd12 commit d5ed9e7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/command_utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def evaluate_model(

if should_log_config:
log.info('Evaluation config:')
log_config(logged_config)
log_config(trainer.logger, logged_config)

log.info(f'Starting eval for {model_name}...')
if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def train(cfg: DictConfig) -> Trainer:

if train_cfg.log_config:
log.info('Logging config')
log_config(logged_cfg)
log_config(trainer.logger, logged_cfg)
log_dataset_uri(logged_cfg)
torch.cuda.empty_cache()
gc.collect()
Expand Down
12 changes: 3 additions & 9 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

import mlflow
from composer.loggers import Logger
from composer.utils import dist, parse_uri
from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -571,21 +572,14 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
return init_context


def log_config(cfg: dict[str, Any]) -> None:
def log_config(logger: Logger, cfg: dict[str, Any]) -> None:
"""Logs the current config and updates the wandb and mlflow configs.
This function can be called multiple times to update the wandb and MLflow
config with different variables.
"""
print(om.to_yaml(cfg))
loggers = cfg.get('loggers', None) or {}
if 'wandb' in loggers:
import wandb
if wandb.run:
wandb.config.update(cfg)

if 'mlflow' in loggers and mlflow.active_run():
mlflow.log_params(params=cfg)
logger.log_hyperparameters(cfg)


def _parse_source_dataset(cfg: dict[str, Any]) -> list[tuple[str, str, str]]:
Expand Down

0 comments on commit d5ed9e7

Please sign in to comment.