Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Aug 21, 2024
1 parent c56401f commit b9d1d89
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 1 deletion.
3 changes: 3 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stitch_turns_decoder_only,
stitch_turns_encoder_decoder,
)
from llmfoundry.tokenizers import get_date_string
# yapf: disable
from llmfoundry.utils.exceptions import (
ALLOWED_MESSAGES_KEYS,
Expand Down Expand Up @@ -249,11 +250,13 @@ def slice_out_last_turn(
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
date_string=get_date_string(),
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from llmfoundry.registry import tokenizers
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.tokenizers.utils import get_date_string

tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper)

__all__ = [
'TiktokenTokenizerWrapper',
'get_date_string',
]
12 changes: 12 additions & 0 deletions llmfoundry/tokenizers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import datetime

__all__ = [
'get_date_string',
]

def get_date_string() -> str:
"""Get the current date string."""
return datetime.datetime.now().strftime('%d %b %Y')
3 changes: 3 additions & 0 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TextStreamer,
)

from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.exceptions import ChatTemplateError

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'
Expand Down Expand Up @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str:
chat_conversation,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand All @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None:
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down
7 changes: 6 additions & 1 deletion tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
dataset_constructor,
tokenize_formatted_example,
)
from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry.utils.exceptions import (
ALLOWED_PROMPT_KEYS,
Expand Down Expand Up @@ -281,7 +282,11 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
for prompt, response in templated_prompt_response_turns:
reconstructed_chat += prompt + response

full_chat = tok.apply_chat_template(convo, tokenize=False)
full_chat = tok.apply_chat_template(
convo,
tokenize=False,
date_string=get_date_string(),
)
assert reconstructed_chat == full_chat


Expand Down
4 changes: 4 additions & 0 deletions tests/tokenizers/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import transformers

from llmfoundry.tokenizers import get_date_string
from llmfoundry.tokenizers.tiktoken import (
TiktokenTokenizerWrapper,
bytes_to_unicode,
Expand Down Expand Up @@ -516,6 +517,7 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i]
# Using default system prompt.
Expand All @@ -533,13 +535,15 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i]
for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML):
chat_str = wrapped_tokenizer.apply_chat_template(
dict_chats,
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_GENERATE_STRING[i]

Expand Down
50 changes: 50 additions & 0 deletions tests/tokenizers/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import AutoTokenizer

from llmfoundry.tokenizers.utils import get_date_string


def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'):
with open(conf_path) as f:
Expand Down Expand Up @@ -80,3 +84,49 @@ def test_load_tokenizer():
)['attention_mask']
attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4)
assert attention_mask == attn_mask_key


@pytest.mark.parametrize(
'tokenizer_name',
[
'EleutherAI/gpt-neox-20b',
'meta-llama/Meta-Llama-3-8B-Instruct',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
'meta-llama/Meta-Llama-3.1-70B-Instruct',
'meta-llama/Meta-Llama-3.1-405B-Instruct',
],
)
@pytest.mark.parametrize('spoof_chat_template', [True, False])
def test_tokenizer_date_string(tokenizer_name: str, spoof_chat_template: bool):
if 'meta-llama' in tokenizer_name:
pytest.skip('Model is gated. Skipping test.')

is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name
if is_llama_3_1_instruct and spoof_chat_template:
pytest.skip(
'Llama 3.1 Instruct models use date_string in chat template, so no need to spoof. Skipping test.',
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [{'role': 'system', 'content': ''}]
date_string = get_date_string()

# Manually set a chat template to test if the date_string is being used.
if spoof_chat_template:
tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n"

token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt',
date_string=date_string,
)

assert isinstance(token_ids, torch.Tensor)
decoded_text = tokenizer.decode(token_ids.flatten())

# Only Llama 3.1 instruct family models should use the current date in their chat templates.
if is_llama_3_1_instruct or spoof_chat_template:
assert date_string in decoded_text
else:
assert date_string not in decoded_text

0 comments on commit b9d1d89

Please sign in to comment.