From 5f5a144aec99d6079b284f04241d51f79776ebc6 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 18 Jan 2024 19:15:53 +0000 Subject: [PATCH] fix conflicting formatting linting guidelines --- llmfoundry/data/finetuning/tasks.py | 76 +++++++++++++++++++++-- tests/data/test_dataloader.py | 93 ++++++++++++++++++++++++++--- 2 files changed, 156 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e61d138c41..94321a2853 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +# isort:skip_file """Includes code for task-specific seq-to-seq data formatting. @@ -36,7 +37,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import os import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import datasets as hf_datasets import huggingface_hub as hf_hub @@ -57,6 +58,23 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] +PromptResponseDict = Dict[str, str] +ChatFormattedDict = Dict[str, List[Dict[str, str]]] +Conversation = PromptResponseDict | ChatFormattedDict +ConversationType = Literal['prompt_response', 'chat'] +TokenizedConversation = Dict[str, List[int | str]] + + +def _get_conversation_type(conversation_example: Conversation): + # note: this function does not validate the conversation types, + # it merely determines which validator to use. + if 'messages' in conversation_example: + return 'chat' + elif 'prompt' in conversation_example or 'response' in conversation_example: + return 'prompt_response' + else: + raise KeyError(f'unknown conversation type {conversation_example=}') + def _is_empty_or_nonexistent(dirpath: str) -> bool: """Check if a directory is empty or non-existent. @@ -70,9 +88,42 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 -def _tokenize_formatted_example( - example: Dict[str, Any], - tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: +def _tokenize_chat_formatted_example( + example: ChatFormattedDict, + tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation: + + def slice(s: str, sep: str): + # it seems like we can reuse this logic, as we likely have this pattern in other places. + slices = s.split(sep) + if len(slices) < 2: + raise ValueError(f'separator not in string. {sep=}, {s=}') + a, b = sep.join(slices[:-1]), sep + slices[-1] + return a, b + + messages = example['messages'] + + if len(messages) < 2: + raise ValueError( + f'chat example must have at least two messages. {messages=}') + last_message = messages[-1] + if last_message['role'] != 'assistant': + raise ValueError( + f'last message must be from assistant. {last_message=}') + for message in messages: + if 'role' not in message or 'content' not in message: + raise KeyError(f'message must have role and content. {message=}') + + applied_template = tokenizer.apply_chat_template(messages, tokenize=False) + prompt, response = slice(applied_template, last_message['content']) + return { + 'input_ids': tokenizer.tokenize(prompt), + 'labels': tokenizer.tokenize(response) + } + + +def _tokenize_prompt_response_formatted_example( + example: PromptResponseDict, + tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation: """Tokenize a formatted example and validate expected keys.""" example_keys = set(example.keys()) prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) @@ -108,6 +159,23 @@ def _tokenize_formatted_example( return tokenizer(text=prompt, text_target=response) +def _tokenize_formatted_example( + example: Conversation, + tokenizer: PreTrainedTokenizerBase) -> TokenizedConversation: + example_format = _get_conversation_type(example) + print(f'{example_format=}') + + if example_format == 'chat': + chat_example: ChatFormattedDict = example # type: ignore + return _tokenize_chat_formatted_example(chat_example, tokenizer) + elif example_format == 'prompt_response': + prompt_response_example: PromptResponseDict = example # type: ignore + return _tokenize_prompt_response_formatted_example( + prompt_response_example, tokenizer) + else: + raise ValueError(f'unknown conversation type {example_format=}') + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 7f99eeda25..be50c389a0 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -1,5 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +# isort: skip_file import contextlib import os import pathlib @@ -9,7 +10,7 @@ from argparse import Namespace from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import ContextManager, Literal, Optional, Union +from typing import ContextManager, List, Literal, Optional, Union from unittest.mock import MagicMock import pytest @@ -23,11 +24,10 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data import build_dataloader -from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, - _ALLOWED_RESPONSE_KEYS, - DOWNLOADED_FT_DATASETS_DIRPATH, - SUPPORTED_EXTENSIONS, - _tokenize_formatted_example) +from llmfoundry.data.finetuning.tasks import ( + _ALLOWED_PROMPT_KEYS, _ALLOWED_RESPONSE_KEYS, + DOWNLOADED_FT_DATASETS_DIRPATH, SUPPORTED_EXTENSIONS, ChatFormattedDict, + PromptResponseDict, _tokenize_formatted_example) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -428,27 +428,102 @@ def test_tokenize_example_malformed(): 'response': 'response', 'completion': 'completion' } + no_content = {'messages': [{'role': 'user'}]} + ends_with_user_role: ChatFormattedDict = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'assistant', + 'content': 'Hi, User!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } + no_assistant_message: ChatFormattedDict = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello GPT!' + }, { + 'role': 'user', + 'content': 'user message not followed by an assistant label' + }] + } - malformed_examples = [ + malformed_prompt_response_examples = [ no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, extra_keys_with_response, multiple_allowed_response_keys ] + malformed_chat_examples = [ + no_content, ends_with_user_role, no_assistant_message + ] - for example in malformed_examples: + for example in malformed_prompt_response_examples: with pytest.raises(KeyError): _tokenize_formatted_example(example, MagicMock()) + my_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {}) + for example in malformed_chat_examples: + with pytest.raises(Exception): + _tokenize_formatted_example( + example, my_tokenizer + ) # type: ignore (the typing here is supposed to be malformed) + def test_tokenize_example_well_formed(): tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') for prompt_key in _ALLOWED_PROMPT_KEYS: for response_key in _ALLOWED_RESPONSE_KEYS: - example = {prompt_key: 'prompt', response_key: 'response'} + + example: PromptResponseDict = { + prompt_key: 'prompt', + response_key: 'response' + } tokenized_example = _tokenize_formatted_example(example, tokenizer) assert 'input_ids' in tokenized_example assert 'labels' in tokenized_example + chat_examples: List[ChatFormattedDict] = [ + { + 'messages': [{ + 'role': 'user', + 'content': 'Hello, GPT' + }, { + 'role': 'assistant', + 'content': 'this is my response' + }] + }, # prompt/response but in chat format + { + 'messages': [ + { + 'role': 'user', + 'content': 'Hello, GPT' + }, + { + 'role': 'assistant', + 'content': 'this is my response' + }, + { + 'role': 'user', + 'content': 'Nice to hear that.' + }, + { + 'role': 'assistant', + 'content': 'multi-way chat works too!' + }, + ] + }, # multi-way chat + ] + + chat_tokenizer = build_tokenizer('TinyLlama/TinyLlama-1.1B-Chat-v1.0', {}) + for chat_example in chat_examples: + tokenized_example = _tokenize_formatted_example(chat_example, + chat_tokenizer) + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example + @pytest.mark.parametrize('split', ['train', 'custom', 'data']) def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):