diff --git a/docs/source/data_utils.mdx b/docs/source/data_utils.mdx index 6bbfc5b32d..9b8391278d 100644 --- a/docs/source/data_utils.mdx +++ b/docs/source/data_utils.mdx @@ -1,15 +1,29 @@ -## Data Utilities +# Data Utilities + +## is_conversational [[autodoc]] is_conversational +## apply_chat_template + [[autodoc]] apply_chat_template +## maybe_apply_chat_template + [[autodoc]] maybe_apply_chat_template +## extract_prompt + [[autodoc]] extract_prompt +## maybe_extract_prompt + [[autodoc]] maybe_extract_prompt +## unpair_preference_dataset + [[autodoc]] unpair_preference_dataset +## maybe_unpair_preference_dataset + [[autodoc]] maybe_unpair_preference_dataset diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index a4eb13683c..aebf7396ac 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -17,7 +17,7 @@ from datasets import Dataset, DatasetDict from parameterized import parameterized -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from trl.data_utils import ( apply_chat_template, @@ -196,6 +196,37 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): self.assertIsInstance(result["label"], bool) self.assertEqual(result["label"], example["label"]) + def test_apply_chat_template_with_tools(self): + tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + # Define test case + test_case = { + "prompt": [ + {"content": "Whats the temperature in London?", "role": "user"}, + ] + } + # Test with tools + result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) + + # Verify tools are included in the output + self.assertIn("get_current_temperature", result_with_tools["prompt"]) + + # Test without tools + result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) + + # Verify tools are not included in the output + self.assertNotIn("get_current_temperature", result_without_tools["prompt"]) + class UnpairPreferenceDatasetTester(unittest.TestCase): paired_dataset = Dataset.from_dict( diff --git a/trl/data_utils.py b/trl/data_utils.py index 9bc68f1d5c..8c8f448adf 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, TypeVar +from typing import Any, Callable, Optional, Sequence, TypeVar, Union from datasets import Dataset, DatasetDict from transformers import PreTrainedTokenizer @@ -61,9 +61,13 @@ def is_conversational(example: dict[str, Any]) -> bool: return False -def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer) -> dict[str, str]: +def apply_chat_template( + example: dict[str, list[dict[str, str]]], + tokenizer: PreTrainedTokenizer, + tools: Optional[list[Union[dict, Callable]]] = None, +) -> dict[str, str]: r""" - Apply a chat template to a conversational example. + Apply a chat template to a conversational example along with the schema for a list of functions in `tools`. For more details, see [`maybe_apply_chat_template`]. """ @@ -82,30 +86,36 @@ def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: Pre # Apply the chat template to the whole conversation if "messages" in example: - messages = tokenizer.apply_chat_template(example["messages"], tokenize=False) + messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False) # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: - prompt = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True + ) # Apply the chat template to the entire prompt + completion if "prompt" in example: # explicit prompt and prompt-completion case if "chosen" in example: - prompt_chosen = tokenizer.apply_chat_template(example["prompt"] + example["chosen"], tokenize=False) + prompt_chosen = tokenizer.apply_chat_template( + example["prompt"] + example["chosen"], tools=tools, tokenize=False + ) chosen = prompt_chosen[len(prompt) :] if "rejected" in example and "prompt" in example: # explicit prompt - prompt_rejected = tokenizer.apply_chat_template(example["prompt"] + example["rejected"], tokenize=False) + prompt_rejected = tokenizer.apply_chat_template( + example["prompt"] + example["rejected"], tools=tools, tokenize=False + ) rejected = prompt_rejected[len(prompt) :] if "completion" in example: prompt_completion = tokenizer.apply_chat_template( - example["prompt"] + example["completion"], tokenize=False + example["prompt"] + example["completion"], tools=tools, tokenize=False ) completion = prompt_completion[len(prompt) :] else: # implicit prompt case if "chosen" in example: - chosen = tokenizer.apply_chat_template(example["chosen"], tokenize=False) + chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False) if "rejected" in example: - rejected = tokenizer.apply_chat_template(example["rejected"], tokenize=False) + rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False) # Ensure that the prompt is the initial part of the prompt-completion string if "prompt" in example: @@ -140,7 +150,9 @@ def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: Pre def maybe_apply_chat_template( - example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer + example: dict[str, list[dict[str, str]]], + tokenizer: PreTrainedTokenizer, + tools: Optional[list[Union[dict, Callable]]] = None, ) -> dict[str, str]: r""" If the example is in a conversational format, apply a chat template to it. @@ -159,9 +171,11 @@ def maybe_apply_chat_template( For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of messages, where each message is a dictionary with keys `"role"` and `"content"`. - tokenizer (`PreTrainedTokenizer`): The tokenizer to apply the chat template with. + tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): + A list of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect Returns: `dict[str, str]`: The formatted example with the chat template applied. @@ -184,7 +198,7 @@ def maybe_apply_chat_template( ``` """ if is_conversational(example): - return apply_chat_template(example, tokenizer) + return apply_chat_template(example, tokenizer, tools) else: return example diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py index 1be86580aa..69c73b219b 100644 --- a/trl/extras/dataset_formatting.py +++ b/trl/extras/dataset_formatting.py @@ -27,20 +27,24 @@ } -def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): +def conversations_formatting_function( + tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None +): r""" return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer - apply chat template to the dataset + apply chat template to the dataset along with the schema of the list of functions in the tools list. """ def format_dataset(examples): if isinstance(examples[messages_field][0], list): output_texts = [] for i in range(len(examples[messages_field])): - output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) + output_texts.append( + tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools) + ) return output_texts else: - return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) + return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools) return format_dataset @@ -72,7 +76,7 @@ def format_dataset(examples): def get_formatting_func_from_dataset( - dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer + dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None ) -> Optional[Callable]: r""" Finds the correct formatting function based on the dataset structure. Currently supported datasets are: @@ -90,11 +94,11 @@ def get_formatting_func_from_dataset( if "messages" in dataset.features: if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: logging.info("Formatting dataset with chatml format") - return conversations_formatting_function(tokenizer, "messages") + return conversations_formatting_function(tokenizer, "messages", tools) if "conversations" in dataset.features: if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: logging.info("Formatting dataset with chatml format") - return conversations_formatting_function(tokenizer, "conversations") + return conversations_formatting_function(tokenizer, "conversations", tools) elif dataset.features == FORMAT_MAPPING["instruction"]: logging.info("Formatting dataset with instruction format") return instructions_formatting_function(tokenizer)