Skip to content

Commit

Permalink
Merge branch 'main' into cl_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 25, 2024
2 parents d745a12 + 21c9e0a commit 92ae4ee
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 11 deletions.
3 changes: 2 additions & 1 deletion llmfoundry/callbacks/run_timeout_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from composer.loggers import MosaicMLLogger

from llmfoundry.utils.exceptions import RunTimeoutError
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook

log = logging.getLogger(__name__)


def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None):
log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',)
if mosaicml_logger is not None:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout))
os.kill(os.getpid(), signal.SIGINT)

Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
log_eval_analytics,
log_train_analytics,
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.prompt_files import load_prompts, load_prompts_from_file
from llmfoundry.utils.registry_utils import (
Expand Down Expand Up @@ -97,6 +98,7 @@
'download_from_hf_hub',
'download_from_oras',
'maybe_create_mosaicml_logger',
'no_override_excepthook',
'find_mosaicml_logger',
'log_eval_analytics',
'log_train_analytics',
Expand Down
11 changes: 11 additions & 0 deletions llmfoundry/utils/mosaicml_logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]:
return MosaicMLLogger()


def no_override_excepthook() -> bool:
"""Returns True if the excepthook flag is off.
This means we are not automatically catching exceptions for MosaicMl runs.
"""
return os.environ.get(
'OVERRIDE_EXCEPTHOOK',
'false',
).lower() != 'true'


def find_mosaicml_logger(
loggers: List[LoggerDestination],
) -> Optional[MosaicMLLogger]:
Expand Down
7 changes: 5 additions & 2 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row

from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.exceptions import (
ClusterDoesNotExistError,
FailedToConnectToDatabricksError,
Expand Down Expand Up @@ -676,6 +679,6 @@ def fetch_DT(args: Namespace) -> None:
log.info(f'Elapsed time {time.time() - tik}')

except Exception as e:
if mosaicml_logger is not None:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
7 changes: 5 additions & 2 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.data.data import AbstractConcatTokensDataset
from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils import (
maybe_create_mosaicml_logger,
no_override_excepthook,
)
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
download_file,
Expand Down Expand Up @@ -624,6 +627,6 @@ def _configure_logging(logging_level: str):
args_str=_args_str(args),
)
except Exception as e:
if mosaicml_logger is not None:
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e
13 changes: 7 additions & 6 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
EvalDataLoaderLocation,
TrainDataLoaderLocation,
)
from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook
from llmfoundry.utils.registry_utils import import_file

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -384,8 +385,8 @@ def main(cfg: DictConfig) -> Trainer:
train_cfg.device_train_batch_size,
)
except BaseContextualError as e:
if mosaicml_logger is not None:
e.location = TrainDataLoaderLocation
e.location = TrainDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

Expand Down Expand Up @@ -417,8 +418,8 @@ def main(cfg: DictConfig) -> Trainer:
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
except BaseContextualError as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
e.location = EvalDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

Expand Down Expand Up @@ -467,8 +468,8 @@ def main(cfg: DictConfig) -> Trainer:
non_icl_metrics,
)
except BaseContextualError as e:
if mosaicml_logger is not None:
e.location = EvalDataLoaderLocation
e.location = EvalDataLoaderLocation
if mosaicml_logger is not None and no_override_excepthook():
mosaicml_logger.log_exception(e)
raise e

Expand Down

0 comments on commit 92ae4ee

Please sign in to comment.