Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat Template support for function calling and RAG #30621

Merged
merged 69 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
3ad5bd4
First draft, still missing automatic function conversion
Rocketknight1 May 2, 2024
59e2cb6
First draft of the automatic schema generator
Rocketknight1 May 3, 2024
8f7655d
Lots of small fixes
Rocketknight1 May 3, 2024
0b2ead3
the walrus has betrayed me
Rocketknight1 May 3, 2024
cb67fd2
please stop committing your debug breakpoints
Rocketknight1 May 3, 2024
41df7d1
Lots of cleanup and edge cases, looking better now
Rocketknight1 May 3, 2024
eec2486
Comments and bugfixes for the type hint parser
Rocketknight1 May 7, 2024
cf2b8da
More cleanup
Rocketknight1 May 7, 2024
ad0984b
Add tests, update schema generator
Rocketknight1 May 7, 2024
c9fb3de
Update tests, proper handling of return values
Rocketknight1 May 8, 2024
10be3b1
Small docstring change
Rocketknight1 May 8, 2024
d9e6454
More doc updates
Rocketknight1 May 8, 2024
80addcf
More doc updates
Rocketknight1 May 8, 2024
d3d677b
Add json_schema decorator
Rocketknight1 May 9, 2024
a7d241c
Clean up the TODOs and finish the docs
Rocketknight1 May 9, 2024
fad9ae2
self.maxDiff = None to see the whole diff for the nested list test
Rocketknight1 May 9, 2024
d202bfe
add import for add_json_schema
Rocketknight1 May 9, 2024
11cf8f5
Quick test fix
Rocketknight1 May 9, 2024
8962c42
Fix something that was bugging me in the chat template docstring
Rocketknight1 May 9, 2024
6f4a897
Less "anyOf" when unnecessary
Rocketknight1 May 10, 2024
6462de2
Support return types for the templates that need them
Rocketknight1 May 16, 2024
a49b68e
Proper return type tests
Rocketknight1 May 16, 2024
c8a021e
Switch to Google format docstrings
Rocketknight1 May 17, 2024
69b6d31
Update chat templating docs to match new format
Rocketknight1 May 17, 2024
7f20d44
Stop putting the return type in with the other parameters
Rocketknight1 May 21, 2024
c3cf872
Add Tuple support
Rocketknight1 May 22, 2024
098780d
No more decorator - we just do it implicitly!
Rocketknight1 May 23, 2024
0ad7fde
Add enum support to get_json_schema
Rocketknight1 May 23, 2024
90a3c5b
Update docstring
Rocketknight1 May 23, 2024
ab2e741
Add copyright header
Rocketknight1 May 23, 2024
b2563c2
Update src/transformers/tokenization_utils_base.py
Rocketknight1 May 23, 2024
24c0589
Update docs/source/en/chat_templating.md
Rocketknight1 May 23, 2024
49f1e97
Update src/transformers/utils/chat_template_utils.py
Rocketknight1 May 23, 2024
1036c5a
Update src/transformers/utils/chat_template_utils.py
Rocketknight1 May 23, 2024
cdbc9bc
Add copyright header
Rocketknight1 May 23, 2024
dbb157f
make fixup
Rocketknight1 May 23, 2024
6171a9f
Fix indentation
Rocketknight1 May 23, 2024
098d6e6
Reformat chat_template_utils
Rocketknight1 May 23, 2024
0a92408
Correct return value
Rocketknight1 May 23, 2024
a4dc182
Make regexes module-level
Rocketknight1 May 23, 2024
390596e
Support more complex, multi-line arg docstrings
Rocketknight1 May 24, 2024
5568cb3
Update error message for ...
Rocketknight1 May 24, 2024
ae443e2
Update ruff
Rocketknight1 May 24, 2024
85cba4e
Add document type validation
Rocketknight1 May 24, 2024
23692c3
Refactor docs
Rocketknight1 May 24, 2024
6ebad3a
Refactor docs
Rocketknight1 May 24, 2024
19a9624
Refactor docs
Rocketknight1 May 24, 2024
28b5bc2
Clean up Tuple error
Rocketknight1 May 24, 2024
880bc99
Add an extra test for very complex defs and docstrings and clean ever…
Rocketknight1 May 24, 2024
58f69db
Document enum block
Rocketknight1 May 24, 2024
df3123e
Quick test fixes
Rocketknight1 May 24, 2024
cfb1190
Stop supporting type hints in docstring to fix bugs and simplify the …
Rocketknight1 May 28, 2024
fc66acc
Update docs for the regex change
Rocketknight1 May 28, 2024
bb6ba18
Clean up enum regex
Rocketknight1 May 28, 2024
eda660f
Wrap functions in {"type": "function", "function": ...}
Rocketknight1 May 29, 2024
50c00e4
Update src/transformers/utils/chat_template_utils.py
Rocketknight1 May 28, 2024
7b396ec
Temporary tool calling commit
Rocketknight1 May 31, 2024
f64acfd
Add type hints to chat template utils, partially update docs (incompl…
Rocketknight1 Jun 5, 2024
4e839cb
Code cleanup based on @molbap's suggestion
Rocketknight1 Jun 5, 2024
8d7ca8f
Add comments to explain regexes
Rocketknight1 Jun 5, 2024
3fbd682
Fix up type parsing for unions and lists
Rocketknight1 Jun 5, 2024
76c3320
Add custom exception types and adjust tests to look for them
Rocketknight1 Jun 6, 2024
170e367
Update docs with a demo!
Rocketknight1 Jun 6, 2024
1d4952a
Docs cleanup
Rocketknight1 Jun 7, 2024
9f71f6c
Pass content as string
Rocketknight1 Jun 7, 2024
6c29d46
Update tool call formatting
Rocketknight1 Jun 10, 2024
0ab8c7f
Update docs with new function format
Rocketknight1 Jun 11, 2024
2b3e222
Update docs
Rocketknight1 Jun 11, 2024
2b19b0a
Update docs with a second tool to show the model choosing correctly
Rocketknight1 Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,332 @@ The sun.</s>

From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.

## Advanced: Extra inputs to chat templates

The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No restrictions 😬

strings, lists, dicts or whatever else you want.

That said, there are some common use-cases for these extra arguments,
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
we have some opinionated recommendations about what the names and formats of these arguments should be, which are
described in the sections below. We encourage model authors to make their chat templates compatible with this format,
to make it easy to transfer tool-calling code between models.

## Advanced: Tool use / function calling

"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
to a tool-use model, you can simply pass a list of functions to the `tools` argument:

```python
import datetime

def current_time():
"""Get the current local time as a string."""
return str(datetime.now())

def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b

tools = [current_time, multiply]

model_input = tokenizer.apply_chat_template(
messages,
tools=tools
)
```

In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
correctly as tools. Specifically, you should follow these rules:

- The function should have a descriptive name
- Every argument must have a type hint
- The function must have a docstring in the standard Google style (in other words, an initial function description
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
`a (int): The first number to multiply`. Type hints should go in the function header instead.
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
because most tool-use models ignore them.

### Passing tool results to the model

The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use
one? If that happens, you should:

1. Parse the model's output to get the tool name(s) and arguments.
2. Add the model's tool call(s) to the conversation.
3. Call the corresponding function(s) with those arguments.
4. Add the result(s) to the conversation

### A complete tool use example

Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model,
as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the
memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use
and offer even stronger performance.

First, let's load our model and tokenizer:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13")
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
```

Next, let's define a list of tools:

```python
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22. # A real function should probably actually get the temperature!

def get_current_wind_speed(location: str) -> float:
"""
Get the current wind speed in km/h at a given location.
Args:
location: The location to get the temperature for, in the format "City, Country"
Returns:
The current wind speed at the given location in km/h, as a float.
"""
return 6. # A real function should probably actually get the wind speed!

tools = [get_current_temperature, get_current_wind_speed]
```

Now, let's set up a conversation for our bot:

```python
messages = [
{"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."},
{"role": "user", "content": "Hey, what's the temperature in Paris right now?"}
]
```

Now, let's apply the chat template and generate a response:

```python
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
```

And we get:

```text
<tool_call>
{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"}
</tool_call><|im_end|>
```

The model has called the function with valid arguments, in the format requested by the function docstring. It has
inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units,
the temperature in France should certainly be displayed in Celsius.

Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs
are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response
corresponds to which call. You can generate them any way you like, but they should be unique within each chat.

```python
tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]})
```


Now that we've added the tool call to the conversation, we can call the function and append the result to the
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above.

```python
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"})
```

Finally, let's let the assistant read the function outputs and continue chatting with the user:

```python
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
```

And we get:

```text
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
```

Although this was a simple demo with dummy tools and a single call, the same technique works with
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
agents with real-time information, computational tools like calculators, or access to large databases.

<Tip>
Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and
match tool calls to results using the ordering, and there are several models that use neither and only issue one tool
call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we
recommend structuring your tools calls like we've shown here, and returning tool results in the order that
they were issued by the model. The chat templates on each model should handle the rest.
</Tip>

### Understanding tool schemas

Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
return the response in the chat.

Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
you can handle the conversion manually. Here is an example of a manual schema conversion.

```python
from transformers.utils import get_json_schema

def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
Comment on lines +443 to +451
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we need a docstring?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring will be turned in to the description in the JSON, which will in turn be formatted by the chat template and passed to the model, so we do need to keep it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's fine if there's no return var type hint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This was a tricky issue. Almost all models don't use return type hints or descriptions at all, and they are not supported in the OpenAI or Anthropic APIs. However, the Hermes class of models does use it. The only choices I had were to either break support for them in the templating system, or to allow optional return hints/descriptions. I definitely don't want to make them mandatory, since most models ignore them!

Copy link

@interstellarninja interstellarninja Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using google python guide so the docstring would have returns and raises:

  def connect_to_next_port(self, minimum: int) -> int:
    """Connects to the next available port.

    Args:
      minimum: A port value greater or equal to 1024.

    Returns:
      The new minimum port.

    Raises:
      ConnectionError: If no available port is found.
    """

we are also experimenting with function signature and call schemas that include "returns" fields specially to support chained tool-use, besides we believe that hinting the model about what the function returns helps it to call a relevant function to assist with user query.


schema = get_json_schema(multiply)
print(schema)
```

This will yield:

```json
{
"type": "function",
"function": {
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "number",
"description": "The first number to multiply"
},
"b": {
"type": "number",
"description": "The second number to multiply"
}
},
"required": ["a", "b"]
}
}
}
```

If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
all. JSON schemas can be passed directly to the `tools` argument of
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
to a minimum.

Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:

```python
# A simple function that takes no arguments
current_time = {
"type": "function",
"function": {
"name": "current_time",
"description": "Get the current local time as a string.",
"parameters": {
'type': 'object',
'properties': {}
}
}
}

# A more complete function that takes two numerical arguments
multiply = {
'type': 'function',
'function': {
'name': 'multiply',
'description': 'A function that multiplies two numbers',
'parameters': {
'type': 'object',
'properties': {
'a': {
'type': 'number',
'description': 'The first number to multiply'
},
'b': {
'type': 'number', 'description': 'The second number to multiply'
}
},
'required': ['a', 'b']
}
}
}

model_input = tokenizer.apply_chat_template(
messages,
tools = [current_time, multiply]
)
```

## Advanced: Retrieval-augmented generation

"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
recommendation for RAG models is that their template
should accept a `documents` argument. This should be a list of documents, where each "document"
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
than the JSON schemas used for tools, no helper functions are necessary.

Here's an example of a RAG template in action:

```python
document1 = {
"title": "The Moon: Our Age-Old Foe",
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
}

document2 = {
"title": "The Sun: Our Age-Old Friend",
"contents": "Although often underappreciated, the sun provides several notable benefits..."
}

model_input = tokenizer.apply_chat_template(
messages,
documents=[document1, document2]
)
```

## Advanced: How do chat templates work?

The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
Expand Down
Loading
Loading