diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md
index 0a0e3effc2a946..8d49fa5c80ee28 100644
--- a/docs/source/en/chat_templating.md
+++ b/docs/source/en/chat_templating.md
@@ -233,6 +233,332 @@ The sun.
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
+## Advanced: Extra inputs to chat templates
+
+The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
+argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
+chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
+strings, lists, dicts or whatever else you want.
+
+That said, there are some common use-cases for these extra arguments,
+such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
+we have some opinionated recommendations about what the names and formats of these arguments should be, which are
+described in the sections below. We encourage model authors to make their chat templates compatible with this format,
+to make it easy to transfer tool-calling code between models.
+
+## Advanced: Tool use / function calling
+
+"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
+to a tool-use model, you can simply pass a list of functions to the `tools` argument:
+
+```python
+import datetime
+
+def current_time():
+ """Get the current local time as a string."""
+ return str(datetime.now())
+
+def multiply(a: float, b: float):
+ """
+ A function that multiplies two numbers
+
+ Args:
+ a: The first number to multiply
+ b: The second number to multiply
+ """
+ return a * b
+
+tools = [current_time, multiply]
+
+model_input = tokenizer.apply_chat_template(
+ messages,
+ tools=tools
+)
+```
+
+In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
+correctly as tools. Specifically, you should follow these rules:
+
+- The function should have a descriptive name
+- Every argument must have a type hint
+- The function must have a docstring in the standard Google style (in other words, an initial function description
+ followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
+- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
+ `a (int): The first number to multiply`. Type hints should go in the function header instead.
+- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
+ because most tool-use models ignore them.
+
+### Passing tool results to the model
+
+The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use
+one? If that happens, you should:
+
+1. Parse the model's output to get the tool name(s) and arguments.
+2. Add the model's tool call(s) to the conversation.
+3. Call the corresponding function(s) with those arguments.
+4. Add the result(s) to the conversation
+
+### A complete tool use example
+
+Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model,
+as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the
+memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
+or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use
+and offer even stronger performance.
+
+First, let's load our model and tokenizer:
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B"
+
+tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13")
+model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
+```
+
+Next, let's define a list of tools:
+
+```python
+def get_current_temperature(location: str, unit: str) -> float:
+ """
+ Get the current temperature at a location.
+
+ Args:
+ location: The location to get the temperature for, in the format "City, Country"
+ unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
+ Returns:
+ The current temperature at the specified location in the specified units, as a float.
+ """
+ return 22. # A real function should probably actually get the temperature!
+
+def get_current_wind_speed(location: str) -> float:
+ """
+ Get the current wind speed in km/h at a given location.
+
+ Args:
+ location: The location to get the temperature for, in the format "City, Country"
+ Returns:
+ The current wind speed at the given location in km/h, as a float.
+ """
+ return 6. # A real function should probably actually get the wind speed!
+
+tools = [get_current_temperature, get_current_wind_speed]
+```
+
+Now, let's set up a conversation for our bot:
+
+```python
+messages = [
+ {"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."},
+ {"role": "user", "content": "Hey, what's the temperature in Paris right now?"}
+]
+```
+
+Now, let's apply the chat template and generate a response:
+
+```python
+inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
+inputs = {k: v.to(model.device) for k, v in inputs.items()}
+out = model.generate(**inputs, max_new_tokens=128)
+print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
+```
+
+And we get:
+
+```text
+
+{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"}
+<|im_end|>
+```
+
+The model has called the function with valid arguments, in the format requested by the function docstring. It has
+inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units,
+the temperature in France should certainly be displayed in Celsius.
+
+Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs
+are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response
+corresponds to which call. You can generate them any way you like, but they should be unique within each chat.
+
+```python
+tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call
+tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
+messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]})
+```
+
+
+Now that we've added the tool call to the conversation, we can call the function and append the result to the
+conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
+that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above.
+
+```python
+messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"})
+```
+
+Finally, let's let the assistant read the function outputs and continue chatting with the user:
+
+```python
+inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
+inputs = {k: v.to(model.device) for k, v in inputs.items()}
+out = model.generate(**inputs, max_new_tokens=128)
+print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
+```
+
+And we get:
+
+```text
+The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
+```
+
+Although this was a simple demo with dummy tools and a single call, the same technique works with
+multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
+agents with real-time information, computational tools like calculators, or access to large databases.
+
+
+Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and
+match tool calls to results using the ordering, and there are several models that use neither and only issue one tool
+call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we
+recommend structuring your tools calls like we've shown here, and returning tool results in the order that
+they were issued by the model. The chat templates on each model should handle the rest.
+
+
+### Understanding tool schemas
+
+Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
+[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
+are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
+never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
+need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
+to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
+return the response in the chat.
+
+Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
+follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
+you can handle the conversion manually. Here is an example of a manual schema conversion.
+
+```python
+from transformers.utils import get_json_schema
+
+def multiply(a: float, b: float):
+ """
+ A function that multiplies two numbers
+
+ Args:
+ a: The first number to multiply
+ b: The second number to multiply
+ """
+ return a * b
+
+schema = get_json_schema(multiply)
+print(schema)
+```
+
+This will yield:
+
+```json
+{
+ "type": "function",
+ "function": {
+ "name": "multiply",
+ "description": "A function that multiplies two numbers",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "a": {
+ "type": "number",
+ "description": "The first number to multiply"
+ },
+ "b": {
+ "type": "number",
+ "description": "The second number to multiply"
+ }
+ },
+ "required": ["a", "b"]
+ }
+ }
+}
+```
+
+If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
+all. JSON schemas can be passed directly to the `tools` argument of
+`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
+though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
+recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
+to a minimum.
+
+Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
+
+```python
+# A simple function that takes no arguments
+current_time = {
+ "type": "function",
+ "function": {
+ "name": "current_time",
+ "description": "Get the current local time as a string.",
+ "parameters": {
+ 'type': 'object',
+ 'properties': {}
+ }
+ }
+}
+
+# A more complete function that takes two numerical arguments
+multiply = {
+ 'type': 'function',
+ 'function': {
+ 'name': 'multiply',
+ 'description': 'A function that multiplies two numbers',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'a': {
+ 'type': 'number',
+ 'description': 'The first number to multiply'
+ },
+ 'b': {
+ 'type': 'number', 'description': 'The second number to multiply'
+ }
+ },
+ 'required': ['a', 'b']
+ }
+ }
+}
+
+model_input = tokenizer.apply_chat_template(
+ messages,
+ tools = [current_time, multiply]
+)
+```
+
+## Advanced: Retrieval-augmented generation
+
+"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
+to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
+recommendation for RAG models is that their template
+should accept a `documents` argument. This should be a list of documents, where each "document"
+is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
+than the JSON schemas used for tools, no helper functions are necessary.
+
+Here's an example of a RAG template in action:
+
+```python
+document1 = {
+ "title": "The Moon: Our Age-Old Foe",
+ "contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
+}
+
+document2 = {
+ "title": "The Sun: Our Age-Old Friend",
+ "contents": "Although often underappreciated, the sun provides several notable benefits..."
+}
+
+model_input = tokenizer.apply_chat_template(
+ messages,
+ documents=[document1, document2]
+)
+```
+
## Advanced: How do chat templates work?
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index bc9417f3587c1e..14a31560c6f1a8 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -28,6 +28,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
+from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
@@ -47,6 +48,7 @@
copy_func,
download_url,
extract_commit_hash,
+ get_json_schema,
is_flax_available,
is_jax_tensor,
is_mlx_available,
@@ -1683,6 +1685,8 @@ def get_vocab(self) -> Dict[str, int]:
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
+ tools: Optional[List[Dict]] = None,
+ documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
@@ -1703,8 +1707,21 @@ def apply_chat_template(
Args:
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
with "role" and "content" keys, representing the chat history so far.
- chat_template (str, *optional*): A Jinja template to use for this conversion. If
- this is not passed, the model's default chat template will be used instead.
+ tools (`List[Dict]`, *optional*):
+ A list of tools (callable functions) that will be accessible to the model. If the template does not
+ support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+ giving the name, description and argument types for the tool. See our
+ [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+ for more information.
+ documents (`List[Dict[str, str]]`, *optional*):
+ A list of dicts representing documents that will be accessible to the model if it is performing RAG
+ (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+ effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+ see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+ for examples of passing documents with chat templates.
+ chat_template (`str`, *optional*):
+ A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
+ argument, as the model's template will be used by default.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
@@ -1802,6 +1819,27 @@ def apply_chat_template(
conversations = [conversation]
is_batched = False
+ # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
+ if tools is not None:
+ tool_schemas = []
+ for tool in tools:
+ if isinstance(tool, dict):
+ tool_schemas.append(tool)
+ elif isfunction(tool):
+ tool_schemas.append(get_json_schema(tool))
+ else:
+ raise ValueError(
+ "Tools should either be a JSON schema, or a callable function with type hints "
+ "and a docstring suitable for auto-conversion to a schema."
+ )
+ else:
+ tool_schemas = None
+
+ if documents is not None:
+ for document in documents:
+ if not isinstance(document, dict):
+ raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
+
rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
@@ -1809,7 +1847,11 @@ def apply_chat_template(
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
- messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
+ messages=chat,
+ tools=tool_schemas,
+ documents=documents,
+ add_generation_prompt=add_generation_prompt,
+ **template_kwargs,
)
rendered.append(rendered_chat)
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 51c1113cab3c2c..ce87bc8623132e 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -21,6 +21,7 @@
from .. import __version__
from .backbone_utils import BackboneConfigMixin, BackboneMixin
+from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from .doc import (
add_code_sample_docstrings,
diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py
new file mode 100644
index 00000000000000..ee6173f2a1532b
--- /dev/null
+++ b/src/transformers/utils/chat_template_utils.py
@@ -0,0 +1,316 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import re
+from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
+
+
+BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
+# Extracts the initial segment of the docstring, containing the function description
+description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
+# Extracts the Args: block from the docstring
+args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
+# Splits the Args: block into individual arguments
+args_split_re = re.compile(
+ r"""
+(?:^|\n) # Match the start of the args block, or a newline
+\s*(\w+):\s* # Capture the argument name and strip spacing
+(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
+(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
+""",
+ re.DOTALL | re.VERBOSE,
+)
+# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
+returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
+
+
+class TypeHintParsingException(Exception):
+ """Exception raised for errors in parsing type hints to generate JSON schemas"""
+
+ pass
+
+
+class DocstringParsingException(Exception):
+ """Exception raised for errors in parsing docstrings to generate JSON schemas"""
+
+ pass
+
+
+def _get_json_schema_type(param_type: str) -> Dict[str, str]:
+ type_mapping = {
+ int: {"type": "integer"},
+ float: {"type": "number"},
+ str: {"type": "string"},
+ bool: {"type": "boolean"},
+ Any: {},
+ }
+ return type_mapping.get(param_type, {"type": "object"})
+
+
+def _parse_type_hint(hint: str) -> Dict:
+ origin = get_origin(hint)
+ args = get_args(hint)
+
+ if origin is None:
+ try:
+ return _get_json_schema_type(hint)
+ except KeyError:
+ raise TypeHintParsingException(
+ "Couldn't parse this type hint, likely due to a custom class or object: ", hint
+ )
+
+ elif origin is Union:
+ # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
+ subtypes = [_parse_type_hint(t) for t in args if t != type(None)]
+ if len(subtypes) == 1:
+ # A single non-null type can be expressed directly
+ return_dict = subtypes[0]
+ elif all(isinstance(subtype["type"], str) for subtype in subtypes):
+ # A union of basic types can be expressed as a list in the schema
+ return_dict = {"type": [subtype["type"] for subtype in subtypes]}
+ else:
+ # A union of more complex types requires "anyOf"
+ return_dict = {"anyOf": subtypes}
+ if type(None) in args:
+ return_dict["nullable"] = True
+ return return_dict
+
+ elif origin is list:
+ if not args:
+ return {"type": "array"}
+ else:
+ # Lists can only have a single type argument, so recurse into it
+ return {"type": "array", "items": _parse_type_hint(args[0])}
+
+ elif origin is tuple:
+ if not args:
+ return {"type": "array"}
+ if len(args) == 1:
+ raise TypeHintParsingException(
+ f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
+ "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
+ "more than one element, we recommend "
+ "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
+ "pass the element directly."
+ )
+ if ... in args:
+ raise TypeHintParsingException(
+ "Conversion of '...' is not supported in Tuple type hints. "
+ "Use List[] types for variable-length"
+ " inputs instead."
+ )
+ return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
+
+ elif origin is dict:
+ # The JSON equivalent to a dict is 'object', which mandates that all keys are strings
+ # However, we can specify the type of the dict values with "additionalProperties"
+ out = {"type": "object"}
+ if len(args) == 2:
+ out["additionalProperties"] = _parse_type_hint(args[1])
+ return out
+
+ raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
+
+
+def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
+ type_hints = get_type_hints(func)
+ signature = inspect.signature(func)
+ required = []
+ for param_name, param in signature.parameters.items():
+ if param.annotation == inspect.Parameter.empty:
+ raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
+ if param.default == inspect.Parameter.empty:
+ required.append(param_name)
+
+ properties = {}
+ for param_name, param_type in type_hints.items():
+ properties[param_name] = _parse_type_hint(param_type)
+
+ schema = {"type": "object", "properties": properties}
+ if required:
+ schema["required"] = required
+
+ return schema
+
+
+def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
+ """
+ Parses a Google-style docstring to extract the function description,
+ argument descriptions, and return description.
+
+ Args:
+ docstring (str): The docstring to parse.
+
+ Returns:
+ The function description, arguments, and return description.
+ """
+
+ # Extract the sections
+ description_match = description_re.search(docstring)
+ args_match = args_re.search(docstring)
+ returns_match = returns_re.search(docstring)
+
+ # Clean and store the sections
+ description = description_match.group(1).strip() if description_match else None
+ docstring_args = args_match.group(1).strip() if args_match else None
+ returns = returns_match.group(1).strip() if returns_match else None
+
+ # Parsing the arguments into a dictionary
+ if docstring_args is not None:
+ docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
+ matches = args_split_re.findall(docstring_args)
+ args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
+ else:
+ args_dict = {}
+
+ return description, args_dict, returns
+
+
+def get_json_schema(func: Callable) -> Dict:
+ """
+ This function generates a JSON schema for a given function, based on its docstring and type hints. This is
+ mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
+ the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
+ that the function has a docstring, and that each argument has a description in the docstring, in the standard
+ Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
+
+ Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
+ optional because most chat templates ignore the return value of the function.
+
+ Args:
+ func: The function to generate a JSON schema for.
+
+ Returns:
+ A dictionary containing the JSON schema for the function.
+
+ Examples:
+ ```python
+ >>> def multiply(x: float, y: float):
+ >>> '''
+ >>> A function that multiplies two numbers
+ >>>
+ >>> Args:
+ >>> x: The first number to multiply
+ >>> y: The second number to multiply
+ >>> '''
+ >>> return x * y
+ >>>
+ >>> print(get_json_schema(multiply))
+ {
+ "name": "multiply",
+ "description": "A function that multiplies two numbers",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {"type": "number", "description": "The first number to multiply"},
+ "y": {"type": "number", "description": "The second number to multiply"}
+ },
+ "required": ["x", "y"]
+ }
+ }
+ ```
+
+ The general use for these schemas is that they are used to generate tool descriptions for chat templates that
+ support them, like so:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from transformers.utils import get_json_schema
+ >>>
+ >>> def multiply(x: float, y: float):
+ >>> '''
+ >>> A function that multiplies two numbers
+ >>>
+ >>> Args:
+ >>> x: The first number to multiply
+ >>> y: The second number to multiply
+ >>> return x * y
+ >>> '''
+ >>>
+ >>> multiply_schema = get_json_schema(multiply)
+ >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
+ >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
+ >>> formatted_chat = tokenizer.apply_chat_template(
+ >>> messages,
+ >>> tools=[multiply_schema],
+ >>> chat_template="tool_use",
+ >>> return_dict=True,
+ >>> return_tensors="pt",
+ >>> add_generation_prompt=True
+ >>> )
+ >>> # The formatted chat can now be passed to model.generate()
+ ```
+
+ Each argument description can also have an optional `(choices: ...)` block at the end, such as
+ `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
+ only be parsed correctly if it is at the end of the line:
+
+ ```python
+ >>> def drink_beverage(beverage: str):
+ >>> '''
+ >>> A function that drinks a beverage
+ >>>
+ >>> Args:
+ >>> beverage: The beverage to drink (choices: ["tea", "coffee"])
+ >>> '''
+ >>> pass
+ >>>
+ >>> print(get_json_schema(drink_beverage))
+ ```
+ {
+ 'name': 'drink_beverage',
+ 'description': 'A function that drinks a beverage',
+ 'parameters': {
+ 'type': 'object',
+ 'properties': {
+ 'beverage': {
+ 'type': 'string',
+ 'enum': ['tea', 'coffee'],
+ 'description': 'The beverage to drink'
+ }
+ },
+ 'required': ['beverage']
+ }
+ }
+ """
+ doc = inspect.getdoc(func)
+ if not doc:
+ raise DocstringParsingException(
+ f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
+ )
+ doc = doc.strip()
+ main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
+
+ json_schema = _convert_type_hints_to_json_schema(func)
+ if (return_dict := json_schema["properties"].pop("return", None)) is not None:
+ if return_doc is not None: # We allow a missing return docstring since most templates ignore it
+ return_dict["description"] = return_doc
+ for arg, schema in json_schema["properties"].items():
+ if arg not in param_descriptions:
+ raise DocstringParsingException(
+ f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
+ )
+ desc = param_descriptions[arg]
+ enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
+ if enum_choices:
+ schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
+ desc = enum_choices.string[: enum_choices.start()].strip()
+ schema["description"] = desc
+
+ output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
+ if return_dict is not None:
+ output["return"] = return_dict
+ return {"type": "function", "function": output}
diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py
new file mode 100644
index 00000000000000..cff31c1f8a3483
--- /dev/null
+++ b/tests/utils/test_chat_template_utils.py
@@ -0,0 +1,476 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from typing import List, Optional, Tuple, Union
+
+from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
+
+
+class JsonSchemaGeneratorTest(unittest.TestCase):
+ def test_simple_function(self):
+ def fn(x: int):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {"x": {"type": "integer", "description": "The input"}},
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_no_arguments(self):
+ def fn():
+ """
+ Test function
+ """
+ return True
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {"type": "object", "properties": {}},
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_union(self):
+ def fn(x: Union[int, float]):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_optional(self):
+ def fn(x: Optional[int]):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_default_arg(self):
+ def fn(x: int = 42):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_nested_list(self):
+ def fn(x: List[List[Union[str, int]]]):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {
+ "type": "array",
+ "items": {"type": "array", "items": {"type": ["string", "integer"]}},
+ "description": "The input",
+ }
+ },
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_multiple_arguments(self):
+ def fn(x: int, y: str):
+ """
+ Test function
+
+ Args:
+ x: The input
+ y: Also the input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {"type": "integer", "description": "The input"},
+ "y": {"type": "string", "description": "Also the input"},
+ },
+ "required": ["x", "y"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_multiple_complex_arguments(self):
+ def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
+ """
+ Test function
+
+ Args:
+ x: The input
+ y: Also the input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
+ "y": {
+ "type": ["integer", "string"],
+ "nullable": True,
+ "description": "Also the input",
+ },
+ },
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_missing_docstring(self):
+ def fn(x: int):
+ return x
+
+ with self.assertRaises(DocstringParsingException):
+ get_json_schema(fn)
+
+ def test_missing_param_docstring(self):
+ def fn(x: int):
+ """
+ Test function
+ """
+ return x
+
+ with self.assertRaises(DocstringParsingException):
+ get_json_schema(fn)
+
+ def test_missing_type_hint(self):
+ def fn(x):
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ with self.assertRaises(TypeHintParsingException):
+ get_json_schema(fn)
+
+ def test_return_value(self):
+ def fn(x: int) -> int:
+ """
+ Test function
+
+ Args:
+ x: The input
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {"x": {"type": "integer", "description": "The input"}},
+ "required": ["x"],
+ },
+ "return": {"type": "integer"},
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_return_value_docstring(self):
+ def fn(x: int) -> int:
+ """
+ Test function
+
+ Args:
+ x: The input
+
+
+ Returns:
+ The output
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {"x": {"type": "integer", "description": "The input"}},
+ "required": ["x"],
+ },
+ "return": {"type": "integer", "description": "The output"},
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_tuple(self):
+ def fn(x: Tuple[int, str]):
+ """
+ Test function
+
+ Args:
+ x: The input
+
+
+ Returns:
+ The output
+ """
+ return x
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {
+ "type": "array",
+ "prefixItems": [{"type": "integer"}, {"type": "string"}],
+ "description": "The input",
+ }
+ },
+ "required": ["x"],
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_single_element_tuple_fails(self):
+ def fn(x: Tuple[int]):
+ """
+ Test function
+
+ Args:
+ x: The input
+
+
+ Returns:
+ The output
+ """
+ return x
+
+ # Single-element tuples should just be the type itself, or List[type] for variable-length inputs
+ with self.assertRaises(TypeHintParsingException):
+ get_json_schema(fn)
+
+ def test_ellipsis_type_fails(self):
+ def fn(x: Tuple[int, ...]):
+ """
+ Test function
+
+ Args:
+ x: The input
+
+
+ Returns:
+ The output
+ """
+ return x
+
+ # Variable length inputs should be specified with List[type], not Tuple[type, ...]
+ with self.assertRaises(TypeHintParsingException):
+ get_json_schema(fn)
+
+ def test_enum_extraction(self):
+ def fn(temperature_format: str):
+ """
+ Test function
+
+ Args:
+ temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
+
+
+ Returns:
+ The temperature
+ """
+ return -40.0
+
+ # Let's see if that gets correctly parsed as an enum
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "temperature_format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature format to use",
+ }
+ },
+ "required": ["temperature_format"],
+ },
+ }
+
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_multiline_docstring_with_types(self):
+ def fn(x: int, y: int):
+ """
+ Test function
+
+ Args:
+ x: The first input
+
+ y: The second input. This is a longer description
+ that spans multiple lines with indentation and stuff.
+
+ Returns:
+ God knows what
+ """
+ pass
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {"type": "integer", "description": "The first input"},
+ "y": {
+ "type": "integer",
+ "description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
+ },
+ },
+ "required": ["x", "y"],
+ },
+ }
+
+ self.assertEqual(schema["function"], expected_schema)
+
+ def test_everything_all_at_once(self):
+ def fn(
+ x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
+ ) -> Tuple[int, str]:
+ """
+ Test function with multiple args, and docstring args that we have to strip out.
+
+ Args:
+ x: The first input. It's got a big multiline
+ description and also contains
+ (choices: ["a", "b", "c"])
+
+ y: The second input. It's a big list with a single-line description.
+
+ z: The third input. It's some kind of tuple with a default arg.
+
+ Returns:
+ The output. The return description is also a big multiline
+ description that spans multiple lines.
+ """
+ pass
+
+ schema = get_json_schema(fn)
+ expected_schema = {
+ "name": "fn",
+ "description": "Test function with multiple args, and docstring args that we have to strip out.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {
+ "type": "string",
+ "enum": ["a", "b", "c"],
+ "description": "The first input. It's got a big multiline description and also contains",
+ },
+ "y": {
+ "type": "array",
+ "items": {"type": ["string", "integer"]},
+ "nullable": True,
+ "description": "The second input. It's a big list with a single-line description.",
+ },
+ "z": {
+ "type": "array",
+ "prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}],
+ "description": "The third input. It's some kind of tuple with a default arg.",
+ },
+ },
+ "required": ["x", "y"],
+ },
+ "return": {
+ "type": "array",
+ "prefixItems": [{"type": "integer"}, {"type": "string"}],
+ "description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
+ },
+ }
+ self.assertEqual(schema["function"], expected_schema)