From a0c8309c3621212d3fd6b364d83accf5df9dcf71 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 24 Jun 2024 10:47:04 -0700 Subject: [PATCH] remove logger in data scripts and callback --- llmfoundry/callbacks/run_timeout_callback.py | 5 +++- scripts/data_prep/convert_delta_to_json.py | 10 ++++--- scripts/data_prep/convert_text_to_mds.py | 5 +++- scripts/train/train.py | 28 ++++++++++++-------- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index eb8051240d..717a1979b0 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -17,7 +17,10 @@ def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None): log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) - if mosaicml_logger is not None: + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout)) os.kill(os.getpid(), signal.SIGINT) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f63f1b0027..694317ee35 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -27,8 +27,9 @@ from packaging import version from pyspark.sql import SparkSession from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import \ - ExecutePlanResponseReattachableIterator +from pyspark.sql.connect.client.reattach import ( + ExecutePlanResponseReattachableIterator, +) from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row @@ -644,6 +645,9 @@ def fetch_DT(args: Namespace) -> None: log.info(f'Elapsed time {time.time() - tik}') except Exception as e: - if mosaicml_logger is not None: + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': mosaicml_logger.log_exception(e) raise e diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index b2f0b0e7b4..ed0e802fad 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -624,6 +624,9 @@ def _configure_logging(logging_level: str): args_str=_args_str(args), ) except Exception as e: - if mosaicml_logger is not None: + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': mosaicml_logger.log_exception(e) raise e diff --git a/scripts/train/train.py b/scripts/train/train.py index d86302ecaa..fe984b642f 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -394,10 +394,12 @@ def main(cfg: DictConfig) -> Trainer: train_cfg.device_train_batch_size, ) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = TrainDataLoaderLocation - if os.environ.get('OVERRIDE_EXCEPTHOOK', 'false').lower() != 'true': - mosaicml_logger.log_exception(e) + e.location = TrainDataLoaderLocation + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': + mosaicml_logger.log_exception(e) raise e if mosaicml_logger is not None: @@ -428,11 +430,12 @@ def main(cfg: DictConfig) -> Trainer: if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation - if os.environ.get('OVERRIDE_EXCEPTHOOK', - 'false').lower() != 'true': - mosaicml_logger.log_exception(e) + e.location = EvalDataLoaderLocation + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': + mosaicml_logger.log_exception(e) raise e if mosaicml_logger is not None: @@ -480,8 +483,11 @@ def main(cfg: DictConfig) -> Trainer: non_icl_metrics, ) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation + e.location = EvalDataLoaderLocation + if mosaicml_logger is not None and os.environ.get( + 'OVERRIDE_EXCEPTHOOK', + 'false', + ).lower() != 'true': mosaicml_logger.log_exception(e) raise e