From 30544f0844d4aca4169aaac6e819939abba52a17 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:45:37 -0700 Subject: [PATCH] Add a callback to write huggingface checkpoints during the training run (#594) --- llmfoundry/callbacks/__init__.py | 13 +- llmfoundry/callbacks/hf_checkpointer.py | 167 ++++++++++++ llmfoundry/utils/builders.py | 6 +- tests/__init__.py | 2 + tests/data_utils.py | 67 +++++ tests/test_dataloader.py | 63 +---- tests/test_hf_conversion_script.py | 345 ++++++++++++++++++++++++ 7 files changed, 596 insertions(+), 67 deletions(-) create mode 100644 llmfoundry/callbacks/hf_checkpointer.py create mode 100644 tests/__init__.py create mode 100644 tests/data_utils.py diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 465bdd01d1..62ffcd565c 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -5,6 +5,7 @@ from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.generate_callback import Generate + from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet from llmfoundry.callbacks.monolithic_ckpt_callback import \ MonolithicCheckpointSaver @@ -18,7 +19,13 @@ ) from e __all__ = [ - 'FDiffMetrics', 'Generate', 'MonolithicCheckpointSaver', 'GlobalLRScaling', - 'LayerFreezing', 'ScheduledGarbageCollector', 'EvalGauntlet', - 'ModelGauntlet' + 'FDiffMetrics', + 'Generate', + 'MonolithicCheckpointSaver', + 'GlobalLRScaling', + 'LayerFreezing', + 'ScheduledGarbageCollector', + 'EvalGauntlet', + 'ModelGauntlet', + 'HuggingFaceCheckpointer', ] diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py new file mode 100644 index 0000000000..fe3028ab19 --- /dev/null +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -0,0 +1,167 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Optional, Union + +import torch +from composer.callbacks.utils import create_interval_scheduler +from composer.core import Callback, Event, State, Time +from composer.core.state import fsdp_state_dict_type_context +from composer.loggers import Logger +from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader +from composer.models import HuggingFaceModel +from composer.utils import dist, format_name_with_dist_and_time, parse_uri +from transformers import PreTrainedTokenizerBase + +from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM +from llmfoundry.utils.huggingface_hub_utils import \ + edit_files_for_hf_compatibility + +log = logging.getLogger(__name__) + + +class HuggingFaceCheckpointer(Callback): + """Save a huggingface formatted checkpoint during training. + + Args: + save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that + this would be the same as your save_folder. + save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be + saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. + Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, + :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. + huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. + precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. + overwrite (bool): Whether to overwrite previous checkpoints. + """ + + def __init__( + self, + save_folder: str, + save_interval: Union[str, int, Time], + huggingface_folder_name: str = 'ba{batch}', + precision: str = 'fp32', + overwrite: bool = False, + ): + self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( + save_folder) + self.overwrite = overwrite + self.precision = precision + self.dtype = { + 'float32': torch.float32, + 'float16': torch.float16, + 'bfloat16': torch.bfloat16, + }[precision] + self.huggingface_folder_name_fstr = os.path.join( + 'huggingface', huggingface_folder_name) + self.check_interval = create_interval_scheduler( + save_interval, include_end_of_training=True) + self.upload_to_object_store = (self.backend != '') + if self.upload_to_object_store: + self.remote_ud = RemoteUploaderDownloader( + bucket_uri=f'{self.backend}://{self.bucket_name}', + num_concurrent_uploads=4) + else: + self.remote_ud = None + + self.last_checkpoint_batch: Optional[Time] = None + + def run_event(self, event: Event, state: State, logger: Logger) -> None: + # The interval scheduler handles only returning True for the appropriate events + if state.get_elapsed_duration() is not None and self.check_interval( + state, + event) and self.last_checkpoint_batch != state.timestamp.batch: + self._save_checkpoint(state, logger) + elif event == Event.INIT: + if not isinstance(state.model, HuggingFaceModel): + raise ValueError( + f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + + f'Got {type(state.model)} instead.') + if self.upload_to_object_store and self.remote_ud is not None: + self.remote_ud.init(state, logger) + state.callbacks.append(self.remote_ud) + + def _save_checkpoint(self, state: State, logger: Logger): + del logger # unused + + self.last_checkpoint_batch = state.timestamp.batch + + log.info('Saving HuggingFace formatted checkpoint') + + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + CONFIG_MAPPING._extra_content['mpt'] = MPTConfig + MPTConfig.register_for_auto_class() + MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') + + assert isinstance(state.model, HuggingFaceModel) + + save_dir = format_name_with_dist_and_time( + str( + Path(self.save_dir_format_str) / + self.huggingface_folder_name_fstr), state.run_name, + state.timestamp) + dir_context_mgr = tempfile.TemporaryDirectory( + ) if self.upload_to_object_store else contextlib.nullcontext( + enter_result=save_dir) + + with dir_context_mgr as temp_save_dir: + assert isinstance(temp_save_dir, + str) # pyright doesn't know about enter_result + + with fsdp_state_dict_type_context(state.model.model, + state_dict_type='full'): + state_dict = state.model.model.state_dict() + + # convert the state dict to the requested precision + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + state_dict[k] = v.to(dtype=self.dtype) + + if dist.get_global_rank() == 0: + # We raise above if the model is not a HuggingFaceModel, so this assert is safe + assert hasattr(state.model.model, 'save_pretrained') + state.model.model.save_pretrained(temp_save_dir, + state_dict=state_dict) + + if state.model.tokenizer is not None: + assert isinstance(state.model.tokenizer, + PreTrainedTokenizerBase) + state.model.tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if state.model.model.config.model_type == 'mpt': + edit_files_for_hf_compatibility(temp_save_dir) + + with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f: + edited_config = json.load(f) + + if state.model.model.config.model_type == 'mpt': + edited_config['attn_config']['attn_impl'] = 'torch' + edited_config['init_device'] = 'cpu' + + edited_config['torch_dtype'] = self.precision + with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f: + json.dump(edited_config, f, indent=4) + + if self.upload_to_object_store: + assert self.remote_ud is not None + # TODO change to log after other pr + log.info( + f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' + ) + for filename in os.listdir(temp_save_dir): + self.remote_ud.upload_file( + state=state, + remote_file_name=os.path.join(save_dir, filename), + file_path=Path(os.path.join(temp_save_dir, + filename)), + overwrite=self.overwrite, + ) + + dist.barrier() diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c0eb2a59df..32f7aceea3 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -27,8 +27,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, Generate, - GlobalLRScaling, LayerFreezing, - MonolithicCheckpointSaver, + GlobalLRScaling, HuggingFaceCheckpointer, + LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) @@ -99,6 +99,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return ScheduledGarbageCollector(**kwargs) elif name == 'early_stopper': return EarlyStopper(**kwargs) + elif name == 'hf_checkpointer': + return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/data_utils.py b/tests/data_utils.py new file mode 100644 index 0000000000..075933de7d --- /dev/null +++ b/tests/data_utils.py @@ -0,0 +1,67 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from typing import Optional + + +def make_tiny_ft_dataset( + path: str, + size: int = 4, + add_bad_data_dropped: bool = False, + add_bad_data_error: bool = False, + add_just_bos_eos_pad: bool = False, + pad_token: Optional[str] = None, + start_token: Optional[str] = None, + end_token: Optional[str] = None, +): + good_sample = {'prompt': 'hello', 'response': 'goodbye'} + samples = [good_sample] * size + if add_bad_data_dropped: + if pad_token is None: + raise ValueError( + 'pad_token, start_token, and end_token must be specified if add_bad_data is True' + ) + # empty prompt + samples.append({'prompt': '', 'response': 'goodbye'}) + # empty response + samples.append({'prompt': 'hello', 'response': ''}) + # response just pad + samples.append({'prompt': 'hello', 'response': pad_token}) + # response just pad multiple times + samples.append({'prompt': 'hello', 'response': pad_token * 3}) + + if add_bad_data_error: + # prompt just None + samples.append({ + 'prompt': None, + 'response': 'goodbye' + }) # type: ignore (intentional test) + # response just None + samples.append({ + 'prompt': 'hello', + 'response': None + }) # type: ignore (intentional test) + + if add_just_bos_eos_pad: + if pad_token is None or start_token is None or end_token is None: + raise ValueError( + 'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True' + ) + # prompt just start + samples.append({'prompt': start_token, 'response': 'goodbye'}) + # response just start + samples.append({'prompt': 'hello', 'response': start_token}) + # prompt just end + samples.append({'prompt': end_token, 'response': 'goodbye'}) + # response just end + samples.append({'prompt': 'hello', 'response': end_token}) + # prompt just pad + samples.append({'prompt': pad_token, 'response': 'goodbye'}) + + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'w') as _f: + for sample in samples: + _f.write(json.dumps(sample)) + _f.write('\n') diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index f9a281efa7..53549ccfe1 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import contextlib -import json import os import pathlib import shutil @@ -25,6 +24,7 @@ repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(repo_dir) from scripts.data_prep.convert_dataset_hf import main as main_hf +from tests.data_utils import make_tiny_ft_dataset def get_config(conf_path: str = 'yamls/mpt/125m.yaml'): @@ -279,67 +279,6 @@ def test_finetuning_dataloader(decoder_only_format: bool, break -def make_tiny_ft_dataset( - path: str, - size: int = 4, - add_bad_data_dropped: bool = False, - add_bad_data_error: bool = False, - add_just_bos_eos_pad: bool = False, - pad_token: Optional[str] = None, - start_token: Optional[str] = None, - end_token: Optional[str] = None, -): - good_sample = {'prompt': 'hello', 'response': 'goodbye'} - samples = [good_sample] * size - if add_bad_data_dropped: - if pad_token is None: - raise ValueError( - 'pad_token, start_token, and end_token must be specified if add_bad_data is True' - ) - # empty prompt - samples.append({'prompt': '', 'response': 'goodbye'}) - # empty response - samples.append({'prompt': 'hello', 'response': ''}) - # response just pad - samples.append({'prompt': 'hello', 'response': pad_token}) - # response just pad multiple times - samples.append({'prompt': 'hello', 'response': pad_token * 3}) - - if add_bad_data_error: - # prompt just None - samples.append({ - 'prompt': None, - 'response': 'goodbye' - }) # type: ignore (intentional test) - # response just None - samples.append({ - 'prompt': 'hello', - 'response': None - }) # type: ignore (intentional test) - - if add_just_bos_eos_pad: - if pad_token is None or start_token is None or end_token is None: - raise ValueError( - 'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True' - ) - # prompt just start - samples.append({'prompt': start_token, 'response': 'goodbye'}) - # response just start - samples.append({'prompt': 'hello', 'response': start_token}) - # prompt just end - samples.append({'prompt': end_token, 'response': 'goodbye'}) - # response just end - samples.append({'prompt': 'hello', 'response': end_token}) - # prompt just pad - samples.append({'prompt': pad_token, 'response': 'goodbye'}) - - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'w') as _f: - for sample in samples: - _f.write(json.dumps(sample)) - _f.write('\n') - - @pytest.mark.world_size(2) @pytest.mark.parametrize('dataset_size', [4, 8]) @pytest.mark.parametrize('device_batch_size', [2, 4]) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index e16832d803..2a175a04e9 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -1,12 +1,15 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import math import os import pathlib import sys from composer import Trainer +from composer.utils import dist, get_device +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM # Add repo root to path so we can import scripts and test it @@ -21,9 +24,134 @@ import transformers from omegaconf import DictConfig from omegaconf import OmegaConf as om +from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.data.finetuning import build_finetuning_dataloader +from llmfoundry.utils.builders import build_optimizer, build_tokenizer from scripts.inference.convert_composer_to_hf import convert_composer_to_hf +from tests.data_utils import make_tiny_ft_dataset + + +def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase, + tokenizer2: PreTrainedTokenizerBase): + """WARNING: Parameters are updated within the check so don't call check_hf_tokenizer_equivalence on the same + + params more than once + + This is a best effort attempt to compare two tokenizers for equivalence + + This is not a perfect test, but it should catch most issues. We first check that the vocab is identical + and that a string is tokenized the same one. Then we compare the __dict__ of the tokenizers, but we remove + some keys that are not important for equivalence. See the inline explanations for each one. + """ + if hasattr(tokenizer1, 'vocab') or hasattr(tokenizer2, 'vocab'): + assert tokenizer1.vocab == tokenizer2.vocab + + # we only care about the file and class name, not the full import path + assert str(type(tokenizer1)).split('.')[-2:] == str( + type(tokenizer2)).split('.')[-2:] + + expected_tokenizer_output = tokenizer2( + 'This is some text that should get tokenizer !? @ totallyarealtoken') + actual_tokenizer_output = tokenizer1( + 'This is some text that should get tokenizer !? @ totallyarealtoken') + assert expected_tokenizer_output == actual_tokenizer_output + + # we remove the actual _tokenizer object because it is an instantiated object and so does not pass equality + # the tokenizers are not usable below these pops + if hasattr(tokenizer1, '_tokenizer') or hasattr(tokenizer2, '_tokenizer'): + tokenizer1.__dict__.pop('_tokenizer') + tokenizer2.__dict__.pop('_tokenizer') + + # we remove a couple more objects because they are instantiated objects and so do not pass equality + if hasattr(tokenizer1, 'sp_model') or hasattr(tokenizer2, 'sp_model'): + tokenizer1.__dict__.pop('sp_model') + tokenizer2.__dict__.pop('sp_model') + + if hasattr(tokenizer1, 'tokens_trie') or hasattr(tokenizer2, 'tokens_trie'): + tokenizer1.__dict__.pop('tokens_trie') + tokenizer2.__dict__.pop('tokens_trie') + + # extra key that is not important + if hasattr(tokenizer1, 'deprecation_warnings') or hasattr( + tokenizer2, 'deprecation_warnings'): + tokenizer1.__dict__.pop('deprecation_warnings') + tokenizer2.__dict__.pop('deprecation_warnings') + + # name_or_path will be the path that the tokenizer was loaded from, which will just be a temporary directory for + # the reloaded tokenizer, so we remove it and don't compare it between the two tokenizers + tokenizer1.__dict__.pop('name_or_path') + tokenizer2.__dict__.pop('name_or_path') + tokenizer1.init_kwargs.pop('name_or_path', None) + tokenizer2.init_kwargs.pop('name_or_path', None) + + # The init_kwargs are not always the same between initial load and reload, even though the tokenizers are the same + # and have the attributes set correctly. This section removes the keys that are different, only checking for equality if they + # are present in both tokenizers + model_max_length_1 = tokenizer1.init_kwargs.get('model_max_length', None) + model_max_length_2 = tokenizer2.init_kwargs.get('model_max_length', None) + if model_max_length_1 is not None and model_max_length_2 is not None: + assert model_max_length_1 == model_max_length_2 + tokenizer1.__dict__['init_kwargs'].pop('model_max_length', None) + tokenizer2.__dict__['init_kwargs'].pop('model_max_length', None) + + spaces_1 = tokenizer1.init_kwargs.get('clean_up_tokenization_spaces', None) + spaces_2 = tokenizer2.init_kwargs.get('clean_up_tokenization_spaces', None) + if spaces_1 is not None and spaces_2 is not None: + assert spaces_1 == spaces_2 + tokenizer1.__dict__['init_kwargs'].pop('clean_up_tokenization_spaces', None) + tokenizer2.__dict__['init_kwargs'].pop('clean_up_tokenization_spaces', None) + + tokenizer1.__dict__['init_kwargs'].pop('special_tokens_map_file', None) + tokenizer2.__dict__['init_kwargs'].pop('special_tokens_map_file', None) + + # tokenizer.init_kwargs['tokenizer_file'] is unset when the tokenizer does not specify it, but is set to + # None when you save and reload, so here we just check that its the same if it is present in both tokenizers. + tokenizer_file_1 = tokenizer1.init_kwargs.get('tokenizer_file', None) + tokenizer_file_2 = tokenizer2.init_kwargs.get('tokenizer_file', None) + if tokenizer_file_1 is not None or tokenizer_file_2 is not None: + assert tokenizer_file_1 == tokenizer_file_2 + + tokenizer1.__dict__['init_kwargs'].pop('tokenizer_file', None) + tokenizer2.__dict__['init_kwargs'].pop('tokenizer_file', None) + tokenizer1.__dict__['init_kwargs'].pop('vocab_file', None) + tokenizer2.__dict__['init_kwargs'].pop('vocab_file', None) + + # vocab_file will be the path that the tokenizer was loaded from, which will just be a temporary directory for + # the reloaded tokenizer, so we remove it and don't compare it between the two tokenizers + tokenizer1.__dict__.pop('vocab_file', None) + tokenizer2.__dict__.pop('vocab_file', None) + tokenizer1.__dict__.pop('special_tokens_map_file', None) + tokenizer2.__dict__.pop('special_tokens_map_file', None) + + # The tokenizer name is changed in transformers 4.31 when changing the tokenizer mapping, so we remove it and compare + # if necessary. Checks whether the names are subsets of each other. + tokenizer1_name = tokenizer1.__dict__['init_kwargs'].get( + 'auto_map', {}).get('AutoTokenizer', [None])[0] + tokenizer2_name = tokenizer2.__dict__['init_kwargs'].get( + 'auto_map', {}).get('AutoTokenizer', [None])[0] + if tokenizer1_name is not None and tokenizer2_name is not None: + assert tokenizer1_name in tokenizer2_name or tokenizer2_name in tokenizer1_name + tokenizer1.__dict__['init_kwargs'].pop('auto_map', None) + tokenizer2.__dict__['init_kwargs'].pop('auto_map', None) + + assert tokenizer1.__dict__ == tokenizer2.__dict__ + + +def check_hf_model_equivalence(model1: PreTrainedModel, + model2: PreTrainedModel): + expected_model_config_dict = model1.config.to_dict() + new_model_config_dict = model2.config.to_dict() + + # _name_or_path is different depending on whether the model was loaded from disk or the hub, + # so we remove it + expected_model_config_dict.pop('_name_or_path') + new_model_config_dict.pop('_name_or_path') + assert expected_model_config_dict == new_model_config_dict + assert all( + torch.equal(p1.cpu(), p2.cpu()) + for p1, p2 in zip(model1.parameters(), model2.parameters())) def delete_transformers_cache(): @@ -48,6 +176,223 @@ def get_config( return cast(DictConfig, test_cfg) +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) +@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded']) +def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, + fsdp_state_dict_type: str): + delete_transformers_cache() + + dist.initialize_dist(get_device('gpu')) + + max_seq_len = 16 + save_interval_batches = 2 + huggingface_save_interval_batches = 3 + device_batch_size = 1 + dataset_size = 14 + max_duration_batches = 7 + precision_str = 'bfloat16' + precision = torch.bfloat16 + + checkpointer_callback = HuggingFaceCheckpointer( + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval=f'{huggingface_save_interval_batches}ba', + precision=precision_str, + ) + + # get small version of each model + model_cfg = None + tokenizer_name = None + if model == 'mpt': + model_cfg = { + 'name': 'mpt_causal_lm', + 'init_device': 'cpu', + 'd_model': 128, + 'n_heads': 2, + 'n_layers': 2, + 'expansion_ratio': 4, + 'max_seq_len': max_seq_len, + 'vocab_size': 50368, + 'attn_config': { + 'attn_impl': 'torch', + }, + 'loss_fn': 'torch_crossentropy', + } + tokenizer_name = 'EleutherAI/gpt-neox-20b' + elif model == 'neo': + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M', + 'config_overrides': { + 'max_position_embeddings': max_seq_len, + 'hidden_size': 36, + }, + 'pretrained': False, + 'init_device': 'cpu', + } + tokenizer_name = 'EleutherAI/gpt-neo-125M' + elif model == 'llama2': + if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: + pytest.skip( + 'The CI cluster does not have access to the Llama models, so skip this test.' + ) + model_cfg = { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'meta-llama/Llama-2-7b-hf', + 'config_overrides': { + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + 'use_auth_token': True, + 'pretrained': False, + 'init_device': 'cpu', + } + tokenizer_name = 'meta-llama/Llama-2-7b-hf' + else: + raise ValueError(f'Unknown model {model}') + assert model_cfg is not None + assert tokenizer_name is not None + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'sharding_strategy': 'FULL_SHARD', + 'mixed_precision': 'PURE', + 'activation_checkpointing': False, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + 'limit_all_gathers': True, + 'state_dict_type': fsdp_state_dict_type, + } + + tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer_name, use_auth_token=model == 'llama2') + + tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small') + tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl') + if dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size) + + dataloader_cfg = { + 'name': 'finetuning', + 'dataset': { + 'hf_name': tiny_dataset_folder_path, + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + } + + dataloader_cfg = om.create(dataloader_cfg) + + tokenizer = build_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_kwargs={'model_max_length': max_seq_len}, + ) + + train_dataloader = build_finetuning_dataloader( + dataloader_cfg, + tokenizer, + device_batch_size, + ) + + original_model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, + tokenizer) + + optimizer_config = { + 'name': 'decoupled_adamw', + 'lr': 6e-4, + 'betas': [0.9, 0.95], + 'eps': 1e-8, + 'weight_decay': 0.0, + } + optimizer_name = optimizer_config.pop('name') + optimizer = build_optimizer(original_model, optimizer_name, + optimizer_config) + + trainer = Trainer( + model=original_model, + device='gpu', + fsdp_config=fsdp_config, + train_dataloader=train_dataloader, + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval=f'{save_interval_batches}ba', + max_duration=f'{max_duration_batches}ba', + callbacks=[checkpointer_callback], + optimizers=optimizer, + save_latest_filename=None, + ) + trainer.fit() + + # summon full params to check equivalence + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + with FSDP.summon_full_params(trainer.state.model, + writeback=False, + recurse=True): + loaded_model = None + loaded_tokenizer = None + # Only rank zero is saving the huggingface checkpoints, so only check + # for equivalence on rank zero + if dist.get_global_rank() == 0: + normal_checkpoints = [ + name + for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) + if name != 'huggingface' + ] + huggingface_checkpoints = [ + name for name in os.listdir( + os.path.join(tmp_path, 'checkpoints', 'huggingface')) + ] + assert len(normal_checkpoints) == math.ceil(max_duration_batches / + save_interval_batches) + assert len(huggingface_checkpoints) == math.ceil( + max_duration_batches / huggingface_save_interval_batches) + + # Load the last huggingface checkpoint + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{max_duration_batches}'), + trust_remote_code=True, + ) + + # Check that the loaded model has the correct precision, and then set it back + # to the original for the equivalence check + assert loaded_model.config.torch_dtype == precision + loaded_model.config.torch_dtype = original_model.model.config.torch_dtype + + if model == 'mpt': + # Check that we have correctly set these attributes, and then set them back + # to the original for the equivalence check + assert loaded_model.config.attn_config['attn_impl'] == 'torch' + assert loaded_model.config.init_device == 'cpu' + loaded_model.config.attn_config[ + 'attn_impl'] = original_model.model.config.attn_config[ + 'attn_impl'] + loaded_model.config.init_device = original_model.model.config.init_device + + loaded_tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path, 'checkpoints', 'huggingface', + f'ba{max_duration_batches}'), + trust_remote_code=True, + ) + + check_hf_model_equivalence(trainer.state.model.model.to(precision), + loaded_model) + check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer) + + delete_transformers_cache() + + @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) def test_convert_and_generate(model: str, tmp_path: pathlib.Path): delete_transformers_cache()