From 459947c747004e8e71a669324afc756ffe704181 Mon Sep 17 00:00:00 2001 From: Jane Zhang Date: Fri, 20 Oct 2023 15:00:37 -0700 Subject: [PATCH] Adding Mosaic logger + logging data validated event (#670) --- scripts/train/train.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 5e93e33056..28ecb68e34 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -4,6 +4,7 @@ import logging import os import sys +import time import warnings from typing import Any, Dict, List, Optional, Union @@ -11,6 +12,9 @@ from composer import Trainer from composer.core import Evaluator from composer.core.callback import Callback +from composer.loggers import MosaicMLLogger +from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, + MOSAICML_PLATFORM_ENV_VAR) from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler, cyclic_schedule) from composer.utils import dist, get_device, reproducibility @@ -462,7 +466,17 @@ def main(cfg: DictConfig) -> Trainer: loggers = [ build_logger(str(name), logger_cfg) for name, logger_cfg in logger_configs.items() - ] if logger_configs else None + ] if logger_configs else [] + + mosaicml_logger = next( + (logger for logger in loggers if isinstance(logger, MosaicMLLogger)), + None) + if mosaicml_logger is None: + if os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower( + ) == 'true' and os.environ.get(MOSAICML_ACCESS_TOKEN_ENV_VAR): + # Adds mosaicml logger to composer if the run was sent from Mosaic platform, access token is set, and mosaic logger wasn't previously added + mosaicml_logger = MosaicMLLogger() + loggers.append(mosaicml_logger) # Profiling profiler: Optional[Profiler] = None @@ -510,6 +524,10 @@ def main(cfg: DictConfig) -> Trainer: tokenizer, device_train_batch_size, ) + + if mosaicml_logger is not None: + mosaicml_logger.log_metrics({'data_validated': time.time()}) + ## Evaluation print('Building eval loader...') evaluators = []