Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni authored Aug 27, 2024
2 parents 184311a + cef39d1 commit dddec21
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 10 deletions.
30 changes: 25 additions & 5 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stitch_turns_decoder_only,
stitch_turns_encoder_decoder,
)
from llmfoundry.tokenizers import get_date_string
# yapf: disable
from llmfoundry.utils.exceptions import (
ALLOWED_MESSAGES_KEYS,
Expand Down Expand Up @@ -249,11 +250,13 @@ def slice_out_last_turn(
full_conversation = tokenizer.apply_chat_template(
messages_through_current_turn,
tokenize=False,
date_string=get_date_string(),
)
prompt_with_history = tokenizer.apply_chat_template(
messages_through_current_turn[:-1],
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down Expand Up @@ -319,10 +322,24 @@ def _tokenize_with_bos_removal(
)

# Remove the BOS token from the start of the labels if it was automatically added
if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token:
if tokenizer.bos_token_id is not None and tokenized_sample['labels'][
0] == tokenizer.bos_token_id:
tokenized_sample['labels'] = tokenized_sample['labels'][1:]
# Unfortunately if the tokenizer is PretrainedTokenizerFast, as llama3 is, it does not provide
# an add_bos_token attr that we can check explicitly, so instead we rely on checking if both the
# text and the text_target start with bos_token_id to determine whether to strip bos.
has_bos_token = hasattr(
tokenizer,
'bos_token_id',
) and tokenizer.bos_token_id is not None
input_ids_starts_with_bos = False
labels_starts_with_bos = False
if has_bos_token and len(
tokenized_sample['input_ids'],
) > 0 and len(tokenized_sample['labels']) > 0:
input_ids_starts_with_bos = tokenized_sample['input_ids'][
0] == tokenizer.bos_token_id
labels_starts_with_bos = tokenized_sample['labels'][
0] == tokenizer.bos_token_id
if input_ids_starts_with_bos and labels_starts_with_bos:
tokenized_sample['labels'] = tokenized_sample['labels'][1:]

return tokenized_sample

Expand Down Expand Up @@ -642,6 +659,9 @@ def __init__(
self.max_seq_len = max_seq_len
self.packing_ratio = packing_ratio

def tokenize_example(self, example: Example) -> TokenizedExample:
return tokenize_formatted_example(example, self.tokenizer)

# How to process a sample
def __getitem__(self, idx: int) -> dict[str, Any]:
sample = super().__getitem__(idx)
Expand Down Expand Up @@ -670,7 +690,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
)
# Convert to latest format by wrapping sample as a "turn"
return {'turns': [sample]}
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)
return self.tokenize_example(sample)

def state_dict(self, num_samples: int,
from_beginning: bool) -> dict[str, Any]:
Expand Down
3 changes: 3 additions & 0 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def build_inner_model(
pretrained_lora_id_or_path,
)

if prepare_for_fsdp:
cls.prepare_inner_model(model, init_device)

return model

def get_peft_config(self, peft_config_dict: dict[str, Any]) -> 'PeftConfig':
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from llmfoundry.registry import tokenizers
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.tokenizers.utils import get_date_string

tokenizers.register('tiktoken', func=TiktokenTokenizerWrapper)

__all__ = [
'TiktokenTokenizerWrapper',
'get_date_string',
]
13 changes: 13 additions & 0 deletions llmfoundry/tokenizers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import datetime

__all__ = [
'get_date_string',
]


def get_date_string() -> str:
"""Get the current date string."""
return datetime.datetime.now().strftime('%d %b %Y')
3 changes: 3 additions & 0 deletions scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TextStreamer,
)

from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.exceptions import ChatTemplateError

DEFAULT_SYSTEM_PROMPT = 'You are a friendly chatbot who aims to be helpful and honest.'
Expand Down Expand Up @@ -132,6 +133,7 @@ def _history_as_formatted_str(self) -> str:
chat_conversation,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand All @@ -149,6 +151,7 @@ def turn(self, user_inp: str) -> None:
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
date_string=get_date_string(),
)
except Exception as e:
raise ChatTemplateError(
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
'mlflow>=2.14.1,<2.16',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
'mosaicml-streaming>=0.8.0,<0.9',
'mosaicml-streaming>=0.8.1,<0.9',
'torch>=2.3.0,<2.4',
'datasets>=2.19,<2.20',
'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data
Expand Down Expand Up @@ -87,7 +87,7 @@
'pyright==1.1.256',
'toml>=0.10.2,<0.11',
'packaging>=21,<25',
'hf_transfer==0.1.3',
'hf_transfer==0.1.8',
]

extra_deps['databricks'] = [
Expand Down
38 changes: 35 additions & 3 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
dataset_constructor,
tokenize_formatted_example,
)
from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry.utils.exceptions import (
ALLOWED_PROMPT_KEYS,
Expand Down Expand Up @@ -210,10 +211,27 @@ def test_tokenize_instruct_example_well_formed():

@pytest.mark.parametrize(
'tokenizer_name',
['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'],
[
'EleutherAI/gpt-neox-20b',
'HuggingFaceH4/zephyr-7b-beta',
't5-base',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
],
)
@pytest.mark.parametrize('messages_format', [True, False])
def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
@pytest.mark.parametrize('use_date_string', [True, False])
def test_multi_turn_chat_slicing(
tokenizer_name: str,
messages_format: bool,
use_date_string: bool,
):
if 'meta-llama' in tokenizer_name:
pytest.skip('Model is gated. Skipping test.')
is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name
if is_llama_3_1_instruct and use_date_string:
pytest.skip(
'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.',
)
if messages_format:
convo = [
{
Expand Down Expand Up @@ -272,6 +290,10 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):

tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name)

# Manually set a chat template to test if the date_string is being used.
if use_date_string:
tok.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n"

templated_prompt_response_turns = _slice_chat_formatted_example(
example,
tok,
Expand All @@ -281,9 +303,19 @@ def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
for prompt, response in templated_prompt_response_turns:
reconstructed_chat += prompt + response

full_chat = tok.apply_chat_template(convo, tokenize=False)
date_string = get_date_string()
full_chat = tok.apply_chat_template(
convo,
tokenize=False,
date_string=date_string,
)
assert reconstructed_chat == full_chat

if is_llama_3_1_instruct or use_date_string:
assert date_string in full_chat
else:
assert date_string not in full_chat


def test_fail_chat_template():
convo = [
Expand Down
25 changes: 25 additions & 0 deletions tests/models/hf/test_hf_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel


def test_build_inner_model_fsdp():
model = BaseHuggingFaceModel.build_inner_model(
pretrained_model_name_or_path='codellama/CodeLlama-7b-hf',
pretrained_lora_id_or_path=None,
trust_remote_code=False,
init_device='cpu',
use_flash_attention_2=False,
use_auth_token=False,
config_overrides={
'num_hidden_layers': 2,
'hidden_size': 32,
'intermediate_size': 64,
},
load_in_8bit=False,
pretrained=False,
prepare_for_fsdp=True,
)

assert model.fsdp_wrap_fn(model.model.layers[0])
4 changes: 4 additions & 0 deletions tests/tokenizers/test_tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import transformers

from llmfoundry.tokenizers import get_date_string
from llmfoundry.tokenizers.tiktoken import (
TiktokenTokenizerWrapper,
bytes_to_unicode,
Expand Down Expand Up @@ -516,6 +517,7 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_NO_SYSTEM_PROMPT[i]
# Using default system prompt.
Expand All @@ -533,13 +535,15 @@ def test_chat_formatting(
dict_chats,
tokenize=False,
add_generation_prompt=False,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_CHAT_STRING_SYSTEM_PROMPT[i]
for i, dict_chats in enumerate(MULTI_TURN_GENERATE_CHAT_ML):
chat_str = wrapped_tokenizer.apply_chat_template(
dict_chats,
tokenize=False,
add_generation_prompt=True,
date_string=get_date_string(),
)
assert chat_str == MULTI_TURN_GENERATE_STRING[i]

Expand Down
50 changes: 50 additions & 0 deletions tests/tokenizers/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from omegaconf import OmegaConf as om
from transformers import AutoTokenizer

from llmfoundry.tokenizers.utils import get_date_string


def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'):
with open(conf_path) as f:
Expand Down Expand Up @@ -80,3 +84,49 @@ def test_load_tokenizer():
)['attention_mask']
attn_mask_key = [1, 1, 1, 1] + [0] * (tokenizer.model_max_length - 4)
assert attention_mask == attn_mask_key


@pytest.mark.parametrize(
'tokenizer_name',
[
'EleutherAI/gpt-neox-20b',
'meta-llama/Meta-Llama-3-8B-Instruct',
'meta-llama/Meta-Llama-3.1-8B-Instruct',
'meta-llama/Meta-Llama-3.1-70B-Instruct',
'meta-llama/Meta-Llama-3.1-405B-Instruct',
],
)
@pytest.mark.parametrize('use_date_string', [True, False])
def test_tokenizer_date_string(tokenizer_name: str, use_date_string: bool):
if 'meta-llama' in tokenizer_name:
pytest.skip('Model is gated. Skipping test.')

is_llama_3_1_instruct = 'Meta-Llama-3.1' in tokenizer_name and 'Instruct' in tokenizer_name
if is_llama_3_1_instruct and use_date_string:
pytest.skip(
'Llama 3.1 Instruct models use date_string in chat template already. Skipping test.',
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
messages = [{'role': 'system', 'content': ''}]
date_string = get_date_string()

# Manually set a chat template to test if the date_string is being used.
if use_date_string:
tokenizer.chat_template = "{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{{- \"Today Date: \" + date_string }}\n"

token_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt',
date_string=date_string,
)

assert isinstance(token_ids, torch.Tensor)
decoded_text = tokenizer.decode(token_ids.flatten())

# Only Llama 3.1 instruct family models should use the current date in their chat templates.
if is_llama_3_1_instruct or use_date_string:
assert date_string in decoded_text
else:
assert date_string not in decoded_text

0 comments on commit dddec21

Please sign in to comment.