From 805cf83c709732e0d99b952dd51ab528f109814d Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Aug 2024 13:17:02 -0400 Subject: [PATCH] Add special errors for bad chat/ift types (#1437) --- llmfoundry/data/finetuning/tasks.py | 21 ++++++++------------- llmfoundry/utils/exceptions.py | 16 ++++++++++++++++ tests/data/test_template_tokenization.py | 18 ++++++++++++++++-- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index dd9b495ce4..e8175b4446 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -78,8 +78,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, + InvalidExampleTypeError, InvalidFileExtensionError, InvalidLastChatMessageRoleError, + InvalidMessageTypeError, InvalidPromptResponseKeysError, InvalidPromptTypeError, InvalidResponseTypeError, @@ -139,9 +141,7 @@ def _get_example_type(example: Example) -> ExampleType: KeyError: If the example type is unknown. """ if not isinstance(example, Mapping): - raise TypeError( - f'Expected example to be a Mapping, but found {type(example)}', - ) + raise InvalidExampleTypeError(str(type(example))) if ( len(example.keys()) == 1 and any( allowed_message_key in example @@ -156,7 +156,8 @@ def _get_example_type(example: Example) -> ExampleType: ): return 'prompt_response' else: - raise UnknownExampleTypeError(str(example.keys())) + keys = str(set(example.keys())) + raise UnknownExampleTypeError(keys) def _is_empty_or_nonexistent(dirpath: str) -> bool: @@ -173,23 +174,17 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]): if not isinstance(dictionary, Mapping): - raise TypeError( - f'Expected dictionary to be a mapping, but found {type(dictionary)}', - ) + raise InvalidExampleTypeError(str(type(dictionary))) desired_keys = allowed_keys.intersection(dictionary.keys()) return list(desired_keys)[0] def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(example, Mapping): - raise TypeError( - f'Expected example to be a mapping, but found {type(example)}', - ) + raise InvalidExampleTypeError(str(type(example))) messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)] if not isinstance(messages, List): - raise TypeError( - f'Expected messages to be an iterable, but found {type(messages)}', - ) + raise InvalidMessageTypeError(str(type(messages))) if len(messages) <= 1: raise NotEnoughChatDataError() diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 140bf8540b..c6a667697d 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -161,6 +161,22 @@ def __init__( ## Tasks exceptions +class InvalidExampleTypeError(UserError): + """Error thrown when a message type is not a `Mapping`.""" + + def __init__(self, example_type: str) -> None: + message = f'Expected example to be a `Mapping`, but found type {example_type}' + super().__init__(message, example_type=example_type) + + +class InvalidMessageTypeError(UserError): + """Error thrown when a message type is not an `Iterable`.""" + + def __init__(self, message_type: str) -> None: + message = f'Expected message to be an `Iterable`, but found type {message_type}' + super().__init__(message, message_type=message_type) + + class UnknownExampleTypeError(UserError): """Error thrown when an unknown example type is used in a task.""" diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 16447d6623..9f44739b6b 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -16,6 +16,8 @@ ALLOWED_PROMPT_KEYS, ALLOWED_RESPONSE_KEYS, ChatTemplateError, + InvalidExampleTypeError, + InvalidMessageTypeError, ) @@ -48,13 +50,13 @@ def test_tokenize_chat_example_malformed(): 'content': 'user message not followed by an assistant label', }], } - wrong_type = {'messages': 'this is not a list of messages'} + wrong_example_type = ['this is not a dictionary'] + wrong_messages_type = {'messages': 'this is not a list of messages'} malformed_chat_examples = [ too_few_messages, no_content, ends_with_user_role, no_assistant_message, - wrong_type, ] my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) for example in malformed_chat_examples: @@ -63,6 +65,18 @@ def test_tokenize_chat_example_malformed(): example, my_tokenizer, ) # type: ignore (the typing here is supposed to be malformed) + with pytest.raises(InvalidExampleTypeError): + # Ignore the type here because it's the mistyping that we're + # trying to test. + tokenize_formatted_example( # type: ignore + wrong_example_type, # type: ignore + my_tokenizer, # type: ignore + ) + with pytest.raises(InvalidMessageTypeError): + tokenize_formatted_example( + wrong_messages_type, + my_tokenizer, + ) def test_tokenize_chat_example_well_formed():