From 2a773b6536e1be1e7eee07ff783193008d727e74 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Wed, 19 Jun 2024 16:58:36 -0700 Subject: [PATCH 01/10] removed logging exception --- scripts/train/train.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index f2a70b526d..3799a3eb7f 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -394,9 +394,7 @@ def main(cfg: DictConfig) -> Trainer: train_cfg.device_train_batch_size, ) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = TrainDataLoaderLocation - mosaicml_logger.log_exception(e) + e.location = TrainDataLoaderLocation raise e if mosaicml_logger is not None: @@ -427,9 +425,7 @@ def main(cfg: DictConfig) -> Trainer: if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation - mosaicml_logger.log_exception(e) + e.location = EvalDataLoaderLocation raise e if mosaicml_logger is not None: From 678a8060d4e58b88c0b09ce21448aa5288572dcf Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Wed, 3 Jul 2024 15:07:12 -0700 Subject: [PATCH 02/10] merged main --- .../callbacks/curriculum_learning_callback.py | 150 +++++++++++++++++- scripts/data_prep/convert_delta_to_json.py | 6 +- scripts/data_prep/convert_text_to_mds.py | 37 ++--- scripts/train/train.py | 4 +- 4 files changed, 164 insertions(+), 33 deletions(-) diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 961bf1cae1..e0b1efa99a 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -16,7 +16,10 @@ from torch.utils.data import DataLoader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental_class +from llmfoundry.utils.exceptions import ( + BaseContextualError, + TrainDataLoaderLocation, +) log = logging.getLogger(__name__) @@ -45,9 +48,148 @@ def __init__(self, train_config: Dict, dataset_index: int): def before_load(self, state: State, logger: Logger): del logger - # Save the current dataset state so we can restore it correctly - # if we are resuming with a new dataset. - train_loader = state.train_dataloader + # Ensure all duration units are the same as max_duration + datamix_units = [datamix['duration'].unit for datamix in self._schedule] + assert state.max_duration is not None, 'max_duration should have beeen set.' + if any(state.max_duration.unit != unit for unit in datamix_units): + raise ValueError(( + f'All durations in the schedule must have the same units as ' + f'the max_duration. Expected {state.max_duration.unit}, but ' + f'got {datamix_units}.' + )) + + # Ensure schedule duration is equal to max_duration + schedule_duration = Time(0, state.max_duration.unit) + for datamix in self._schedule: + assert isinstance(datamix['duration'], Time) + schedule_duration += datamix['duration'] + if schedule_duration != state.max_duration: + raise ValueError(( + 'The sum of all durations in the schedule must be equal to the ' + 'max_duration.' + )) + + self._validate_dataloader(state.train_dataloader) + + def after_load(self, state: State, logger: Logger): + del logger # unused + + self._validate_dataloader(state.train_dataloader) + + # If checkpoint was saved before iteration was incremented, we need to increment it now + if (( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.TOKEN and state.timestamp.token_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + ) or ( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + )): + log.warning(( + 'The CurriculumLearning callback has detected that the previous run did not correctly ' + 'increment the iteration.' + )) + self._schedule_index += 1 + state.timestamp = state.timestamp.to_next_iteration() + + def iteration_start(self, state: State, logger: Logger): + # Swap the dataset if starting a new iteration that's not the original datamix + if self._schedule_index > 0: + # TODO: trainer._train_data_spec should be updated whenever the dataloader is updated + # Dataloaders with the same prefix access the same shared memory + # which is stale + clean_stale_shared_memory() + datamix = copy.deepcopy(self._schedule[self._schedule_index]) + data_spec = self._build_train_loader( + train_loader_config=datamix['train_loader'], + logger=logger, + ) + state.set_dataloader( + dataloader=data_spec.dataloader, + dataloader_label='train', + ) + state.train_dataloader = state.dataloader + self._validate_dataloader(state.train_dataloader) + + # Set the length of the new iteration + state._iteration_length = self._schedule[self._schedule_index + ]['duration'] + + def iteration_end(self, state: State, logger: Logger): + del state, logger # unused + + self._schedule_index += 1 + + def state_dict(self): + return { + 'schedule': self._schedule, + 'schedule_index': self._schedule_index, + } + + def load_state_dict(self, state: dict[str, Any]): + self._schedule_index = state['schedule_index'] + + # Ensure that the schedule has not changed on previously trained datamixes + for idx in range(state['schedule_index']): + if self._schedule[idx] != state['schedule'][idx]: + raise ValueError(( + f'Previous datamixes must stay the same across ', + f'resumptions. Expected {state["schedule"][idx]} but got ', + f'{self._schedule[idx]}', + )) + + # Ensure that the datamix has not changed on the current datamix + current_loader = self._schedule[self._schedule_index]['train_loader'] + saved_loader = state['schedule'][self._schedule_index]['train_loader'] + if current_loader != saved_loader: + raise ValueError(( + f'The current datamix must stay the same across resumptions. ', + f'Expected {saved_loader} but got {current_loader}', + )) + + # Ensure that the current datamix duration is greater than timestamp + duration = self._schedule[self._schedule_index]['duration'] + if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: + raise ValueError(( + f'Duration must be in terms of tokens or epochs, but got ', + f'{duration.unit}.', + )) + if (( + duration.unit == TimeUnit.TOKEN and + duration > state['timestamp'].token_in_iteration + ) or ( + duration.unit == TimeUnit.EPOCH and + duration > state['timestamp'].epoch_in_iteration + )): + raise ValueError(( + 'The duration of the current datamix must be less or equal to ' + 'than the saved timestamp.' + )) + + def _build_train_loader( + self, + train_loader_config: dict[str, Any], + logger: Logger, + ) -> DataSpec: + from llmfoundry.data.dataloader import build_dataloader + + # Copied from scripts/train/train.py + log.info( + f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', + ) + assert self.tokenizer is not None + try: + return build_dataloader( + train_loader_config, + self.tokenizer, + self.device_train_batch_size, + ) + except BaseContextualError as e: + e.location = TrainDataLoaderLocation + raise e + + def _validate_dataloader(self, train_loader: Any): # Check if we are using a DataLoader and StreamingDataset if not isinstance(train_loader, DataLoader): raise ValueError( diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index f63f1b0027..b2892d521a 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -33,7 +33,9 @@ from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row -from llmfoundry.utils import maybe_create_mosaicml_logger +from llmfoundry.utils import ( + maybe_create_mosaicml_logger, +) from llmfoundry.utils.exceptions import ( ClusterDoesNotExistError, FailedToConnectToDatabricksError, @@ -644,6 +646,4 @@ def fetch_DT(args: Namespace) -> None: log.info(f'Elapsed time {time.time() - tik}') except Exception as e: - if mosaicml_logger is not None: - mosaicml_logger.log_exception(e) raise e diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py index b2f0b0e7b4..92c36eb35d 100644 --- a/scripts/data_prep/convert_text_to_mds.py +++ b/scripts/data_prep/convert_text_to_mds.py @@ -24,7 +24,6 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.data.data import AbstractConcatTokensDataset -from llmfoundry.utils import maybe_create_mosaicml_logger from llmfoundry.utils.data_prep_utils import ( DownloadingIterable, download_file, @@ -605,25 +604,17 @@ def _configure_logging(logging_level: str): if __name__ == '__main__': args = parse_args() _configure_logging(args.logging_level) - - mosaicml_logger = maybe_create_mosaicml_logger() - - try: - convert_text_to_mds( - tokenizer_name=args.tokenizer, - output_folder=args.output_folder, - input_folder=args.input_folder, - concat_tokens=args.concat_tokens, - eos_text=args.eos_text, - bos_text=args.bos_text, - no_wrap=args.no_wrap, - compression=args.compression, - processes=args.processes, - reprocess=args.reprocess, - trust_remote_code=args.trust_remote_code, - args_str=_args_str(args), - ) - except Exception as e: - if mosaicml_logger is not None: - mosaicml_logger.log_exception(e) - raise e + convert_text_to_mds( + tokenizer_name=args.tokenizer, + output_folder=args.output_folder, + input_folder=args.input_folder, + concat_tokens=args.concat_tokens, + eos_text=args.eos_text, + bos_text=args.bos_text, + no_wrap=args.no_wrap, + compression=args.compression, + processes=args.processes, + reprocess=args.reprocess, + trust_remote_code=args.trust_remote_code, + args_str=_args_str(args), + ) diff --git a/scripts/train/train.py b/scripts/train/train.py index 3799a3eb7f..f77abd8a7d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -473,9 +473,7 @@ def main(cfg: DictConfig) -> Trainer: non_icl_metrics, ) except BaseContextualError as e: - if mosaicml_logger is not None: - e.location = EvalDataLoaderLocation - mosaicml_logger.log_exception(e) + e.location = EvalDataLoaderLocation raise e compile_config = train_cfg.compile_config From 04bff5c22a5a2925bf432f24d346e0880c4c8024 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 13:36:01 -0700 Subject: [PATCH 03/10] format files --- llmfoundry/callbacks/curriculum_learning_callback.py | 8 ++++---- llmfoundry/callbacks/run_timeout_callback.py | 4 +--- scripts/data_prep/convert_delta_to_json.py | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 7988462d14..98a672f8db 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -130,12 +130,12 @@ def after_load(self, state: State, logger: Logger): # If checkpoint was saved before iteration was incremented, we need to increment it now if (( self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.TOKEN and state.timestamp.token_in_iteration - >= self._schedule[self._schedule_index]['duration'].value + == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= + self._schedule[self._schedule_index]['duration'].value ) or ( self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration - >= self._schedule[self._schedule_index]['duration'].value + == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= + self._schedule[self._schedule_index]['duration'].value )): log.warning(( 'The CurriculumLearning callback has detected that the previous run did not correctly ' diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index f240efe0de..2f64f6f931 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -15,9 +15,7 @@ def _timeout(timeout: int): - log.error( - f'Timeout after {timeout} seconds of inactivity after fit_end.', - ) + log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) try: raise RunTimeoutError(timeout=timeout) except RunTimeoutError: diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 7fc43f0dba..e1fb6f9fe1 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -28,9 +28,8 @@ from packaging import version from pyspark.sql import SparkSession from pyspark.sql.connect.client.core import SparkConnectClient -from pyspark.sql.connect.client.reattach import ( - ExecutePlanResponseReattachableIterator, -) +from pyspark.sql.connect.client.reattach import \ + ExecutePlanResponseReattachableIterator from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.dataframe import DataFrame as SparkDataFrame from pyspark.sql.types import Row From a13e70855d59b6de10864f641a6f2fe51a210153 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 13:47:39 -0700 Subject: [PATCH 04/10] removed unused import --- scripts/train/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 41daa966f6..668e20dbab 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -56,7 +56,6 @@ EvalDataLoaderLocation, TrainDataLoaderLocation, ) -from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook from llmfoundry.utils.registry_utils import import_file log = logging.getLogger(__name__) From 2ce048324c03cbb15d44a70107ba53fbf33203e4 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 15:10:26 -0700 Subject: [PATCH 05/10] put os kill in finally --- llmfoundry/callbacks/run_timeout_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index 2f64f6f931..a050e6b1a8 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -18,7 +18,7 @@ def _timeout(timeout: int): log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) try: raise RunTimeoutError(timeout=timeout) - except RunTimeoutError: + finally: os.kill(os.getpid(), signal.SIGINT) From 729fe967ccdc1b8e4b3fec384bb70e2801697b42 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 16:29:09 -0700 Subject: [PATCH 06/10] moving timeout earlier for testing --- llmfoundry/callbacks/run_timeout_callback.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index a050e6b1a8..28f04adea5 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -15,7 +15,9 @@ def _timeout(timeout: int): - log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) + log.error( + f'Timeout after {timeout} seconds of inactivity after fit_end.', + ) try: raise RunTimeoutError(timeout=timeout) finally: @@ -30,6 +32,7 @@ def __init__( ): self.timeout = timeout self.timer: Optional[threading.Timer] = None + self._timeout() def _reset(self): if self.timer is not None: From e13e50e066c88cb079f1b90d8a442e203edea44b Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 17:00:52 -0700 Subject: [PATCH 07/10] test timeout without try catch --- llmfoundry/callbacks/run_timeout_callback.py | 8 ++++---- llmfoundry/utils/__init__.py | 2 -- llmfoundry/utils/mosaicml_logger_utils.py | 11 ----------- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index 28f04adea5..193be74cd0 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -18,10 +18,10 @@ def _timeout(timeout: int): log.error( f'Timeout after {timeout} seconds of inactivity after fit_end.', ) - try: - raise RunTimeoutError(timeout=timeout) - finally: - os.kill(os.getpid(), signal.SIGINT) + # try: + raise RunTimeoutError(timeout=timeout) + # finally: + # os.kill(os.getpid(), signal.SIGINT) class RunTimeoutCallback(Callback): diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 3768fef843..87a08a999d 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -46,7 +46,6 @@ log_eval_analytics, log_train_analytics, maybe_create_mosaicml_logger, - no_override_excepthook, ) from llmfoundry.utils.prompt_files import load_prompts, load_prompts_from_file from llmfoundry.utils.registry_utils import ( @@ -98,7 +97,6 @@ 'download_from_hf_hub', 'download_from_oras', 'maybe_create_mosaicml_logger', - 'no_override_excepthook', 'find_mosaicml_logger', 'log_eval_analytics', 'log_train_analytics', diff --git a/llmfoundry/utils/mosaicml_logger_utils.py b/llmfoundry/utils/mosaicml_logger_utils.py index e3f88e40dc..b01170ff0f 100644 --- a/llmfoundry/utils/mosaicml_logger_utils.py +++ b/llmfoundry/utils/mosaicml_logger_utils.py @@ -37,17 +37,6 @@ def maybe_create_mosaicml_logger() -> Optional[MosaicMLLogger]: return MosaicMLLogger() -def no_override_excepthook() -> bool: - """Returns True if the excepthook flag is off. - - This means we are not automatically catching exceptions for MosaicMl runs. - """ - return os.environ.get( - 'OVERRIDE_EXCEPTHOOK', - 'false', - ).lower() != 'true' - - def find_mosaicml_logger( loggers: List[LoggerDestination], ) -> Optional[MosaicMLLogger]: From f223f55202c53e055f552e41f51ed334dbffb853 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 17:14:30 -0700 Subject: [PATCH 08/10] added back try finally --- llmfoundry/callbacks/run_timeout_callback.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index 193be74cd0..28f04adea5 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -18,10 +18,10 @@ def _timeout(timeout: int): log.error( f'Timeout after {timeout} seconds of inactivity after fit_end.', ) - # try: - raise RunTimeoutError(timeout=timeout) - # finally: - # os.kill(os.getpid(), signal.SIGINT) + try: + raise RunTimeoutError(timeout=timeout) + finally: + os.kill(os.getpid(), signal.SIGINT) class RunTimeoutCallback(Callback): From 0ccc918d27fcb0d9583c70256e4c01f5c4788404 Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 17:19:49 -0700 Subject: [PATCH 09/10] removing runtime error --- llmfoundry/callbacks/run_timeout_callback.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index 28f04adea5..144c67c220 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -9,19 +9,12 @@ from composer import Callback, Logger, State -from llmfoundry.utils.exceptions import RunTimeoutError - log = logging.getLogger(__name__) def _timeout(timeout: int): - log.error( - f'Timeout after {timeout} seconds of inactivity after fit_end.', - ) - try: - raise RunTimeoutError(timeout=timeout) - finally: - os.kill(os.getpid(), signal.SIGINT) + log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',) + os.kill(os.getpid(), signal.SIGINT) class RunTimeoutCallback(Callback): From f506c37443a880b1f562bbd7f0c57e1a5f1069cf Mon Sep 17 00:00:00 2001 From: jjanezhang Date: Mon, 8 Jul 2024 17:21:06 -0700 Subject: [PATCH 10/10] removed test init --- llmfoundry/callbacks/run_timeout_callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/callbacks/run_timeout_callback.py b/llmfoundry/callbacks/run_timeout_callback.py index 144c67c220..4c791f37d7 100644 --- a/llmfoundry/callbacks/run_timeout_callback.py +++ b/llmfoundry/callbacks/run_timeout_callback.py @@ -25,7 +25,6 @@ def __init__( ): self.timeout = timeout self.timer: Optional[threading.Timer] = None - self._timeout() def _reset(self): if self.timer is not None: