From b9322862b75f5d4fa995fc564b96389899048aa7 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 27 Nov 2024 12:40:26 -0800 Subject: [PATCH 1/5] .. --- llmfoundry/data/finetuning/tasks.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6b4bd25936..bda19c9b5e 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1161,3 +1161,21 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict: except Exception as e: raise UnableToProcessPromptResponseError(inp) from e return {'messages': messages} + + +@dataset_constructor.register('abc/def') +def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: + """Convert from QA format to our chat format.""" + try: + Q = inp['Q'] + A = inp['A'] + messages: list[dict[str, str]] = [{ + 'role': 'user', + 'content': Q, + }, { + 'role': 'assistant', + 'content': A, + }] + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} From a0b08a465a255bf4529b70bd09f407211a82cc2a Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 27 Nov 2024 12:45:19 -0800 Subject: [PATCH 2/5] .. --- llmfoundry/data/finetuning/tasks.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index bda19c9b5e..02af7c09f3 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1179,3 +1179,13 @@ def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: except Exception as e: raise UnableToProcessPromptResponseError(inp) from e return {'messages': messages} + + +@dataset_constructor.register('abc/msg') +def messages_format_preprocessor(inp: dict) -> ChatFormattedDict: + """Convert from QA format to our chat format.""" + try: + messages = inp['messages'] + except Exception as e: + raise UnableToProcessPromptResponseError(inp) from e + return {'messages': messages} From a6eae451013f10ae97a485119483a4de987912ab Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:16:50 -0800 Subject: [PATCH 3/5] Update tasks.py --- llmfoundry/data/finetuning/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 02af7c09f3..7237e86c5c 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1163,7 +1163,7 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict: return {'messages': messages} -@dataset_constructor.register('abc/def') +@dataset_constructor.register('qa_format') def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: """Convert from QA format to our chat format.""" try: @@ -1181,7 +1181,7 @@ def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: return {'messages': messages} -@dataset_constructor.register('abc/msg') +@dataset_constructor.register('messages_format') def messages_format_preprocessor(inp: dict) -> ChatFormattedDict: """Convert from QA format to our chat format.""" try: From 6367a43944621ac54cb9849871f7518fb4fe2bcb Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:26:49 -0800 Subject: [PATCH 4/5] Update tasks.py --- llmfoundry/data/finetuning/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 7237e86c5c..a97914da2a 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -1163,7 +1163,7 @@ def shareGPT_format_preprocessor(inp: dict) -> ChatFormattedDict: return {'messages': messages} -@dataset_constructor.register('qa_format') +@dataset_constructor.register('math-ai/StackMathQA') def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: """Convert from QA format to our chat format.""" try: @@ -1181,7 +1181,7 @@ def QA_format_preprocessor(inp: dict) -> ChatFormattedDict: return {'messages': messages} -@dataset_constructor.register('messages_format') +@dataset_constructor.register('AI-MO/NuminaMath-CoT') def messages_format_preprocessor(inp: dict) -> ChatFormattedDict: """Convert from QA format to our chat format.""" try: From f4c2399111b551392e64ef74211a27aa8ee5ede1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 17 Dec 2024 14:35:49 -0500 Subject: [PATCH 5/5] adding tests --- tests/data/test_dataset.py | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index b89fcc4b37..90f748bea2 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -8,8 +8,10 @@ import pytest from llmfoundry.data.finetuning.tasks import ( + QA_format_preprocessor, _get_num_processes, dataset_constructor, + messages_format_preprocessor, ) from llmfoundry.utils.exceptions import DatasetTooSmallError @@ -60,3 +62,48 @@ def get_local_world_size(self): new=MockDataset, ): dataset_constructor.build_from_streaming() + + +def test_QA_format_preprocessor(): + inp = { + 'Q': 'What is the capital of France?', + 'A': 'Paris', + 'meta': { + 'a': 'b', + }, + } + + expected_messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + output = QA_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + for i, message in enumerate(output['messages']): + expected_message = expected_messages[i] + for k, v in message.items(): + assert k in expected_message + assert v == expected_message[k] + + +def test_messages_format_preprocessor(): + messages = [{ + 'role': 'user', + 'content': 'What is the capital of France?', + }, { + 'role': 'assistant', + 'content': 'Paris', + }] + inp = { + 'messages': messages, + 'other_key': 'other_value', + } + + output = messages_format_preprocessor(inp) + assert len(output) == 1 + assert 'messages' in output + assert output['messages'] == messages