From 96f6419f01145a6b0e568f2d5d9b764a6d9183d3 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 17 Oct 2024 12:19:41 -0700 Subject: [PATCH] Add configurability for hf checkpointer register timeout (#1599) --- llmfoundry/callbacks/hf_checkpointer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4365a5b2e5..2c4603ea87 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -180,6 +180,7 @@ def __init__( mlflow_logging_config: Optional[dict] = None, flatten_imports: Sequence[str] = ('llmfoundry',), final_register_only: bool = False, + register_wait_seconds: int = 7200, ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -193,6 +194,7 @@ def __init__( self.using_peft = False self.final_register_only = final_register_only + self.register_wait_seconds = register_wait_seconds self.mlflow_registered_model_name = mlflow_registered_model_name if self.final_register_only and self.mlflow_registered_model_name is None: @@ -325,7 +327,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self.using_peft = composer_model.using_peft elif event == Event.FIT_END: # Wait for all child processes spawned by the callback to finish. - timeout = 3600 + timeout = self.register_wait_seconds wait_start = time.time() while not self._all_register_processes_done(state.device): wait_time = time.time() - wait_start