From e235f4200083413f854cf4209c84089fa2d44c6e Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Wed, 21 Aug 2024 22:24:12 -0700 Subject: [PATCH] Add date_string when applying tokenizer chat template (#1474) --- llmfoundry/data/finetuning/tasks.py | 3 ++ llmfoundry/tokenizers/__init__.py | 2 + llmfoundry/tokenizers/utils.py | 13 ++++++ scripts/inference/hf_chat.py | 3 ++ tests/data/test_template_tokenization.py | 38 ++++++++++++++++-- tests/tokenizers/test_tiktoken.py | 4 ++ tests/tokenizers/test_tokenizer.py | 50 ++++++++++++++++++++++++ 7 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 llmfoundry/tokenizers/utils.py diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e42c07b2b1..aaaa5e145a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: stitch_turns_decoder_only, stitch_turns_encoder_decoder, ) +from llmfoundry.tokenizers import get_date_string # yapf: disable from llmfoundry.utils.exceptions import ( ALLOWED_MESSAGES_KEYS, @@ -249,11 +250,13 @@ def slice_out_last_turn( full_conversation = tokenizer.apply_chat_template( messages_through_current_turn, tokenize=False, + date_string=get_date_string(), ) prompt_with_history = tokenizer.apply_chat_template( messages_through_current_turn[:-1], tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py index d37c12a555..6c580caf48 100644 --- a/llmfoundry/tokenizers/__init__.py +++ b/llmfoundry/tokenizers/__init__.py @@ -3,9 +3,11 @@ from llmfoundry.registry import tokenizers from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.tokenizers.utils import get_date_string tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper) __all__ = [ 'TiktokenTokenizerWrapper', + 'get_date_string', ] diff --git a/llmfoundry/tokenizers/utils.py b/llmfoundry/tokenizers/utils.py new file mode 100644 index 0000000000..c087076771 --- /dev/null +++ b/llmfoundry/tokenizers/utils.py @@ -0,0 +1,13 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +__all__ = [ + 'get_date_string', +] + + +def get_date_string() -> str: + """Get the current date string.""" + return datetime.datetime.now().strftime('%d %b %Y') diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 89c73e5afc..cd4336c74a 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -19,6 +19,7 @@ TextStreamer, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.exceptions import ChatTemplateError DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.' @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str: chat_conversation, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None: tokenize=True, add_generation_prompt=True, return_tensors='pt', + date_string=get_date_string(), ) except Exception as e: raise ChatTemplateError( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 9f44739b6b..fdf7233115 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -11,6 +11,7 @@ dataset_constructor, tokenize_formatted_example, ) +from llmfoundry.tokenizers import get_date_string from llmfoundry.utils.builders import build_tokenizer from llmfoundry.utils.exceptions import ( ALLOWED_PROMPT_KEYS, @@ -210,10 +211,27 @@ def test_tokenize_instruct_example_well_formed(): @pytest.mark.parametrize( 'tokenizer_name', - ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'], + [ + 'EleutherAI/gpt-neox-20b', + 'HuggingFaceH4/zephyr-7b-beta', + 't5-base', + 'meta-llama/Meta-Llama-3.1-8B-Instruct', + ], ) @pytest.mark.parametrize('messages_format', [True, False]) -def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): +@pytest.mark.parametrize('use_date_string', [True, False]) +def test_multi_turn_chat_slicing( + tokenizer_name: str, + messages_format: bool, + use_date_string: bool, +): + if 'meta-llama' in tokenizer_name: + pytest.skip('Model is gated. Skipping test.') + is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name + if is_llama_3_1_instruct and use_date_string: + pytest.skip( + 'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.', + ) if messages_format: convo = [ { @@ -272,6 +290,10 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + # Manually set a chat template to test if the date_string is being used. + if use_date_string: + tok.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n" + templated_prompt_response_turns = _slice_chat_formatted_example( example, tok, @@ -281,9 +303,19 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool): for prompt, response in templated_prompt_response_turns: reconstructed_chat += prompt + response - full_chat = tok.apply_chat_template(convo, tokenize=False) + date_string = get_date_string() + full_chat = tok.apply_chat_template( + convo, + tokenize=False, + date_string=date_string, + ) assert reconstructed_chat == full_chat + if is_llama_3_1_instruct or use_date_string: + assert date_string in full_chat + else: + assert date_string not in full_chat + def test_fail_chat_template(): convo = [ diff --git a/tests/tokenizers/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py index 8a61c6124f..97255e9ebb 100644 --- a/tests/tokenizers/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -7,6 +7,7 @@ import pytest import transformers +from llmfoundry.tokenizers import get_date_string from llmfoundry.tokenizers.tiktoken import ( TiktokenTokenizerWrapper, bytes_to_unicode, @@ -516,6 +517,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i] # Using default system prompt. @@ -533,6 +535,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=False, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i] for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML): @@ -540,6 +543,7 @@ def test_chat_formatting( dict_chats, tokenize=False, add_generation_prompt=True, + date_string=get_date_string(), ) assert chat_str == MULTI_TURN_GENERATE_STRING[i] diff --git a/tests/tokenizers/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py index b4f1846091..d42f810214 100644 --- a/tests/tokenizers/test_tokenizer.py +++ b/tests/tokenizers/test_tokenizer.py @@ -1,9 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import pytest +import torch from omegaconf import OmegaConf as om from transformers import AutoTokenizer +from llmfoundry.tokenizers.utils import get_date_string + def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): with open(conf_path) as f: @@ -80,3 +84,49 @@ def test_load_tokenizer(): )['attention_mask'] attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4) assert attention_mask == attn_mask_key + + +@pytest.mark.parametrize( + 'tokenizer_name', + [ + 'EleutherAI/gpt-neox-20b', + 'meta-llama/Meta-Llama-3-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-8B-Instruct', + 'meta-llama/Meta-Llama-3.1-70B-Instruct', + 'meta-llama/Meta-Llama-3.1-405B-Instruct', + ], +) +@pytest.mark.parametrize('use_date_string', [True, False]) +def test_tokenizer_date_string(tokenizer_name: str, use_date_string: bool): + if 'meta-llama' in tokenizer_name: + pytest.skip('Model is gated. Skipping test.') + + is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name + if is_llama_3_1_instruct and use_date_string: + pytest.skip( + 'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.', + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + messages = [{'role': 'system', 'content': ''}] + date_string = get_date_string() + + # Manually set a chat template to test if the date_string is being used. + if use_date_string: + tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n" + + token_ids = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors='pt', + date_string=date_string, + ) + + assert isinstance(token_ids, torch.Tensor) + decoded_text = tokenizer.decode(token_ids.flatten()) + + # Only Llama 3.1 instruct family models should use the current date in their chat templates. + if is_llama_3_1_instruct or use_date_string: + assert date_string in decoded_text + else: + assert date_string not in decoded_text