From 04c3d69822909bed2de6b50a1e5fe5be43809fc4 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 7 Aug 2024 19:19:29 +0000 Subject: [PATCH 1/4] add special errors for bad types --- llmfoundry/data/finetuning/tasks.py | 21 ++++++++------------- llmfoundry/utils/exceptions.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 0adad8af4e..330a88703a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -77,8 +77,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, + InvalidExampleTypeError, InvalidFileExtensionError, InvalidLastChatMessageRoleError, + InvalidMessageTypeError, InvalidPromptResponseKeysError, InvalidPromptTypeError, InvalidResponseTypeError, @@ -136,9 +138,7 @@ def _get_example_type(example: Example) -> ExampleType: KeyError: If the example type is unknown. """ if not isinstance(example, Mapping): - raise TypeError( - f'Expected example to be a Mapping, but found {type(example)}', - ) + raise InvalidExampleTypeError(str(type(example))) if ( len(example.keys()) == 1 and any( allowed_message_key in example @@ -153,7 +153,8 @@ def _get_example_type(example: Example) -> ExampleType: ): return 'prompt_response' else: - raise UnknownExampleTypeError(str(example.keys())) + keys = str(example.keys()) + raise UnknownExampleTypeError(keys) def _is_empty_or_nonexistent(dirpath: str) -> bool: @@ -170,23 +171,17 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]): if not isinstance(dictionary, Mapping): - raise TypeError( - f'Expected dictionary to be a mapping, but found {type(dictionary)}', - ) + raise InvalidExampleTypeError(str(type(dictionary))) desired_keys = allowed_keys.intersection(dictionary.keys()) return list(desired_keys)[0] def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(example, Mapping): - raise TypeError( - f'Expected example to be a mapping, but found {type(example)}', - ) + raise InvalidExampleTypeError(str(type(example))) messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)] if not isinstance(messages, List): - raise TypeError( - f'Expected messages to be an iterable, but found {type(messages)}', - ) + raise InvalidMessageTypeError(str(type(messages))) if len(messages) <= 1: raise NotEnoughChatDataError() diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 76f378f8c6..251bc86631 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -160,6 +160,22 @@ def __init__( ## Tasks exceptions +class InvalidExampleTypeError(UserError): + """Error thrown when a message type is not a `Mapping`.""" + + def __init__(self, example_type: str) -> None: + message = f'Expected example to be a `Mapping`, but found type {example_type}' + super().__init__(message, example_type=example_type) + + +class InvalidMessageTypeError(UserError): + """Error thrown when a message type is not an `Iterable`.""" + + def __init__(self, message_type: str) -> None: + message = f'Expected message to be an `Iterable`, but found type {message_type}' + super().__init__(message, message_type=message_type) + + class UnknownExampleTypeError(UserError): """Error thrown when an unknown example type is used in a task.""" From dbe4faf841f5596a1f4132957a2491fb52fe5b6b Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Wed, 7 Aug 2024 19:29:05 +0000 Subject: [PATCH 2/4] set to get rid of keysivew --- llmfoundry/data/finetuning/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 330a88703a..1b3d257b42 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -153,7 +153,7 @@ def _get_example_type(example: Example) -> ExampleType: ): return 'prompt_response' else: - keys = str(example.keys()) + keys = str(set(example.keys())) raise UnknownExampleTypeError(keys) From d800fd8f08c20edbe88c81e0a6adeedee4152083 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Aug 2024 16:26:45 +0000 Subject: [PATCH 3/4] add tests --- tests/data/test_template_tokenization.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 16447d6623..11e7b84bfc 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -16,6 +16,8 @@ ALLOWED_PROMPT_KEYS, ALLOWED_RESPONSE_KEYS, ChatTemplateError, + InvalidExampleTypeError, + InvalidMessageTypeError, ) @@ -48,13 +50,13 @@ def test_tokenize_chat_example_malformed(): 'content': 'user message not followed by an assistant label', }], } - wrong_type = {'messages': 'this is not a list of messages'} + wrong_example_type = ['this is not a dictionary'] + wrong_messages_type = {'messages': 'this is not a list of messages'} malformed_chat_examples = [ too_few_messages, no_content, ends_with_user_role, no_assistant_message, - wrong_type, ] my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) for example in malformed_chat_examples: @@ -63,6 +65,16 @@ def test_tokenize_chat_example_malformed(): example, my_tokenizer, ) # type: ignore (the typing here is supposed to be malformed) + with pytest.raises(InvalidExampleTypeError): + tokenize_formatted_example( + wrong_example_type, + my_tokenizer, + ) + with pytest.raises(InvalidMessageTypeError): + tokenize_formatted_example( + wrong_messages_type, + my_tokenizer, + ) def test_tokenize_chat_example_well_formed(): From 2506d28601741c619c6746aa30c8648255dcccf1 Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Thu, 8 Aug 2024 16:47:46 +0000 Subject: [PATCH 4/4] fix types --- tests/data/test_template_tokenization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 11e7b84bfc..9f44739b6b 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -66,9 +66,11 @@ def test_tokenize_chat_example_malformed(): my_tokenizer, ) # type: ignore (the typing here is supposed to be malformed) with pytest.raises(InvalidExampleTypeError): - tokenize_formatted_example( - wrong_example_type, - my_tokenizer, + # Ignore the type here because it's the mistyping that we're + # trying to test. + tokenize_formatted_example( # type: ignore + wrong_example_type, # type: ignore + my_tokenizer, # type: ignore ) with pytest.raises(InvalidMessageTypeError): tokenize_formatted_example(