From 8a1e55eb53f0645c1bc98b96c3e2be76a79753c3 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 18 Nov 2024 06:56:41 -0800 Subject: [PATCH] Move transform_model_pre_registration in hf_checkpointer (#1664) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/callbacks/hf_checkpointer.py | 11 ++- .../inference/test_convert_composer_to_hf.py | 72 ++++++++++++++++++- 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7ce9818426..4cc5f46d1a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -784,6 +784,10 @@ def tensor_hook( if dist.get_global_rank() == 0: if register_to_mlflow: + assert new_model_instance is not None + new_model_instance = self.transform_model_pre_registration( + new_model_instance, + ) if self.using_peft: # Save and register peft model to mlflow, this code path uses our older two step logic @@ -798,10 +802,6 @@ def tensor_hook( temp_save_dir, 'register_save', ) - assert new_model_instance is not None - new_model_instance = self.transform_model_pre_registration( - new_model_instance, - ) new_model_instance.save_pretrained( register_save_dir, max_shard_size='1GB', @@ -860,9 +860,6 @@ def _save_and_register_peft_model( original_tokenizer: Optional[Any], save_dir: str, ): - new_model_instance = self.transform_model_pre_registration( - new_model_instance, - ) components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 5dafdcb466..67b4a69a3b 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -624,7 +624,7 @@ def test_huggingface_conversion_callback_interval( def _get_model_and_tokenizer( model: str, max_seq_len: int, - tie_word_embeddings: bool, + tie_word_embeddings: Optional[bool], precision: str, ): if model == 'mpt': @@ -1110,6 +1110,76 @@ def test_huggingface_conversion_callback( delete_transformers_cache() +@patch('os.cpu_count', MagicMock(return_value=1)) +@patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess, +) +def test_transform_model_pre_registration(): + """Test `transform_model_pre_registration` method is called.""" + + class ExtendedHuggingFaceCheckpointer(HuggingFaceCheckpointer): + """Set PEFT to false before registering for testing.""" + + def transform_model_pre_registration(self, model: PreTrainedModel): + self.using_peft = False + return super().transform_model_pre_registration(model) + + model_cfg, tokenizer_name = _get_model_and_tokenizer( + model='neo', + max_seq_len=10, + tie_word_embeddings=None, + precision='bfloat16', + ) + model_cfg['peft_config'] = { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM', + 'lora_alpha': 32, + 'lora_dropout': 0.05, + 'r': 16, + 'target_modules': 'all-linear', + } + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={}, + ) + + original_model = build_composer_model( + model_cfg.pop('name'), + tokenizer=tokenizer, + cfg=model_cfg, + ) + + logger = MagicMock() + state = MagicMock() + state.timestamp.batch = 1 + state.is_model_ddp = False + state.model = original_model + state.model.tokenizer = tokenizer + + checkpointer = ExtendedHuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + ) + mlflow_logger_mock = _create_mlflow_logger_mock() + checkpointer.mlflow_loggers = [mlflow_logger_mock] # type: ignore + + assert model_cfg is not None + assert tokenizer_name is not None + + checkpointer._save_and_register_peft_model = MagicMock() + checkpointer.using_peft = True + checkpointer._save_checkpoint( + state=state, + logger=logger, + upload_to_save_folder=True, + register_to_mlflow=True, + ) + + checkpointer._save_and_register_peft_model.assert_not_called() + assert mlflow_logger_mock.log_model.call_count == 1 + + # TODO(GRT-2431): Refactor as enums @pytest.mark.parametrize( 'model,tie_word_embeddings',