diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 688d8deb74..7ce9818426 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -743,7 +743,10 @@ def tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: - new_model_instance.save_pretrained(temp_save_dir) + new_model_instance.save_pretrained( + temp_save_dir, + max_shard_size='1GB', + ) if original_tokenizer is not None: assert isinstance( original_tokenizer, @@ -799,7 +802,10 @@ def tensor_hook( new_model_instance = self.transform_model_pre_registration( new_model_instance, ) - new_model_instance.save_pretrained(register_save_dir) + new_model_instance.save_pretrained( + register_save_dir, + max_shard_size='1GB', + ) if original_tokenizer: original_tokenizer.save_pretrained(register_save_dir)