diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index d86f27aa86..bd722e1cfb 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -268,7 +268,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: self._save_checkpoint( state, logger, - register_to_mflow=is_last_batch, + register_to_mflow=( + self.mlflow_registered_model_name is not None and + is_last_batch + ), upload_to_save_folder=not ( self.final_register_only and is_last_batch ), diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index b863e1d0a8..a4a9eb9ff4 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -317,6 +317,7 @@ class MockSpawnProcess: def __init__(self, target: Callable, kwargs: dict[str, Any]): self.target = target self.kwargs = kwargs + self.exitcode = 0 def start(self): self.target(**self.kwargs)