Skip to content

Commit

Permalink
Add date_string when applying tokenizer chat template (#1474)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Aug 22, 2024
1 parent e84d97e commit e235f42
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 3 deletions.
3 changes: 3 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stitch_turns_decoder_only,
stitch_turns_encoder_decoder,
)
from llmfoundry.tokenizers import get_date_string
# yapf: disable
from llmfoundry.utils.exceptions import (
ALLOWED_MESSAGES_KEYS,
Expand Down Expand Up @@ -249,11 +250,13 @@ def slice_out_last_turn(
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
date_string=get_date_string(),
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from llmfoundry.registry import tokenizers
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.tokenizers.utils import get_date_string

tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper)

__all__ = [
'TiktokenTokenizerWrapper',
'get_date_string',
]
13 changes: 13 additions & 0 deletions llmfoundry/tokenizers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import datetime

__all__ = [
'get_date_string',
]


def get_date_string() -> str:
"""Get the current date string."""
return datetime.datetime.now().strftime('%d %b %Y')
3 changes: 3 additions & 0 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TextStreamer,
)

from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.exceptions import ChatTemplateError

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'
Expand Down Expand Up @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str:
chat_conversation,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand All @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None:
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down
38 changes: 35 additions & 3 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
dataset_constructor,
tokenize_formatted_example,
)
from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry.utils.exceptions import (
ALLOWED_PROMPT_KEYS,
Expand Down Expand Up @@ -210,10 +211,27 @@ def test_tokenize_instruct_example_well_formed():

@pytest.mark.parametrize(
'tokenizer_name',
['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'],
[
'EleutherAI/gpt-neox-20b',
'HuggingFaceH4/zephyr-7b-beta',
't5-base',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
],
)
@pytest.mark.parametrize('messages_format', [True, False])
def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
@pytest.mark.parametrize('use_date_string', [True, False])
def test_multi_turn_chat_slicing(
tokenizer_name: str,
messages_format: bool,
use_date_string: bool,
):
if 'meta-llama' in tokenizer_name:
pytest.skip('Model is gated. Skipping test.')
is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name
if is_llama_3_1_instruct and use_date_string:
pytest.skip(
'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.',
)
if messages_format:
convo = [
{
Expand Down Expand Up @@ -272,6 +290,10 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):

tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name)

# Manually set a chat template to test if the date_string is being used.
if use_date_string:
tok.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n"

templated_prompt_response_turns = _slice_chat_formatted_example(
example,
tok,
Expand All @@ -281,9 +303,19 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
for prompt, response in templated_prompt_response_turns:
reconstructed_chat += prompt + response

full_chat = tok.apply_chat_template(convo, tokenize=False)
date_string = get_date_string()
full_chat = tok.apply_chat_template(
convo,
tokenize=False,
date_string=date_string,
)
assert reconstructed_chat == full_chat

if is_llama_3_1_instruct or use_date_string:
assert date_string in full_chat
else:
assert date_string not in full_chat


def test_fail_chat_template():
convo = [
Expand Down
4 changes: 4 additions & 0 deletions tests/tokenizers/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import transformers

from llmfoundry.tokenizers import get_date_string
from llmfoundry.tokenizers.tiktoken import (
TiktokenTokenizerWrapper,
bytes_to_unicode,
Expand Down Expand Up @@ -516,6 +517,7 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i]
# Using default system prompt.
Expand All @@ -533,13 +535,15 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i]
for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML):
chat_str = wrapped_tokenizer.apply_chat_template(
dict_chats,
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_GENERATE_STRING[i]

Expand Down
50 changes: 50 additions & 0 deletions tests/tokenizers/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import AutoTokenizer

from llmfoundry.tokenizers.utils import get_date_string


def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'):
with open(conf_path) as f:
Expand Down Expand Up @@ -80,3 +84,49 @@ def test_load_tokenizer():
)['attention_mask']
attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4)
assert attention_mask == attn_mask_key


@pytest.mark.parametrize(
'tokenizer_name',
[
'EleutherAI/gpt-neox-20b',
'meta-llama/Meta-Llama-3-8B-Instruct',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
'meta-llama/Meta-Llama-3.1-70B-Instruct',
'meta-llama/Meta-Llama-3.1-405B-Instruct',
],
)
@pytest.mark.parametrize('use_date_string', [True, False])
def test_tokenizer_date_string(tokenizer_name: str, use_date_string: bool):
if 'meta-llama' in tokenizer_name:
pytest.skip('Model is gated. Skipping test.')

is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name
if is_llama_3_1_instruct and use_date_string:
pytest.skip(
'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.',
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [{'role': 'system', 'content': ''}]
date_string = get_date_string()

# Manually set a chat template to test if the date_string is being used.
if use_date_string:
tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n"

token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt',
date_string=date_string,
)

assert isinstance(token_ids, torch.Tensor)
decoded_text = tokenizer.decode(token_ids.flatten())

# Only Llama 3.1 instruct family models should use the current date in their chat templates.
if is_llama_3_1_instruct or use_date_string:
assert date_string in decoded_text
else:
assert date_string not in decoded_text

0 comments on commit e235f42

Please sign in to comment.