Skip to content

Commit

Permalink
Add special errors for bad chat/ift types (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress authored Aug 8, 2024
1 parent f006d07 commit 805cf83
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
21 changes: 8 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ConsecutiveRepeatedChatRolesError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidExampleTypeError,
InvalidFileExtensionError,
InvalidLastChatMessageRoleError,
InvalidMessageTypeError,
InvalidPromptResponseKeysError,
InvalidPromptTypeError,
InvalidResponseTypeError,
Expand Down Expand Up @@ -139,9 +141,7 @@ def _get_example_type(example: Example) -> ExampleType:
KeyError: If the example type is unknown.
"""
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a Mapping, but found {type(example)}',
)
raise InvalidExampleTypeError(str(type(example)))
if (
len(example.keys()) == 1 and any(
allowed_message_key in example
Expand All @@ -156,7 +156,8 @@ def _get_example_type(example: Example) -> ExampleType:
):
return 'prompt_response'
else:
raise UnknownExampleTypeError(str(example.keys()))
keys = str(set(example.keys()))
raise UnknownExampleTypeError(keys)


def _is_empty_or_nonexistent(dirpath: str) -> bool:
Expand All @@ -173,23 +174,17 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:

def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
if not isinstance(dictionary, Mapping):
raise TypeError(
f'Expected dictionary to be a mapping, but found {type(dictionary)}',
)
raise InvalidExampleTypeError(str(type(dictionary)))
desired_keys = allowed_keys.intersection(dictionary.keys())
return list(desired_keys)[0]


def _validate_chat_formatted_example(example: ChatFormattedDict):
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a mapping, but found {type(example)}',
)
raise InvalidExampleTypeError(str(type(example)))
messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)]
if not isinstance(messages, List):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}',
)
raise InvalidMessageTypeError(str(type(messages)))
if len(messages) <= 1:
raise NotEnoughChatDataError()

Expand Down
16 changes: 16 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ def __init__(


## Tasks exceptions
class InvalidExampleTypeError(UserError):
"""Error thrown when a message type is not a `Mapping`."""

def __init__(self, example_type: str) -> None:
message = f'Expected example to be a `Mapping`, but found type {example_type}'
super().__init__(message, example_type=example_type)


class InvalidMessageTypeError(UserError):
"""Error thrown when a message type is not an `Iterable`."""

def __init__(self, message_type: str) -> None:
message = f'Expected message to be an `Iterable`, but found type {message_type}'
super().__init__(message, message_type=message_type)


class UnknownExampleTypeError(UserError):
"""Error thrown when an unknown example type is used in a task."""

Expand Down
18 changes: 16 additions & 2 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
InvalidExampleTypeError,
InvalidMessageTypeError,
)


Expand Down Expand Up @@ -48,13 +50,13 @@ def test_tokenize_chat_example_malformed():
'content': 'user message not followed by an assistant label',
}],
}
wrong_type = {'messages': 'this is not a list of messages'}
wrong_example_type = ['this is not a dictionary']
wrong_messages_type = {'messages': 'this is not a list of messages'}
malformed_chat_examples = [
too_few_messages,
no_content,
ends_with_user_role,
no_assistant_message,
wrong_type,
]
my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {})
for example in malformed_chat_examples:
Expand All @@ -63,6 +65,18 @@ def test_tokenize_chat_example_malformed():
example,
my_tokenizer,
) # type: ignore (the typing here is supposed to be malformed)
with pytest.raises(InvalidExampleTypeError):
# Ignore the type here because it's the mistyping that we're
# trying to test.
tokenize_formatted_example( # type: ignore
wrong_example_type, # type: ignore
my_tokenizer, # type: ignore
)
with pytest.raises(InvalidMessageTypeError):
tokenize_formatted_example(
wrong_messages_type,
my_tokenizer,
)


def test_tokenize_chat_example_well_formed():
Expand Down

0 comments on commit 805cf83

Please sign in to comment.