Skip to content

Commit

Permalink
Update hf_checkpointer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Dec 17, 2024
1 parent 5110d2d commit 9103dbd
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,9 @@ def tensor_hook(

hooks = []
for _, module in state_dict_model.named_modules():
hooks.append(module._register_state_dict_hook(tensor_hook),)
hooks.append(
module._register_state_dict_hook(tensor_hook),
)

state_dict = get_model_state_dict(
state_dict_model,
Expand Down Expand Up @@ -690,10 +692,10 @@ def tensor_hook(

def _register_hf_model(
self,
state: State,
temp_save_dir: str,
original_tokenizer: PreTrainedTokenizerBase,
use_temp_dir: bool,
new_model_instance: PreTrainedModel,
):
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
Expand Down Expand Up @@ -834,12 +836,11 @@ def _save_checkpoint(
dist.barrier()

if dist.get_global_rank() == 0:
assert new_model_instance is not None
if register_to_mlflow:
self._register_hf_model(
state,
temp_save_dir,
original_tokenizer,
use_temp_dir,
temp_save_dir, original_tokenizer, use_temp_dir,
new_model_instance
)
else:
# Clean up the temporary directory if we don't need to register to mlflow.
Expand Down

0 comments on commit 9103dbd

Please sign in to comment.