diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index aa3beda513..3050529a5a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -4,13 +4,14 @@ import contextlib import copy import logging +import math import os import tempfile from pathlib import Path from typing import Optional, Union import torch -from composer.core import Callback, Event, State, Time +from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader @@ -83,6 +84,13 @@ def __init__( self.huggingface_folder_name_fstr = os.path.join( 'huggingface', huggingface_folder_name) + + if isinstance(save_interval, str): + save_interval = Time.from_timestring(save_interval) + if isinstance(save_interval, int): + save_interval = Time(save_interval, TimeUnit.EPOCH) + + self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) self.upload_to_object_store = (self.backend != '') @@ -128,6 +136,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set( '5GB') + def _is_last_batch(self, state: State): + elapsed_duration = state.get_elapsed_duration() + if elapsed_duration is not None and elapsed_duration >= 1.0: + return True + + assert state.max_duration is not None # for pyright + # If the save interval is specified as 1dur, and the max duration is in epoch units + # we need a special case to identify we are on the last batch and should write the mlflow checkpoint + if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH: + assert state.dataloader_len is not None # for pyright + return int(state.timestamp.batch) % math.ceil( + state.max_duration.value * state.dataloader_len) == 0 + + return False + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -224,8 +247,8 @@ def _save_checkpoint(self, state: State, logger: Logger): overwrite=self.overwrite, ) - elapsed_duration = state.get_elapsed_duration() - if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0: + if self.mlflow_registered_model_name and self._is_last_batch( + state): components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index fcb2cc3a7e..d2f203d3a0 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -251,25 +251,30 @@ def test_callback_inits_with_defaults(): @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) +@pytest.mark.parametrize( + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', + [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_state_dict_type: Optional[str], - log_to_mlflow: bool): + log_to_mlflow: bool, + hf_save_interval: str, + save_interval: str, max_duration: str, + expected_hf_checkpoints: int, + expected_normal_checkpoints: int): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) max_seq_len = 16 - save_interval_batches = 2 - huggingface_save_interval_batches = 3 device_batch_size = 1 dataset_size = 14 - max_duration_batches = 7 precision_str = 'bfloat16' precision = torch.bfloat16 + batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2)) checkpointer_callback = HuggingFaceCheckpointer( save_folder=os.path.join(tmp_path, 'checkpoints'), - save_interval=f'{huggingface_save_interval_batches}ba', + save_interval=hf_save_interval, precision=precision_str, mlflow_registered_model_name='dummy-registered-name' if log_to_mlflow else None, @@ -405,8 +410,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), - save_interval=f'{save_interval_batches}ba', - max_duration=f'{max_duration_batches}ba', + save_interval=save_interval, + max_duration=max_duration, callbacks=[checkpointer_callback], loggers=[mlflow_logger_mock] if log_to_mlflow else [], optimizers=optimizer, @@ -442,15 +447,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, name for name in os.listdir( os.path.join(tmp_path, 'checkpoints', 'huggingface')) ] - assert len(normal_checkpoints) == math.ceil(max_duration_batches / - save_interval_batches) - assert len(huggingface_checkpoints) == math.ceil( - max_duration_batches / huggingface_save_interval_batches) + assert len(normal_checkpoints) == expected_normal_checkpoints + assert len(huggingface_checkpoints) == expected_hf_checkpoints # Load the last huggingface checkpoint loaded_model = transformers.AutoModelForCausalLM.from_pretrained( os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{max_duration_batches}'), + f'ba{batches_per_epoch}'), trust_remote_code=True, ) @@ -471,7 +474,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( os.path.join(tmp_path, 'checkpoints', 'huggingface', - f'ba{max_duration_batches}'), + f'ba{batches_per_epoch}'), trust_remote_code=True, )