Skip to content

Commit

Permalink
Add configurability for hf checkpointer register timeout (#1599)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 17, 2024
1 parent 10f25ae commit 96f6419
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
final_register_only: bool = False,
register_wait_seconds: int = 7200,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -193,6 +194,7 @@ def __init__(
self.using_peft = False

self.final_register_only = final_register_only
self.register_wait_seconds = register_wait_seconds

self.mlflow_registered_model_name = mlflow_registered_model_name
if self.final_register_only and self.mlflow_registered_model_name is None:
Expand Down Expand Up @@ -325,7 +327,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.using_peft = composer_model.using_peft
elif event == Event.FIT_END:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
timeout = self.register_wait_seconds
wait_start = time.time()
while not self._all_register_processes_done(state.device):
wait_time = time.time() - wait_start
Expand Down

0 comments on commit 96f6419

Please sign in to comment.