diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e0e2d2be3a..2e25141fe3 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -788,6 +788,7 @@ def _save_checkpoint( dist.barrier() if dist.get_global_rank() == 0: + assert new_model_instance is not None if upload_to_save_folder: # This context manager casts the TE extra state in io.BytesIO format to tensor format # Needed for proper hf ckpt saving.