From 2765769140fc68b3999b161cb7ae764f5f598e3b Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:28:42 +0000 Subject: [PATCH] Add curriculum learning callback --- .../callbacks/curriculum_learning_callback.py | 246 +++++++++++++----- llmfoundry/data/text_data.py | 4 +- llmfoundry/models/hf/hf_causal_lm.py | 2 +- llmfoundry/models/hf/hf_t5.py | 2 +- .../models/inference_api_wrapper/interface.py | 2 +- llmfoundry/models/mpt/modeling_mpt.py | 11 +- llmfoundry/utils/__init__.py | 8 +- llmfoundry/utils/builder_tokenizer.py | 57 ++++ llmfoundry/utils/builders.py | 56 +--- llmfoundry/utils/config_dist_utils.py | 81 ++++++ llmfoundry/utils/config_utils.py | 79 ------ scripts/data_prep/convert_dataset_hf.py | 2 +- .../data_prep/convert_finetuning_dataset.py | 2 +- scripts/eval/eval.py | 4 +- scripts/inference/benchmarking/benchmark.py | 2 +- scripts/train/train.py | 10 +- tests/a_scripts/eval/test_eval.py | 3 +- .../inference/test_convert_composer_to_hf.py | 7 +- tests/callbacks/test_async_eval_callback.py | 2 +- .../test_curriculum_learning_callback.py | 2 +- .../callbacks/test_eval_gauntlet_callback.py | 2 +- .../test_loss_perp_v_len_callback.py | 2 +- .../test_mbmoe_tok_per_expert_callback.py | 2 +- tests/data/test_dataloader.py | 2 +- tests/data/test_icl_datasets.py | 2 +- tests/data/test_packing.py | 2 +- tests/data/test_template_tokenization.py | 2 +- tests/fixtures/models.py | 2 +- tests/models/hf/test_fsdp_weight_tying.py | 2 +- tests/models/hf/test_hf_config.py | 3 +- tests/models/hf/test_hf_peft_wrapping.py | 2 +- tests/models/hf/test_hf_v_mpt.py | 2 +- .../inference_api_wrapper/test_fmapi.py | 2 +- .../test_inference_api_eval_wrapper.py | 2 +- tests/models/layers/test_huggingface_flash.py | 2 +- tests/models/test_model.py | 3 +- tests/utils/test_builders.py | 2 +- 37 files changed, 372 insertions(+), 246 deletions(-) create mode 100644 llmfoundry/utils/builder_tokenizer.py create mode 100644 llmfoundry/utils/config_dist_utils.py diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 961bf1cae1..510cbcb27a 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -8,22 +8,28 @@ """ import logging -from typing import Any, Dict +import copy +from typing import Any -from composer.core import State -from composer.loggers import Logger +from composer import DataSpec +from composer.core import State, TimeUnit, ensure_time +from composer.loggers import Logger, MosaicMLLogger +from composer.trainer.trainer import _get_initial_device_train_microbatch_size from streaming import StreamingDataset +from streaming.base.util import clean_stale_shared_memory from torch.utils.data import DataLoader +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental_class +from llmfoundry.utils.builder_tokenizer import build_tokenizer +from llmfoundry.utils.config_utils import calculate_batch_size_info +from llmfoundry.utils.exceptions import BaseContextualError, TrainDataLoaderLocation log = logging.getLogger(__name__) __all__ = ['CurriculumLearning'] -@experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. @@ -34,20 +40,179 @@ class CurriculumLearning(CallbackWithConfig): dataset_index (int): The index of the dataset currently being used. """ - def __init__(self, train_config: Dict, dataset_index: int): - self.dataset_index = dataset_index - self.saved_dataset_index = 0 - self.all_dataset_configs = [] - self.current_dataset_state = {} - # The current dataset config is resolved and passed in train.py - self.current_dataset_config = train_config['train_loader'] + def __init__( + self, train_config: dict[str, Any], duration: str | int | TimeUnit, + schedule: list[dict[str, Any]] + ): + non_positive_error = ValueError('The duration must be positive.') + unit_error = ValueError( + 'Schedules can only be defined in terms of epochs or tokens.' + ) + + # Ensure all duration values are positive + # Ensure all duration units are in epochs or tokens + self._duration = ensure_time(duration, TimeUnit.EPOCH) + if self._duration.value <= 0: + raise non_positive_error + if self._duration.unit != TimeUnit.EPOCH and self._duration.unit != TimeUnit.TOKEN: + raise unit_error + + self._schedule = schedule + for datamix in self._schedule: + assert 'duration' in datamix, 'Each datamix must have a duration.' + datamix['duration'] = ensure_time( + datamix['duration'], TimeUnit.EPOCH + ) + if datamix['duration'].value <= 0: + raise non_positive_error + if datamix['duration'].unit != TimeUnit.EPOCH and datamix[ + 'duration'].unit != TimeUnit.TOKEN: + raise unit_error + assert 'train_loader' in datamix, 'Each datamix must have a train_loader.' + + self._schedule_index = -1 + + # Copied from llmfoundry/utils/config_utils.py + self.device_train_batch_size, _, _ = calculate_batch_size_info( + train_config['global_train_batch_size'], + train_config['device_train_microbatch_size'], + data_replication_degree=1, + ) + + # Copied from scripts/train/train.py + tokenizer_name = train_config['tokenizer']['name'] + tokenizer_kwargs = train_config['tokenizer'].get('kwargs', {}) + self.tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) def before_load(self, state: State, logger: Logger): del logger - # Save the current dataset state so we can restore it correctly - # if we are resuming with a new dataset. - train_loader = state.train_dataloader + # Ensure all duration units are the same as max_duration + units_match = True + if self._duration.unit != state.max_duration.unit: + units_match = False + for datamix in self._schedule: + if datamix['duration'].unit != state.max_duration.unit: + units_match = False + if not units_match: + raise ValueError(( + 'All durations in the schedule must have the same units as ' + 'the max_duration.' + )) + + # Ensure schedule duration is greater than max_duration + schedule_duration = self._duration + for datamix in self._schedule: + schedule_duration += datamix['duration'] + if schedule_duration < state.max_duration: + raise ValueError(( + 'The sum of all durations in the schedule must be greater than ' + 'or equal to the max_duration.' + )) + + self._validate_dataloader(state.train_dataloader) + + def after_load(self, state: State, logger: Logger): + del logger + + self._validate_dataloader(state.train_dataloader) + + # Check if adding a new datamix to a run that didn't use this callback + if self._schedule_index == -1 and state.timestamp >= self._duration: + self._schedule_index = 0 + state.timestamp = state.timestamp.to_next_iteration() + # If checkpoint was saved before iteration was incremented, we need to increment it now + elif (( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.TOKEN and state.timestamp.token_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + ) or ( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration + >= self._schedule[self._schedule_index]['duration'].value + )): + log.warning(( + 'The CurriculumLearning callback has detected that the previous run did not correctly ' + 'increment the iteration.' + )) + self._schedule_index += 1 + state.timestamp = state.timestamp.to_next_iteration() + + def iteration_start(self, state: State, logger: Logger): + # Reset and initialize state train dataloader + log.warning( + 'trainer._train_data_spec should be updated whenever the dataloader is updated' + ) + + # Swap the dataset if starting a new iteration that's not the original datamix + if self._schedule_index >= 0: + clean_stale_shared_memory() + data_spec = self._build_train_loader( + train_loader_config=copy.deepcopy( + self._schedule[self._schedule_index] + )['train_loader'], + logger=logger, + ) + state.set_dataloader( + dataloader=data_spec.dataloader, + dataloader_label='train', + ) + # state.train_dataloader = state.dataloader + state.device_train_microbatch_size = _get_initial_device_train_microbatch_size( + state.device_train_microbatch_size, + state.auto_microbatching, + state.train_dataloader, + ) + self._validate_dataloader(state.train_dataloader) + + # Set the length of the new iteration + if self._schedule_index == -1: + state._iteration_length = self._duration + else: + state._iteration_length = self._schedule[self._schedule_index + ]['duration'] + + def iteration_end(self, state: State, logger: Logger): + del state, logger # unused + + self._schedule_index += 1 + + def state_dict(self): + return { + 'duration': self._duration, + 'schedule': self._schedule, + 'schedule_index': self._schedule_index, + } + + def load_state_dict(self, state: dict[str, Any]): + # Ensure that the schedule has not changed on already trained datamixes + assert self._duration == state['duration'] + for idx in range(state['schedule_index'] + 1): + assert self._schedule[idx] == state['schedule'][idx] + + self._schedule_index = state['schedule_index'] + + def _build_train_loader( + self, train_loader_config: dict[str, Any], logger: Logger + ) -> DataSpec: + # Copied from scripts/train/train.py + log.info( + f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}' + ) + try: + return build_dataloader( + train_loader_config, + self.tokenizer, + self.device_train_batch_size, + ) + except BaseContextualError as e: + for destination in logger.destinations: + if isinstance(destination, MosaicMLLogger): + e.location = TrainDataLoaderLocation + destination.log_exception(e) + raise e + + def _validate_dataloader(self, train_loader: Any): # Check if we are using a DataLoader and StreamingDataset if not isinstance(train_loader, DataLoader): raise ValueError( @@ -61,54 +226,3 @@ def before_load(self, state: State, logger: Logger): f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}', ) - assert isinstance(dataset, StreamingDataset) - # Save the current dataset state so we can restore it if needed. - self.current_dataset_state = dataset.state_dict( # type: ignore - num_samples=0, from_beginning=False) - - def after_load(self, state: State, logger: Logger): - del logger - - # As saved_dataset_index is loaded from state_dict, this only runs when - # a user explicitly increments the dataset_index and not on any other - # resumption, including autoresume. - train_loader = state._train_dataloader - assert isinstance( - train_loader, - DataLoader, - ), 'CurriculumLearning callback requires a DataLoader.' - dataset = train_loader.dataset - assert isinstance( - dataset, - StreamingDataset, - ), 'CurriculumLearning callback requires a StreamingDataset.' - if self.saved_dataset_index < self.dataset_index: - # Ignore the dataset state that was read in from the checkpoint, and - # replace with the new dataset state. This preserves resumption info. - if self.current_dataset_state['epoch'] < 0: - # Make sure the epoch in the loaded state dict is not negative. - # Since `__iter__` has not yet been called on the dataset, the - # epoch index in the dataset will still be -1. We need to ensure - # that we set the epoch correctly to 0 in this case. - self.current_dataset_state['epoch'] = 0 - dataset.load_state_dict( # type: ignore - self.current_dataset_state) - # Start a new epoch since we are using a new dataset. - # This will also reset the sample_in_epoch written to checkpoint, - # making sure that subsequent resumptions proceed correctly. - state.timestamp = state.timestamp.to_next_epoch() - # Append the new dataset config to the list of all dataset configs. - self.all_dataset_configs.append(self.current_dataset_config) - elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: - # Make sure to track our current dataset config if we are just starting training. - self.all_dataset_configs.append(self.current_dataset_config) - - def state_dict(self): - return { - 'dataset_index': self.dataset_index, - 'all_dataset_configs': self.all_dataset_configs, - } - - def load_state_dict(self, state: Dict[str, Any]): - self.saved_dataset_index = state.get('dataset_index', 0) - self.all_dataset_configs = state.get('all_dataset_configs', []) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 60b81cd145..e3b3529547 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -266,7 +266,7 @@ def get_sequence_id_from_batch( return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1) -def build_streams(streams: Optional[Dict[str, Any]] = None,): +def build_streams(streams: Optional[Dict[str, Any]] = None): streams_dict = streams # build streams streams_ret = [] @@ -374,7 +374,7 @@ def build_text_dataloader( if __name__ == '__main__': import argparse - from llmfoundry.utils.builders import build_tokenizer + from llmfoundry.utils import build_tokenizer parser = argparse.ArgumentParser() parser.add_argument( diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 5f3a53ed18..0e0c400813 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -149,7 +149,7 @@ def build_metrics( Returns: Tuple[List[Metric], List[Metric]]: A tuple containing the list of training metrics and evaluation metrics. """ - from llmfoundry.utils.builders import build_metric + from llmfoundry.utils import build_metric train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + ( additional_train_metrics or [] diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index f54b7c42ec..71ea7a1c7d 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -57,7 +57,7 @@ def __init__( additional_train_metrics: Optional[List] = None, name: Optional[str] = None, ): - from llmfoundry.utils.builders import build_metric + from llmfoundry.utils import build_metric config_overrides = config_overrides or {} additional_train_metrics = additional_train_metrics or [] diff --git a/llmfoundry/models/inference_api_wrapper/interface.py b/llmfoundry/models/inference_api_wrapper/interface.py index 6d231441ae..df20a3f772 100644 --- a/llmfoundry/models/inference_api_wrapper/interface.py +++ b/llmfoundry/models/inference_api_wrapper/interface.py @@ -19,7 +19,7 @@ class InferenceAPIEvalWrapper(ComposerModel): def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer): - from llmfoundry.utils.builders import build_metric + from llmfoundry.utils import build_metric self.tokenizer = tokenizer self.labels = None diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 9d18799e93..bb1f5ac7c3 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -580,8 +580,9 @@ def forward( 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.', ) - elif (self.attn_uses_sequence_id is - False) and (sequence_id is not None): + elif ( + self.attn_uses_sequence_id is False and sequence_id is not None + ): warnings.warn( 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + @@ -1088,11 +1089,13 @@ def __init__( DEFAULT_CAUSAL_LM_EVAL_METRICS, DEFAULT_CAUSAL_LM_TRAIN_METRICS, ) - from llmfoundry.utils.builders import build_metric + from llmfoundry.utils import build_metric additional_train_metrics = additional_train_metrics or [] - model = self.model_class(self.config_class(**kwargs),) + model = self.model_class( + self.config_class(**kwargs), + ) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index dd43efcdd7..39e4d0e261 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, + build_eval_loaders, build_evaluators, build_icl_data_and_gauntlet, build_icl_evaluators, @@ -12,18 +14,18 @@ build_metric, build_optimizer, build_scheduler, - build_tokenizer, ) +from llmfoundry.utils.builder_tokenizer import build_tokenizer from llmfoundry.utils.checkpoint_conversion_helpers import ( convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict, load_tokenizer, ) +from llmfoundry.utils.config_dist_utils import process_init_device from llmfoundry.utils.config_utils import ( calculate_batch_size_info, log_config, pop_config, - process_init_device, update_batch_size_info, ) from llmfoundry.utils.data_prep_utils import ( @@ -60,8 +62,10 @@ ) __all__ = [ + 'add_metrics_to_eval_loaders', 'build_algorithm', 'build_callback', + 'build_eval_loaders', 'build_evaluators', 'build_icl_data_and_gauntlet', 'build_icl_evaluators', diff --git a/llmfoundry/utils/builder_tokenizer.py b/llmfoundry/utils/builder_tokenizer.py new file mode 100644 index 0000000000..af0c977f4a --- /dev/null +++ b/llmfoundry/utils/builder_tokenizer.py @@ -0,0 +1,57 @@ +import os +from typing import Any + +from composer.utils import dist +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper + + +def build_tokenizer( + tokenizer_name: str, + tokenizer_kwargs: dict[str, Any], +) -> PreTrainedTokenizerBase: + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' + + if dist.is_available() and dist.is_initialized() and dist.get_world_size( + ) > 1: + # Make sure the tokenizer files are downloaded and cached first by local rank 0 + with dist.local_rank_zero_download_and_wait(signal_file_path): + pass + + if tokenizer_name.startswith('tiktoken'): + tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + **tokenizer_kwargs, + ) + + # HuggingFace does not respect the model_max_length kwarg, and overrides it with + # min(kwargs['model_max_length'], original_config['model_max_length']), so we + # explicitly set it here + tokenizer.model_max_length = tokenizer_kwargs.get( + 'model_max_length', + int(1e30), + ) + + if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: + raise ValueError( + f'The tokenizer {tokenizer_name} must have an eos_token.', + ) + + if dist.is_available() and dist.is_initialized() and dist.get_world_size( + ) > 1: + if dist.get_local_rank() == 0: + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_tokenizer_setup') + + dist.barrier() + + if dist.get_local_rank() == 0: + os.remove(signal_file_path) + + return tokenizer diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 73eb026d98..12913e8056 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -28,29 +28,29 @@ from omegaconf import OmegaConf as om from torch.optim.optimizer import Optimizer from torchmetrics import Metric -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase from llmfoundry import registry from llmfoundry.callbacks import EvalGauntlet from llmfoundry.data.dataloader import build_dataloader from llmfoundry.eval.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader -from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry log = logging.getLogger(__name__) __all__ = [ + 'add_metrics_to_eval_loaders', 'build_algorithm', 'build_callback', + 'build_eval_loaders', 'build_evaluators', 'build_icl_data_and_gauntlet', 'build_icl_evaluators', 'build_logger', 'build_optimizer', 'build_scheduler', - 'build_tokenizer', 'build_composer_model', 'build_metric', ] @@ -446,56 +446,6 @@ def build_scheduler( ) -def build_tokenizer( - tokenizer_name: str, - tokenizer_kwargs: Dict[str, Any], -) -> PreTrainedTokenizerBase: - os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - - signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' - - if dist.is_available() and dist.is_initialized( - ) and dist.get_world_size() > 1: - # Make sure the tokenizer files are downloaded and cached first by local rank 0 - with dist.local_rank_zero_download_and_wait(signal_file_path): - pass - - if tokenizer_name.startswith('tiktoken'): - tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) - else: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - **tokenizer_kwargs, - ) - - # HuggingFace does not respect the model_max_length kwarg, and overrides it with - # min(kwargs['model_max_length'], original_config['model_max_length']), so we - # explicitly set it here - tokenizer.model_max_length = tokenizer_kwargs.get( - 'model_max_length', - int(1e30), - ) - - if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: - raise ValueError( - f'The tokenizer {tokenizer_name} must have an eos_token.', - ) - - if dist.is_available() and dist.is_initialized( - ) and dist.get_world_size() > 1: - if dist.get_local_rank() == 0: - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed_tokenizer_setup') - - dist.barrier() - - if dist.get_local_rank() == 0: - os.remove(signal_file_path) - - return tokenizer - - def build_icl_evaluators( icl_tasks: Union[str, List[Dict[str, Any]]], tokenizer: PreTrainedTokenizerBase, diff --git a/llmfoundry/utils/config_dist_utils.py b/llmfoundry/utils/config_dist_utils.py new file mode 100644 index 0000000000..f75821aae8 --- /dev/null +++ b/llmfoundry/utils/config_dist_utils.py @@ -0,0 +1,81 @@ +import contextlib +import warnings +from typing import Any, Mapping, Optional + +from llmfoundry.layers_registry import ffns_with_megablocks +from llmfoundry.models.utils import init_empty_weights + +__all__ = ['process_init_device'] + + +def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]): + # Restrict model init_device to 'meta' and 'cpu', + # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors + # when multiple GPUs are available. + # Also 'meta' is only valid when using FSDP + init_context = contextlib.nullcontext() + if 'init_device' in model_cfg: + assert model_cfg['init_device'] in ['meta', 'cpu', 'mixed'] + if fsdp_config is None and model_cfg['init_device'] == 'meta': + warnings.warn( + "Using `cfg.model.init_device='meta'` is only valid when using FSDP! " +\ + "Reverting to `cfg.model.init_device='cpu'`.") + model_cfg['init_device'] = 'cpu' + if model_cfg['init_device'] == 'meta': + init_context = init_empty_weights() + if model_cfg['init_device'] == 'mixed': + if fsdp_config is None: + raise NotImplementedError( + 'Using init_device `mixed` is only supported with FSDP. ' + + 'Please add a FSDP config.', + ) + # Always set `sync_module_states` to True for mixed initialization + if not fsdp_config.get('sync_module_states', False): + warnings.warn(( + 'Setting `sync_module_states = True` for FSDP. This is required ' + 'when using mixed initialization.' + )) + fsdp_config['sync_module_states'] = True + + # Set defaults for mixed initialization + fsdp_config.setdefault('use_orig_params', False) + fsdp_config.setdefault('load_monolith_rank0_only', True) + + # Set ffn_config.device_mesh to fsdp_config.device_mesh + if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ + 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: + # Raise ValueError if not using device mesh with MoE expert parallelism + if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( + 'moe_world_size', + 1, + ) > 1: + raise ValueError( + 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.', + ) + model_cfg['ffn_config']['device_mesh'] = fsdp_config['device_mesh'] + + # No mixed precision needed for weights when they're already 16 bits + master_dtype = model_cfg.get('master_weights_dtype') + small_dtypes = ( + 'bf16', + 'fp16', + 'float16', + 'bfloat16', + 'amp_fp16', + 'amp_bf16', + ) + if fsdp_config and master_dtype in small_dtypes: + reduce_dtype = None + buffer_dtype = None + mixed_precision = fsdp_config.get('mixed_precision') + if isinstance(mixed_precision, Mapping): + reduce_dtype = mixed_precision.get('reduce_dtype') + buffer_dtype = mixed_precision.get('buffer_dtype') + fsdp_config['mixed_precision'] = { + 'param_dtype': None, + 'reduce_dtype': reduce_dtype, + 'buffer_dtype': buffer_dtype, + 'keep_low_precision_grads': True, + } + + return init_context diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index b6a5acf6d9..eb863de21e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import contextlib import copy import logging import math @@ -14,7 +13,6 @@ Dict, List, Literal, - Mapping, Optional, Set, Tuple, @@ -28,16 +26,12 @@ from omegaconf import OmegaConf as om from transformers import PretrainedConfig -from llmfoundry.layers_registry import ffns_with_megablocks -from llmfoundry.models.utils import init_empty_weights - log = logging.getLogger(__name__) __all__ = [ 'pop_config', 'calculate_batch_size_info', 'update_batch_size_info', - 'process_init_device', 'log_config', 'log_dataset_uri', ] @@ -421,79 +415,6 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: return cfg -def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): - # Restrict model init_device to 'meta' and 'cpu', - # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors - # when multiple GPUs are available. - # Also 'meta' is only valid when using FSDP - init_context = contextlib.nullcontext() - if 'init_device' in model_cfg: - assert model_cfg['init_device'] in ['meta', 'cpu', 'mixed'] - if fsdp_config is None and model_cfg['init_device'] == 'meta': - warnings.warn( - "Using `cfg.model.init_device='meta'` is only valid when using FSDP! " +\ - "Reverting to `cfg.model.init_device='cpu'`.") - model_cfg['init_device'] = 'cpu' - if model_cfg['init_device'] == 'meta': - init_context = init_empty_weights() - if model_cfg['init_device'] == 'mixed': - if fsdp_config is None: - raise NotImplementedError( - 'Using init_device `mixed` is only supported with FSDP. ' + - 'Please add a FSDP config.', - ) - # Always set `sync_module_states` to True for mixed initialization - if not fsdp_config.get('sync_module_states', False): - warnings.warn(( - 'Setting `sync_module_states = True` for FSDP. This is required ' - 'when using mixed initialization.' - )) - fsdp_config['sync_module_states'] = True - - # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) - fsdp_config.setdefault('load_monolith_rank0_only', True) - - # Set ffn_config.device_mesh to fsdp_config.device_mesh - if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[ - 'ffn_config'].get('ffn_type', None) in ffns_with_megablocks: - # Raise ValueError if not using device mesh with MoE expert parallelism - if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get( - 'moe_world_size', - 1, - ) > 1: - raise ValueError( - 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.', - ) - model_cfg['ffn_config']['device_mesh'] = fsdp_config['device_mesh'] - - # No mixed precision needed for weights when they're already 16 bits - master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ( - 'bf16', - 'fp16', - 'float16', - 'bfloat16', - 'amp_fp16', - 'amp_bf16', - ) - if fsdp_config and master_dtype in small_dtypes: - reduce_dtype = None - buffer_dtype = None - mixed_precision = fsdp_config.get('mixed_precision') - if isinstance(mixed_precision, Mapping): - reduce_dtype = mixed_precision.get('reduce_dtype') - buffer_dtype = mixed_precision.get('buffer_dtype') - fsdp_config['mixed_precision'] = { - 'param_dtype': None, - 'reduce_dtype': reduce_dtype, - 'buffer_dtype': buffer_dtype, - 'keep_low_precision_grads': True, - } - - return init_context - - def log_config(cfg: Dict[str, Any]) -> None: """Logs the current config and updates the wandb and mlflow configs. diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index d7aaa52193..10bd943fc4 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -18,7 +18,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry.data import ConcatTokensDataset, NoConcatDataset -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils import build_tokenizer class ConcatMode(Enum): diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index 523d45093d..d0e621ab62 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -22,7 +22,7 @@ is_valid_ift_example, tokenize_formatted_example, ) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils import build_tokenizer HFDataset = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 8a0a5c104f..4334d0af49 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -24,20 +24,20 @@ ) install() -from llmfoundry.utils.builders import ( +from llmfoundry.utils import ( add_metrics_to_eval_loaders, build_callback, build_composer_model, build_evaluators, build_logger, build_tokenizer, + process_init_device, ) from llmfoundry.utils.config_utils import ( EVAL_CONFIG_KEYS, EvalConfig, log_config, make_dataclass_and_log_config, - process_init_device, ) from llmfoundry.utils.registry_utils import import_file diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index f85b895316..c55cca7382 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer def get_dtype(dtype: str): diff --git a/scripts/train/train.py b/scripts/train/train.py index c9e2d67bf4..364f55d8cb 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -27,11 +27,6 @@ from llmfoundry.eval.metrics.nlp import InContextLearningMetric from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils import ( - find_mosaicml_logger, - log_train_analytics, - maybe_create_mosaicml_logger, -) -from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_algorithm, build_callback, @@ -41,6 +36,10 @@ build_optimizer, build_scheduler, build_tokenizer, + find_mosaicml_logger, + log_train_analytics, + maybe_create_mosaicml_logger, + process_init_device, ) from llmfoundry.utils.config_utils import ( TRAIN_CONFIG_KEYS, @@ -49,7 +48,6 @@ log_dataset_uri, make_dataclass_and_log_config, pop_config, - process_init_device, update_batch_size_info, ) from llmfoundry.utils.exceptions import ( diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index a56778538c..c8dd3d940a 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -11,8 +11,7 @@ from composer import Trainer from composer.loggers import InMemoryLogger -from llmfoundry.utils import build_tokenizer -from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils import build_composer_model, build_tokenizer from llmfoundry.utils.config_utils import to_dict_container from scripts.eval.eval import main # noqa: E402 from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 0577e13a1f..36e62fd4f4 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -27,13 +27,14 @@ from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename from llmfoundry.data.finetuning import build_finetuning_dataloader from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM -from llmfoundry.utils import edit_files_for_hf_compatibility -from llmfoundry.utils.builders import ( +from llmfoundry.utils import ( build_composer_model, build_optimizer, build_tokenizer, + edit_files_for_hf_compatibility, + process_init_device, ) -from llmfoundry.utils.config_utils import process_init_device, to_dict_container +from llmfoundry.utils.config_utils import to_dict_container from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset diff --git a/tests/callbacks/test_async_eval_callback.py b/tests/callbacks/test_async_eval_callback.py index e0794df08f..36e21f3f1b 100644 --- a/tests/callbacks/test_async_eval_callback.py +++ b/tests/callbacks/test_async_eval_callback.py @@ -15,7 +15,7 @@ validate_eval_run_config, validate_interval, ) -from llmfoundry.utils.builders import build_callback +from llmfoundry.utils import build_callback from mcli import Run, RunConfig, RunStatus RUN_NAME = 'foo_bar-1234' diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index bbdbf3d691..f2d6dc2275 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.utils.builders import build_callback +from llmfoundry.utils import build_callback def test_curriculum_learning_callback_builds(): diff --git a/tests/callbacks/test_eval_gauntlet_callback.py b/tests/callbacks/test_eval_gauntlet_callback.py index 9c80127af5..a8993d049c 100644 --- a/tests/callbacks/test_eval_gauntlet_callback.py +++ b/tests/callbacks/test_eval_gauntlet_callback.py @@ -12,7 +12,7 @@ from transformers import AutoTokenizer from llmfoundry.eval.metrics.nlp import InContextLearningLMAccuracy -from llmfoundry.utils.builders import build_icl_data_and_gauntlet +from llmfoundry.utils import build_icl_data_and_gauntlet from llmfoundry.utils.config_utils import to_dict_container diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py index 46bde1c2f1..b84dd143ac 100644 --- a/tests/callbacks/test_loss_perp_v_len_callback.py +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -18,7 +18,7 @@ StreamingTextDataset, build_text_dataloader, ) -from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils import build_composer_model from llmfoundry.utils.registry_utils import construct_from_registry diff --git a/tests/callbacks/test_mbmoe_tok_per_expert_callback.py b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py index aff9f1724b..943f0e4fb5 100644 --- a/tests/callbacks/test_mbmoe_tok_per_expert_callback.py +++ b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py @@ -1,7 +1,7 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.utils.builders import build_callback +from llmfoundry.utils import build_callback def test_mbmoe_tok_per_expert_builds(): diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 7c8e808bab..64312e706d 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -39,7 +39,7 @@ build_text_dataloader, ) from llmfoundry.data.utils import get_tokens_per_batch_func -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils import build_tokenizer from llmfoundry.utils.config_utils import to_dict_container # yapf: disable from llmfoundry.utils.exceptions import ( diff --git a/tests/data/test_icl_datasets.py b/tests/data/test_icl_datasets.py index 5254c8e862..79c650a6ad 100644 --- a/tests/data/test_icl_datasets.py +++ b/tests/data/test_icl_datasets.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf as om from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils import build_icl_evaluators from llmfoundry.utils.config_utils import to_list_container diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index b910b8c5ff..ece55813eb 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -15,7 +15,7 @@ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils import build_tokenizer def _data_to_batch(data: List[List[int]], max_seq_len: int, diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 16447d6623..f1eef08d59 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -11,7 +11,7 @@ dataset_constructor, tokenize_formatted_example, ) -from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils import build_tokenizer from llmfoundry.utils.exceptions import ( ALLOWED_PROMPT_KEYS, ALLOWED_RESPONSE_KEYS, diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 83b0924a5d..4acfae86fa 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -10,7 +10,7 @@ from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer def _build_model(config: Dict[str, Any], tokenizer: PreTrainedTokenizerBase): diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 69ced673a1..4ad0a1399d 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -8,7 +8,7 @@ from composer import Trainer from composer.models.huggingface import maybe_get_underlying_model -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer @pytest.mark.world_size(2) diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index 1ca384171d..53582f835e 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -11,8 +11,7 @@ from transformers import PretrainedConfig from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM -from llmfoundry.utils import build_tokenizer -from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils import build_composer_model, build_tokenizer from llmfoundry.utils.config_utils import to_dict_container diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 522fc5db57..870627ee38 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -12,7 +12,7 @@ from peft import LoraConfig, get_peft_model from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer def test_peft_wraps(): diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py index 66f04e0c4a..5defa63588 100644 --- a/tests/models/hf/test_hf_v_mpt.py +++ b/tests/models/hf/test_hf_v_mpt.py @@ -8,7 +8,7 @@ from composer.utils import reproducibility from omegaconf import OmegaConf as om -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer from llmfoundry.utils.config_utils import to_dict_container diff --git a/tests/models/inference_api_wrapper/test_fmapi.py b/tests/models/inference_api_wrapper/test_fmapi.py index af26823aae..3cce7ad719 100644 --- a/tests/models/inference_api_wrapper/test_fmapi.py +++ b/tests/models/inference_api_wrapper/test_fmapi.py @@ -13,7 +13,7 @@ FMAPIChatAPIEvalWrapper, ) from llmfoundry.models.inference_api_wrapper.fmapi import FMAPIEvalInterface -from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils import build_icl_evaluators from llmfoundry.utils.config_utils import to_list_container diff --git a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py index f35e5cd750..6cedc56fad 100644 --- a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py +++ b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py @@ -13,7 +13,7 @@ OpenAIChatAPIEvalWrapper, ) from llmfoundry.tokenizers import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils import build_icl_evaluators from llmfoundry.utils.config_utils import to_list_container diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 3dc3e5dda1..58c63b8766 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -8,7 +8,7 @@ from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils import build_composer_model, build_tokenizer @pytest.mark.gpu diff --git a/tests/models/test_model.py b/tests/models/test_model.py index a62a7dd114..425682e215 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -40,8 +40,7 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM -from llmfoundry.utils import build_tokenizer -from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils import build_composer_model, build_tokenizer from llmfoundry.utils.config_utils import to_dict_container diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index dfcb5b327c..55e4d52e87 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -17,7 +17,7 @@ from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import ( +from llmfoundry.utils import ( add_metrics_to_eval_loaders, build_callback, build_eval_loaders,