diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 7aff80e50f..801813b3ff 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: stitch_turns_decoder_only, stitch_turns_encoder_decoder, ) +from llmfoundry.tokenizers import get_date_string # yapf: disable from llmfoundry.utils.exceptions import ( ALLOWED_MESSAGES_KEYS, @@ -249,11 +250,13 @@ def slice_out_last_turn( full_conversation = tokenizer.apply_chat_template( messages_through_current_turn, tokenize=False, + date_string=get_date_string(), ) prompt_with_history = tokenizer.apply_chat_template( messages_through_current_turn[:-1], tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( @@ -319,10 +322,24 @@ def _tokenize_with_bos_removal( ) # Remove the BOS token from the start of the labels if it was automatically added - if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: - if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ - 0] == tokenizer.bos_token_id: - tokenized_sample['labels'] = tokenized_sample['labels'][1:] + # Unfortunately if the tokenizer is PretrainedTokenizerFast, as llama3 is, it does not provide + # an add_bos_token attr that we can check explicitly, so instead we rely on checking if both the + # text and the text_target start with bos_token_id to determine whether to strip bos. + has_bos_token = hasattr( + tokenizer, + 'bos_token_id', + ) and tokenizer.bos_token_id is not None + input_ids_starts_with_bos = False + labels_starts_with_bos = False + if has_bos_token and len( + tokenized_sample['input_ids'], + ) > 0 and len(tokenized_sample['labels']) > 0: + input_ids_starts_with_bos = tokenized_sample['input_ids'][ + 0] == tokenizer.bos_token_id + labels_starts_with_bos = tokenized_sample['labels'][ + 0] == tokenizer.bos_token_id + if input_ids_starts_with_bos and labels_starts_with_bos: + tokenized_sample['labels'] = tokenized_sample['labels'][1:] return tokenized_sample @@ -642,6 +659,9 @@ def __init__( self.max_seq_len = max_seq_len self.packing_ratio = packing_ratio + def tokenize_example(self, example: Example) -> TokenizedExample: + return tokenize_formatted_example(example, self.tokenizer) + # How to process a sample def __getitem__(self, idx: int) -> dict[str, Any]: sample = super().__getitem__(idx) @@ -670,7 +690,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: ) # Convert to latest format by wrapping sample as a "turn" return {'turns': [sample]} - return tokenize_formatted_example(sample, tokenizer=self.tokenizer) + return self.tokenize_example(sample) def state_dict(self, num_samples: int, from_beginning: bool) -> dict[str, Any]: diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index 7dd92c7224..6b693f2d21 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -395,6 +395,9 @@ def build_inner_model( pretrained_lora_id_or_path, ) + if prepare_for_fsdp: + cls.prepare_inner_model(model, init_device) + return model def get_peft_config(self, peft_config_dict: dict[str, Any]) -> 'PeftConfig': diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index d37c12a555..6c580caf48 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -3,9 +3,11 @@ from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.tokenizers.utils import get_date_string tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) __all__ = [ 'TiktokenTokenizerWrapper', + 'get_date_string', ] diff --git a/llmfoundry/tokenizers/utils.py b/llmfoundry/tokenizers/utils.py new file mode 100644 index 0000000000..c087076771 --- /dev/null +++ b/llmfoundry/tokenizers/utils.py @@ -0,0 +1,13 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +__all__ = [ + 'get_date_string', +] + + +def get_date_string() -> str: + """Get the current date string.""" + return datetime.datetime.now().strftime('%d %b %Y') diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 89c73e5afc..cd4336c74a 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -19,6 +19,7 @@ TextStreamer, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.exceptions import ChatTemplateError DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str: chat_conversation, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None: tokenize=True, add_generation_prompt=True, return_tensors='pt', + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( diff --git a/setup.py b/setup.py index 229efe5b5b..a57a318119 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ 'mlflow>=2.14.1,<2.16', 'accelerate>=0.25,<0.34', # for HF inference `device_map` 'transformers>=4.43.2,<4.44', - 'mosaicml-streaming>=0.8.0,<0.9', + 'mosaicml-streaming>=0.8.1,<0.9', 'torch>=2.3.0,<2.4', 'datasets>=2.19,<2.20', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data @@ -87,7 +87,7 @@ 'pyright==1.1.256', 'toml>=0.10.2,<0.11', 'packaging>=21,<25', - 'hf_transfer==0.1.3', + 'hf_transfer==0.1.8', ] extra_deps['databricks'] = [ diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 9f44739b6b..fdf7233115 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -11,6 +11,7 @@ dataset_constructor, tokenize_formatted_example, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.builders import build_tokenizer from llmfoundry.utils.exceptions import ( ALLOWED_PROMPT_KEYS, @@ -210,10 +211,27 @@ def test_tokenize_instruct_example_well_formed(): @pytest.mark.parametrize( 'tokenizer_name', - ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'], + [ + 'EleutherAI/gpt-neox-20b', + 'HuggingFaceH4/zephyr-7b-beta', + 't5-base', + 'meta-llama/Meta-Llama-3.1-8B-Instruct', + ], ) @pytest.mark.parametrize('messages_format', [True, False]) -def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): +@pytest.mark.parametrize('use_date_string', [True, False]) +def test_multi_turn_chat_slicing( + tokenizer_name: str, + messages_format: bool, + use_date_string: bool, +): + if 'meta-llama' in tokenizer_name: + pytest.skip('Model is gated. Skipping test.') + is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name + if is_llama_3_1_instruct and use_date_string: + pytest.skip( + 'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.', + ) if messages_format: convo = [ { @@ -272,6 +290,10 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + # Manually set a chat template to test if the date_string is being used. + if use_date_string: + tok.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n" + templated_prompt_response_turns = _slice_chat_formatted_example( example, tok, @@ -281,9 +303,19 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): for prompt, response in templated_prompt_response_turns: reconstructed_chat += prompt + response - full_chat = tok.apply_chat_template(convo, tokenize=False) + date_string = get_date_string() + full_chat = tok.apply_chat_template( + convo, + tokenize=False, + date_string=date_string, + ) assert reconstructed_chat == full_chat + if is_llama_3_1_instruct or use_date_string: + assert date_string in full_chat + else: + assert date_string not in full_chat + def test_fail_chat_template(): convo = [ diff --git a/tests/models/hf/test_hf_base.py b/tests/models/hf/test_hf_base.py new file mode 100644 index 0000000000..4b0fb34e53 --- /dev/null +++ b/tests/models/hf/test_hf_base.py @@ -0,0 +1,25 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel + + +def test_build_inner_model_fsdp(): + model = BaseHuggingFaceModel.build_inner_model( + pretrained_model_name_or_path='codellama/CodeLlama-7b-hf', + pretrained_lora_id_or_path=None, + trust_remote_code=False, + init_device='cpu', + use_flash_attention_2=False, + use_auth_token=False, + config_overrides={ + 'num_hidden_layers': 2, + 'hidden_size': 32, + 'intermediate_size': 64, + }, + load_in_8bit=False, + pretrained=False, + prepare_for_fsdp=True, + ) + + assert model.fsdp_wrap_fn(model.model.layers[0]) diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 8a61c6124f..97255e9ebb 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -7,6 +7,7 @@ import pytest import transformers +from llmfoundry.tokenizers import get_date_string from llmfoundry.tokenizers.tiktoken import ( TiktokenTokenizerWrapper, bytes_to_unicode, @@ -516,6 +517,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i] # Using default system prompt. @@ -533,6 +535,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i] for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML): @@ -540,6 +543,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_GENERATE_STRING[i] diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index b4f1846091..d42f810214 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -1,9 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import pytest +import torch from omegaconf import OmegaConf as om from transformers import AutoTokenizer +from llmfoundry.tokenizers.utils import get_date_string + def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): with open(conf_path) as f: @@ -80,3 +84,49 @@ def test_load_tokenizer(): )['attention_mask'] attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4) assert attention_mask == attn_mask_key + + +@pytest.mark.parametrize( + 'tokenizer_name', + [ + 'EleutherAI/gpt-neox-20b', + 'meta-llama/Meta-Llama-3-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-70B-Instruct', + 'meta-llama/Meta-Llama-3.1-405B-Instruct', + ], +) +@pytest.mark.parametrize('use_date_string', [True, False]) +def test_tokenizer_date_string(tokenizer_name: str, use_date_string: bool): + if 'meta-llama' in tokenizer_name: + pytest.skip('Model is gated. Skipping test.') + + is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name + if is_llama_3_1_instruct and use_date_string: + pytest.skip( + 'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.', + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [{'role': 'system', 'content': ''}] + date_string = get_date_string() + + # Manually set a chat template to test if the date_string is being used. + if use_date_string: + tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n" + + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors='pt', + date_string=date_string, + ) + + assert isinstance(token_ids, torch.Tensor) + decoded_text = tokenizer.decode(token_ids.flatten()) + + # Only Llama 3.1 instruct family models should use the current date in their chat templates. + if is_llama_3_1_instruct or use_date_string: + assert date_string in decoded_text + else: + assert date_string not in decoded_text