Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 6, 2024
1 parent cdd3f7e commit 10f65f8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 5 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ def __init__(
}
is_chat = mlflow_logging_config['task'].endswith('chat') or (
mlflow_logging_config['metadata'] is not None and
mlflow_logging_config['metadata'].get('task',
'').endswith('chat')
mlflow_logging_config['metadata'].get('task', '').endswith('chat')
)
if is_chat:
default_input_example = {
Expand Down Expand Up @@ -395,8 +394,10 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
event,
) and self.last_checkpoint_batch != state.timestamp.batch:
is_last_batch = self._is_last_batch(state)
register = self.mlflow_registered_model_name is not None and is_last_batch # Register only on the last batch
upload_to_save_folder = self.save_folder is not None and (not self.final_register_only or not is_last_batch)
register = self.mlflow_registered_model_name is not None and is_last_batch # Register only on the last batch
upload_to_save_folder = self.save_folder is not None and (
not self.final_register_only or not is_last_batch
)
self._save_checkpoint(
state,
logger,
Expand Down
6 changes: 4 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def test_huggingface_conversion_callback_interval(
batches_per_epoch = math.ceil(dataset_size / device_batch_size)

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints') if hf_save_folder else None,
save_folder=os.path.join(tmp_path, 'checkpoints')
if hf_save_folder else None,
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
Expand All @@ -544,7 +545,8 @@ def test_huggingface_conversion_callback_interval(
save_interval=save_interval,
max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow or not hf_save_folder else [],
loggers=[mlflow_logger_mock]
if log_to_mlflow or not hf_save_folder else [],
optimizers=optimizer,
save_latest_filename=None,
)
Expand Down

0 comments on commit 10f65f8

Please sign in to comment.