diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6b4bd25936..a97914da2a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1161,3 +1161,31 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict: except Exception as e: raise UnableToProcessPromptResponseError(inp) from e return {'messages': messages} + + +@dataset_constructor.register('math-ai/StackMathQA') +def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: + """Convert from QA format to our chat format.""" + try: + Q = inp['Q'] + A = inp['A'] + messages: list[dict[str, str]] = [{ + 'role': 'user', + 'content': Q, + }, { + 'role': 'assistant', + 'content': A, + }] + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} + + +@dataset_constructor.register('AI-MO/NuminaMath-CoT') +def messages_format_preprocessor(inp: dict) -> ChatFormattedDict: + """Convert from QA format to our chat format.""" + try: + messages = inp['messages'] + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b89fcc4b37..90f748bea2 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -8,8 +8,10 @@ import pytest from llmfoundry.data.finetuning.tasks import ( + QA_format_preprocessor, _get_num_processes, dataset_constructor, + messages_format_preprocessor, ) from llmfoundry.utils.exceptions import DatasetTooSmallError @@ -60,3 +62,48 @@ def get_local_world_size(self): new=MockDataset, ): dataset_constructor.build_from_streaming() + + +def test_QA_format_preprocessor(): + inp = { + 'Q': 'What is the capital of France?', + 'A': 'Paris', + 'meta': { + 'a': 'b', + }, + } + + expected_messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + output = QA_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + for i, message in enumerate(output['messages']): + expected_message = expected_messages[i] + for k, v in message.items(): + assert k in expected_message + assert v == expected_message[k] + + +def test_messages_format_preprocessor(): + messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + inp = { + 'messages': messages, + 'other_key': 'other_value', + } + + output = messages_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + assert output['messages'] == messages