From 9103dbd769006b201a75d3f1b6a3da06a77d0cc9 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Tue, 17 Dec 2024 11:34:59 -0500 Subject: [PATCH] Update hf_checkpointer.py --- llmfoundry/callbacks/hf_checkpointer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b5a49ec7c0..e0e2d2be3a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -619,7 +619,9 @@ def tensor_hook( hooks = [] for _, module in state_dict_model.named_modules(): - hooks.append(module._register_state_dict_hook(tensor_hook),) + hooks.append( + module._register_state_dict_hook(tensor_hook), + ) state_dict = get_model_state_dict( state_dict_model, @@ -690,10 +692,10 @@ def tensor_hook( def _register_hf_model( self, - state: State, temp_save_dir: str, original_tokenizer: PreTrainedTokenizerBase, use_temp_dir: bool, + new_model_instance: PreTrainedModel, ): assert new_model_instance is not None new_model_instance = self.transform_model_pre_registration( @@ -834,12 +836,11 @@ def _save_checkpoint( dist.barrier() if dist.get_global_rank() == 0: + assert new_model_instance is not None if register_to_mlflow: self._register_hf_model( - state, - temp_save_dir, - original_tokenizer, - use_temp_dir, + temp_save_dir, original_tokenizer, use_temp_dir, + new_model_instance ) else: # Clean up the temporary directory if we don't need to register to mlflow.