From b9d1d89548e618567d8394b661d65ecd85175a23 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 21 Aug 2024 16:46:11 -0400 Subject: [PATCH] yo --- llmfoundry/data/finetuning/tasks.py | 3 ++ llmfoundry/tokenizers/__init__.py | 2 + llmfoundry/tokenizers/utils.py | 12 ++++++ scripts/inference/hf_chat.py | 3 ++ tests/data/test_template_tokenization.py | 7 +++- tests/tokenizers/test_tiktoken.py | 4 ++ tests/tokenizers/test_tokenizer.py | 50 ++++++++++++++++++++++++ 7 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 llmfoundry/tokenizers/utils.py diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 7aff80e50f..66c01ddec6 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: stitch_turns_decoder_only, stitch_turns_encoder_decoder, ) +from llmfoundry.tokenizers import get_date_string # yapf: disable from llmfoundry.utils.exceptions import ( ALLOWED_MESSAGES_KEYS, @@ -249,11 +250,13 @@ def slice_out_last_turn( full_conversation = tokenizer.apply_chat_template( messages_through_current_turn, tokenize=False, + date_string=get_date_string(), ) prompt_with_history = tokenizer.apply_chat_template( messages_through_current_turn[:-1], tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index d37c12a555..6c580caf48 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -3,9 +3,11 @@ from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.tokenizers.utils import get_date_string tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) __all__ = [ 'TiktokenTokenizerWrapper', + 'get_date_string', ] diff --git a/llmfoundry/tokenizers/utils.py b/llmfoundry/tokenizers/utils.py new file mode 100644 index 0000000000..5d3491804d --- /dev/null +++ b/llmfoundry/tokenizers/utils.py @@ -0,0 +1,12 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +__all__ = [ + 'get_date_string', +] + +def get_date_string() -> str: + """Get the current date string.""" + return datetime.datetime.now().strftime('%d %b %Y') diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 89c73e5afc..cd4336c74a 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -19,6 +19,7 @@ TextStreamer, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.exceptions import ChatTemplateError DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str: chat_conversation, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None: tokenize=True, add_generation_prompt=True, return_tensors='pt', + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 9f44739b6b..d63a5f58fb 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -11,6 +11,7 @@ dataset_constructor, tokenize_formatted_example, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.builders import build_tokenizer from llmfoundry.utils.exceptions import ( ALLOWED_PROMPT_KEYS, @@ -281,7 +282,11 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): for prompt, response in templated_prompt_response_turns: reconstructed_chat += prompt + response - full_chat = tok.apply_chat_template(convo, tokenize=False) + full_chat = tok.apply_chat_template( + convo, + tokenize=False, + date_string=get_date_string(), + ) assert reconstructed_chat == full_chat diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 8a61c6124f..97255e9ebb 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -7,6 +7,7 @@ import pytest import transformers +from llmfoundry.tokenizers import get_date_string from llmfoundry.tokenizers.tiktoken import ( TiktokenTokenizerWrapper, bytes_to_unicode, @@ -516,6 +517,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i] # Using default system prompt. @@ -533,6 +535,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i] for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML): @@ -540,6 +543,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_GENERATE_STRING[i] diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index b4f1846091..44add81112 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -1,9 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import pytest +import torch from omegaconf import OmegaConf as om from transformers import AutoTokenizer +from llmfoundry.tokenizers.utils import get_date_string + def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): with open(conf_path) as f: @@ -80,3 +84,49 @@ def test_load_tokenizer(): )['attention_mask'] attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4) assert attention_mask == attn_mask_key + + +@pytest.mark.parametrize( + 'tokenizer_name', + [ + 'EleutherAI/gpt-neox-20b', + 'meta-llama/Meta-Llama-3-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-70B-Instruct', + 'meta-llama/Meta-Llama-3.1-405B-Instruct', + ], +) +@pytest.mark.parametrize('spoof_chat_template', [True, False]) +def test_tokenizer_date_string(tokenizer_name: str, spoof_chat_template: bool): + if 'meta-llama' in tokenizer_name: + pytest.skip('Model is gated. Skipping test.') + + is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name + if is_llama_3_1_instruct and spoof_chat_template: + pytest.skip( + 'Llama 3.1 Instruct models use date_string in chat template, so no need to spoof. Skipping test.', + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [{'role': 'system', 'content': ''}] + date_string = get_date_string() + + # Manually set a chat template to test if the date_string is being used. + if spoof_chat_template: + tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n" + + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors='pt', + date_string=date_string, + ) + + assert isinstance(token_ids, torch.Tensor) + decoded_text = tokenizer.decode(token_ids.flatten()) + + # Only Llama 3.1 instruct family models should use the current date in their chat templates. + if is_llama_3_1_instruct or spoof_chat_template: + assert date_string in decoded_text + else: + assert date_string not in decoded_text