From d8c1552115bdb7189e04482c358a594d1d80d60d Mon Sep 17 00:00:00 2001 From: Abhay Gupta Date: Fri, 26 Jul 2024 12:59:02 -0700 Subject: [PATCH 1/5] Enable passing epsilon when building norm layers (#1399) * adding eps to building norms * adding norm eps to layers and configs * adding docstrings --- llmfoundry/models/layers/attention.py | 7 +++++++ llmfoundry/models/layers/blocks.py | 7 +++++++ llmfoundry/models/layers/layer_builders.py | 2 ++ llmfoundry/models/mpt/configuration_mpt.py | 3 +++ llmfoundry/models/mpt/modeling_mpt.py | 1 + 5 files changed, 20 insertions(+) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8e740be2b3..c7fdb5b987 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -415,6 +415,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -520,6 +521,7 @@ def __init__( self.q_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) if self.reuse_kv_layer_idx is None: @@ -528,6 +530,7 @@ def __init__( self.k_ln = build_norm( name=norm_type.lower(), normalized_shape=norm_size, + eps=norm_eps, device=device, ) @@ -796,6 +799,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -814,6 +818,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, @@ -841,6 +846,7 @@ def __init__( softmax_scale: Optional[float] = None, attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, bias: bool = True, @@ -859,6 +865,7 @@ def __init__( softmax_scale=softmax_scale, attn_pdrop=attn_pdrop, norm_type=norm_type, + norm_eps=norm_eps, fc_type=fc_type, device=device, bias=bias, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c6988b7bd7..92735cc489 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -42,6 +42,7 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, fc_type: Optional[dict[str, Any]] = None, device: Optional[str] = None, no_bias: bool = False, @@ -84,6 +85,7 @@ def __init__( fc_type=fc_type, resid_pdrop=resid_pdrop, norm_type=norm_type, + norm_eps=norm_eps, device=device, no_bias=no_bias, ) @@ -99,6 +101,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -117,6 +120,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) @@ -260,6 +264,7 @@ def __init__( fc_type: Optional[dict[str, Any]] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, device: Optional[str] = None, no_bias: bool = False, **kwargs: Any, @@ -283,6 +288,7 @@ def __init__( self.norm_1 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.attn = build_attention_layer( @@ -302,6 +308,7 @@ def __init__( self.norm_2 = build_norm( name=norm_type.lower(), normalized_shape=d_model, + eps=norm_eps, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py index 69d2059bad..d5fd1d37d4 100644 --- a/llmfoundry/models/layers/layer_builders.py +++ b/llmfoundry/models/layers/layer_builders.py @@ -26,10 +26,12 @@ def build_norm( name: str, normalized_shape: Union[int, List[int], torch.Size], + eps: Optional[float] = 1e-5, device: Optional[str] = None, ): kwargs = { 'normalized_shape': normalized_shape, + 'eps': eps, 'device': device, } diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 8ac5a8ac49..86cc3519ba 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -44,6 +44,7 @@ def __init__( no_bias: bool = False, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', + norm_eps: float = 1e-05, use_cache: bool = False, init_config: Optional[Dict] = None, fc_type: Union[str, Dict] = 'torch', @@ -101,6 +102,7 @@ def __init__( no_bias (bool): Whether to use bias in all layers. embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. norm_type (str): choose type of norm to use + norm_eps (float): epsilon value for norm layer use_cache (bool): Whether or not the model should return the last key/values attentions init_config (Dict): A dictionary used to configure the model initialization: init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', @@ -168,6 +170,7 @@ def __init__( self.no_bias = no_bias self.embedding_fraction = embedding_fraction self.norm_type = norm_type + self.norm_eps = norm_eps self.use_cache = use_cache self.init_config = init_config if init_config is not None else copy.deepcopy( init_config_defaults, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0bcd587084..6f9b6bf806 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -426,6 +426,7 @@ def __init__(self, config: MPTConfig): self.norm_f = build_norm( name=config.norm_type.lower(), normalized_shape=config.d_model, + eps=config.norm_eps, device=config.init_device, ) From 0c93331fee80a3438a663bc205fdb8ef4fb68fcd Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 26 Jul 2024 15:47:09 -0700 Subject: [PATCH 2/5] Add pre register method for mlflow (#1396) --- llmfoundry/callbacks/hf_checkpointer.py | 13 +++++++++++++ .../inference/test_convert_composer_to_hf.py | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 35508cc0c7..1797c3b5b4 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -376,6 +376,17 @@ def transform_config( copied_config.ffn_config['moe_world_size'] = 1 return copied_config + def pre_register_edit(self, local_save_path: str): + """Edit the model before registering with MLflow. + + This allows a subclass to modify the model before registering with MLflow. The base class implementation will + make no modifications. + + Args: + local_save_path (str): The path to the model to be transformed. + """ + pass + def transform_model_pre_registration( self, model: PreTrainedModel, @@ -618,6 +629,8 @@ def tensor_hook( os.path.join(local_save_path, license_filename), ) + self.pre_register_edit(local_save_path,) + # Spawn a new process to register the model. process = SpawnProcess( target=_register_model_with_run_id_multiprocess, diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ffdb09ca98..9eb214e83d 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -388,6 +388,9 @@ def test_huggingface_conversion_callback_interval( checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, ) + checkpointer_callback.pre_register_edit = MagicMock( + wraps=checkpointer_callback.pre_register_edit, + ) trainer = Trainer( model=original_model, device='gpu', @@ -413,9 +416,11 @@ def test_huggingface_conversion_callback_interval( metadata={}, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 + assert checkpointer_callback.pre_register_edit.call_count == 1 assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert checkpointer_callback.transform_model_pre_registration.call_count == 0 + assert checkpointer_callback.pre_register_edit.call_count == 0 assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 From 799279bc19ad1ae730b1f3ec985495e6100e066f Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 26 Jul 2024 20:43:10 -0700 Subject: [PATCH 3/5] Add pip requirements directly for mlflow save (#1400) --- llmfoundry/callbacks/hf_checkpointer.py | 6 ++++++ tests/a_scripts/inference/test_convert_composer_to_hf.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 1797c3b5b4..a186f67f14 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -613,6 +613,12 @@ def tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 9eb214e83d..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -414,6 +414,7 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 assert checkpointer_callback.pre_register_edit.call_count == 1 @@ -594,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 From 5e07a05dd2f727928729ad23c26ce68ec8349286 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sat, 27 Jul 2024 00:48:09 -0700 Subject: [PATCH 4/5] Remove orig params false default (#1401) --- llmfoundry/utils/config_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 48290bd7c5..dcb97eb0de 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -531,7 +531,6 @@ def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): fsdp_config['sync_module_states'] = True # Set defaults for mixed initialization - fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) # Set ffn_config.device_mesh to fsdp_config.device_mesh From 5c7e99bed1dfe28ed4fe61a30c15a2b2d4527b8c Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 28 Jul 2024 18:39:24 -0700 Subject: [PATCH 5/5] Add spin_dataloaders flag (#1405) --- llmfoundry/command_utils/train.py | 1 + llmfoundry/utils/config_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 77bb9dbcfe..c925e6e586 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -544,6 +544,7 @@ def train(cfg: DictConfig) -> Trainer: dist_timeout=train_cfg.dist_timeout, profiler=profiler, compile_config=compile_config, + spin_dataloaders=train_cfg.spin_dataloaders, ) # Optionally just save an HF checkpoint diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index dcb97eb0de..84a3376718 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -167,6 +167,7 @@ class TrainConfig: # Dataloader device_train_microbatch_size: Union[str, int, float] = 'auto' global_train_batch_size: Optional[int] = None + spin_dataloaders: bool = True # Eval dataloader eval_subset_num_batches: int = -1