diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b89fcc4b37..90f748bea2 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -8,8 +8,10 @@ import pytest from llmfoundry.data.finetuning.tasks import ( + QA_format_preprocessor, _get_num_processes, dataset_constructor, + messages_format_preprocessor, ) from llmfoundry.utils.exceptions import DatasetTooSmallError @@ -60,3 +62,48 @@ def get_local_world_size(self): new=MockDataset, ): dataset_constructor.build_from_streaming() + + +def test_QA_format_preprocessor(): + inp = { + 'Q': 'What is the capital of France?', + 'A': 'Paris', + 'meta': { + 'a': 'b', + }, + } + + expected_messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + output = QA_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + for i, message in enumerate(output['messages']): + expected_message = expected_messages[i] + for k, v in message.items(): + assert k in expected_message + assert v == expected_message[k] + + +def test_messages_format_preprocessor(): + messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + inp = { + 'messages': messages, + 'other_key': 'other_value', + } + + output = messages_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + assert output['messages'] == messages