Skip to content

Commit

Permalink
🔨 Support for tools for data utils (#2455)
Browse files Browse the repository at this point in the history
* function calling training support for SFTTraining

* adding tool support to data_utils

* adding test for function calling tokenizer

* reverting changes to sfttrainer and config,added maybe_apply_chat_template

* arg for maybe_apply_chat_templates docstring

* Doc sectioning

* minor test modification

* minor doc modification

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent b3aff44 commit e3e171a
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 22 deletions.
16 changes: 15 additions & 1 deletion docs/source/data_utils.mdx
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
## Data Utilities
# Data Utilities

## is_conversational

[[autodoc]] is_conversational

## apply_chat_template

[[autodoc]] apply_chat_template

## maybe_apply_chat_template

[[autodoc]] maybe_apply_chat_template

## extract_prompt

[[autodoc]] extract_prompt

## maybe_extract_prompt

[[autodoc]] maybe_extract_prompt

## unpair_preference_dataset

[[autodoc]] unpair_preference_dataset

## maybe_unpair_preference_dataset

[[autodoc]] maybe_unpair_preference_dataset
33 changes: 32 additions & 1 deletion tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from datasets import Dataset, DatasetDict
from parameterized import parameterized
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer

from trl.data_utils import (
apply_chat_template,
Expand Down Expand Up @@ -196,6 +196,37 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example):
self.assertIsInstance(result["label"], bool)
self.assertEqual(result["label"], example["label"])

def test_apply_chat_template_with_tools(self):
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")

# Define dummy test tools
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.
Args:
location: The location to get the temperature for
"""
return 22.0

# Define test case
test_case = {
"prompt": [
{"content": "Whats the temperature in London?", "role": "user"},
]
}
# Test with tools
result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature])

# Verify tools are included in the output
self.assertIn("get_current_temperature", result_with_tools["prompt"])

# Test without tools
result_without_tools = apply_chat_template(test_case, tokenizer, tools=None)

# Verify tools are not included in the output
self.assertNotIn("get_current_temperature", result_without_tools["prompt"])


class UnpairPreferenceDatasetTester(unittest.TestCase):
paired_dataset = Dataset.from_dict(
Expand Down
40 changes: 27 additions & 13 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Sequence, TypeVar
from typing import Any, Callable, Optional, Sequence, TypeVar, Union

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -61,9 +61,13 @@ def is_conversational(example: dict[str, Any]) -> bool:
return False


def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer) -> dict[str, str]:
def apply_chat_template(
example: dict[str, list[dict[str, str]]],
tokenizer: PreTrainedTokenizer,
tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
r"""
Apply a chat template to a conversational example.
Apply a chat template to a conversational example along with the schema for a list of functions in `tools`.
For more details, see [`maybe_apply_chat_template`].
"""
Expand All @@ -82,30 +86,36 @@ def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: Pre

# Apply the chat template to the whole conversation
if "messages" in example:
messages = tokenizer.apply_chat_template(example["messages"], tokenize=False)
messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False)

# Apply the chat template to the prompt, adding the generation prompt
if "prompt" in example:
prompt = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=True)
prompt = tokenizer.apply_chat_template(
example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True
)

# Apply the chat template to the entire prompt + completion
if "prompt" in example: # explicit prompt and prompt-completion case
if "chosen" in example:
prompt_chosen = tokenizer.apply_chat_template(example["prompt"] + example["chosen"], tokenize=False)
prompt_chosen = tokenizer.apply_chat_template(
example["prompt"] + example["chosen"], tools=tools, tokenize=False
)
chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example: # explicit prompt
prompt_rejected = tokenizer.apply_chat_template(example["prompt"] + example["rejected"], tokenize=False)
prompt_rejected = tokenizer.apply_chat_template(
example["prompt"] + example["rejected"], tools=tools, tokenize=False
)
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"], tokenize=False
example["prompt"] + example["completion"], tools=tools, tokenize=False
)
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
if "chosen" in example:
chosen = tokenizer.apply_chat_template(example["chosen"], tokenize=False)
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False)
if "rejected" in example:
rejected = tokenizer.apply_chat_template(example["rejected"], tokenize=False)
rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False)

# Ensure that the prompt is the initial part of the prompt-completion string
if "prompt" in example:
Expand Down Expand Up @@ -140,7 +150,9 @@ def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: Pre


def maybe_apply_chat_template(
example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer
example: dict[str, list[dict[str, str]]],
tokenizer: PreTrainedTokenizer,
tools: Optional[list[Union[dict, Callable]]] = None,
) -> dict[str, str]:
r"""
If the example is in a conversational format, apply a chat template to it.
Expand All @@ -159,9 +171,11 @@ def maybe_apply_chat_template(
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
tokenizer (`PreTrainedTokenizer`):
The tokenizer to apply the chat template with.
tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`):
A list of tools (callable functions) that will be accessible to the model.
If the template does not support function calling, this argument will have no effect
Returns:
`dict[str, str]`: The formatted example with the chat template applied.
Expand All @@ -184,7 +198,7 @@ def maybe_apply_chat_template(
```
"""
if is_conversational(example):
return apply_chat_template(example, tokenizer)
return apply_chat_template(example, tokenizer, tools)
else:
return example

Expand Down
18 changes: 11 additions & 7 deletions trl/extras/dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@
}


def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
def conversations_formatting_function(
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None
):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
apply chat template to the dataset along with the schema of the list of functions in the tools list.
"""

def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
output_texts.append(
tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools)
)
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools)

return format_dataset

Expand Down Expand Up @@ -72,7 +76,7 @@ def format_dataset(examples):


def get_formatting_func_from_dataset(
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None
) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
Expand All @@ -90,11 +94,11 @@ def get_formatting_func_from_dataset(
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
return conversations_formatting_function(tokenizer, "messages", tools)
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
return conversations_formatting_function(tokenizer, "conversations", tools)
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
Expand Down

0 comments on commit e3e171a

Please sign in to comment.