diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 0a0e3effc2a946..8d49fa5c80ee28 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -233,6 +233,332 @@ The sun. From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column. +## Advanced: Extra inputs to chat templates + +The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword +argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use +chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass +strings, lists, dicts or whatever else you want. + +That said, there are some common use-cases for these extra arguments, +such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases, +we have some opinionated recommendations about what the names and formats of these arguments should be, which are +described in the sections below. We encourage model authors to make their chat templates compatible with this format, +to make it easy to transfer tool-calling code between models. + +## Advanced: Tool use / function calling + +"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools +to a tool-use model, you can simply pass a list of functions to the `tools` argument: + +```python +import datetime + +def current_time(): + """Get the current local time as a string.""" + return str(datetime.now()) + +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + +tools = [current_time, multiply] + +model_input = tokenizer.apply_chat_template( + messages, + tools=tools +) +``` + +In order for this to work correctly, you should write your functions in the format above, so that they can be parsed +correctly as tools. Specifically, you should follow these rules: + +- The function should have a descriptive name +- Every argument must have a type hint +- The function must have a docstring in the standard Google style (in other words, an initial function description + followed by an `Args:` block that describes the arguments, unless the function does not have any arguments. +- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not + `a (int): The first number to multiply`. Type hints should go in the function header instead. +- The function can have a return type and a `Returns:` block in the docstring. However, these are optional + because most tool-use models ignore them. + +### Passing tool results to the model + +The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use +one? If that happens, you should: + +1. Parse the model's output to get the tool name(s) and arguments. +2. Add the model's tool call(s) to the conversation. +3. Call the corresponding function(s) with those arguments. +4. Add the result(s) to the conversation + +### A complete tool use example + +Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model, +as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the +memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) +or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use +and offer even stronger performance. + +First, let's load our model and tokenizer: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B" + +tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13") +model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto") +``` + +Next, let's define a list of tools: + +```python +def get_current_temperature(location: str, unit: str) -> float: + """ + Get the current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) + Returns: + The current temperature at the specified location in the specified units, as a float. + """ + return 22. # A real function should probably actually get the temperature! + +def get_current_wind_speed(location: str) -> float: + """ + Get the current wind speed in km/h at a given location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + Returns: + The current wind speed at the given location in km/h, as a float. + """ + return 6. # A real function should probably actually get the wind speed! + +tools = [get_current_temperature, get_current_wind_speed] +``` + +Now, let's set up a conversation for our bot: + +```python +messages = [ + {"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."}, + {"role": "user", "content": "Hey, what's the temperature in Paris right now?"} +] +``` + +Now, let's apply the chat template and generate a response: + +```python +inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt") +inputs = {k: v.to(model.device) for k, v in inputs.items()} +out = model.generate(**inputs, max_new_tokens=128) +print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) +``` + +And we get: + +```text + +{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"} +<|im_end|> +``` + +The model has called the function with valid arguments, in the format requested by the function docstring. It has +inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units, +the temperature in France should certainly be displayed in Celsius. + +Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs +are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response +corresponds to which call. You can generate them any way you like, but they should be unique within each chat. + +```python +tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call +tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}} +messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]}) +``` + + +Now that we've added the tool call to the conversation, we can call the function and append the result to the +conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append +that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above. + +```python +messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"}) +``` + +Finally, let's let the assistant read the function outputs and continue chatting with the user: + +```python +inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt") +inputs = {k: v.to(model.device) for k, v in inputs.items()} +out = model.generate(**inputs, max_new_tokens=128) +print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) +``` + +And we get: + +```text +The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|> +``` + +Although this was a simple demo with dummy tools and a single call, the same technique works with +multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational +agents with real-time information, computational tools like calculators, or access to large databases. + + +Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and +match tool calls to results using the ordering, and there are several models that use neither and only issue one tool +call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we +recommend structuring your tools calls like we've shown here, and returning tool results in the order that +they were issued by the model. The chat templates on each model should handle the rest. + + +### Understanding tool schemas + +Each function you pass to the `tools` argument of `apply_chat_template` is converted into a +[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas +are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they +never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they +need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you +to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and +return the response in the chat. + +Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions +follow the specification above, but if you encounter problems, or you simply want more control over the conversion, +you can handle the conversion manually. Here is an example of a manual schema conversion. + +```python +from transformers.utils import get_json_schema + +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + +schema = get_json_schema(multiply) +print(schema) +``` + +This will yield: + +```json +{ + "type": "function", + "function": { + "name": "multiply", + "description": "A function that multiplies two numbers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number to multiply" + }, + "b": { + "type": "number", + "description": "The second number to multiply" + } + }, + "required": ["a", "b"] + } + } +} +``` + +If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at +all. JSON schemas can be passed directly to the `tools` argument of +`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful, +though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We +recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments) +to a minimum. + +Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`: + +```python +# A simple function that takes no arguments +current_time = { + "type": "function", + "function": { + "name": "current_time", + "description": "Get the current local time as a string.", + "parameters": { + 'type': 'object', + 'properties': {} + } + } +} + +# A more complete function that takes two numerical arguments +multiply = { + 'type': 'function', + 'function': { + 'name': 'multiply', + 'description': 'A function that multiplies two numbers', + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'number', + 'description': 'The first number to multiply' + }, + 'b': { + 'type': 'number', 'description': 'The second number to multiply' + } + }, + 'required': ['a', 'b'] + } + } +} + +model_input = tokenizer.apply_chat_template( + messages, + tools = [current_time, multiply] +) +``` + +## Advanced: Retrieval-augmented generation + +"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding +to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our +recommendation for RAG models is that their template +should accept a `documents` argument. This should be a list of documents, where each "document" +is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler +than the JSON schemas used for tools, no helper functions are necessary. + +Here's an example of a RAG template in action: + +```python +document1 = { + "title": "The Moon: Our Age-Old Foe", + "contents": "Man has always dreamed of destroying the moon. In this essay, I shall..." +} + +document2 = { + "title": "The Sun: Our Age-Old Friend", + "contents": "Although often underappreciated, the sun provides several notable benefits..." +} + +model_input = tokenizer.apply_chat_template( + messages, + documents=[document1, document2] +) +``` + ## Advanced: How do chat templates work? The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index bc9417f3587c1e..14a31560c6f1a8 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import lru_cache +from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -47,6 +48,7 @@ copy_func, download_url, extract_commit_hash, + get_json_schema, is_flax_available, is_jax_tensor, is_mlx_available, @@ -1683,6 +1685,8 @@ def get_vocab(self) -> Dict[str, int]: def apply_chat_template( self, conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + tools: Optional[List[Dict]] = None, + documents: Optional[List[Dict[str, str]]] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = False, tokenize: bool = True, @@ -1703,8 +1707,21 @@ def apply_chat_template( Args: conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts with "role" and "content" keys, representing the chat history so far. - chat_template (str, *optional*): A Jinja template to use for this conversion. If - this is not passed, the model's default chat template will be used instead. + tools (`List[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. + documents (`List[Dict[str, str]]`, *optional*): + A list of dicts representing documents that will be accessible to the model if it is performing RAG + (retrieval-augmented generation). If the template does not support RAG, this argument will have no + effect. We recommend that each document should be a dict containing "title" and "text" keys. Please + see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + for examples of passing documents with chat templates. + chat_template (`str`, *optional*): + A Jinja template to use for this conversion. It is usually not necessary to pass anything to this + argument, as the model's template will be used by default. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the @@ -1802,6 +1819,27 @@ def apply_chat_template( conversations = [conversation] is_batched = False + # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas + if tools is not None: + tool_schemas = [] + for tool in tools: + if isinstance(tool, dict): + tool_schemas.append(tool) + elif isfunction(tool): + tool_schemas.append(get_json_schema(tool)) + else: + raise ValueError( + "Tools should either be a JSON schema, or a callable function with type hints " + "and a docstring suitable for auto-conversion to a schema." + ) + else: + tool_schemas = None + + if documents is not None: + for document in documents: + if not isinstance(document, dict): + raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") + rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: @@ -1809,7 +1847,11 @@ def apply_chat_template( # Indicates it's a Conversation object chat = chat.messages rendered_chat = compiled_template.render( - messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs + messages=chat, + tools=tool_schemas, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, ) rendered.append(rendered_chat) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 51c1113cab3c2c..ce87bc8623132e 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,6 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin +from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py new file mode 100644 index 00000000000000..ee6173f2a1532b --- /dev/null +++ b/src/transformers/utils/chat_template_utils.py @@ -0,0 +1,316 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import re +from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints + + +BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) +# Extracts the initial segment of the docstring, containing the function description +description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) +# Extracts the Args: block from the docstring +args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) +# Splits the Args: block into individual arguments +args_split_re = re.compile( + r""" +(?:^|\n) # Match the start of the args block, or a newline +\s*(\w+):\s* # Capture the argument name and strip spacing +(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing +(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block +""", + re.DOTALL | re.VERBOSE, +) +# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! +returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) + + +class TypeHintParsingException(Exception): + """Exception raised for errors in parsing type hints to generate JSON schemas""" + + pass + + +class DocstringParsingException(Exception): + """Exception raised for errors in parsing docstrings to generate JSON schemas""" + + pass + + +def _get_json_schema_type(param_type: str) -> Dict[str, str]: + type_mapping = { + int: {"type": "integer"}, + float: {"type": "number"}, + str: {"type": "string"}, + bool: {"type": "boolean"}, + Any: {}, + } + return type_mapping.get(param_type, {"type": "object"}) + + +def _parse_type_hint(hint: str) -> Dict: + origin = get_origin(hint) + args = get_args(hint) + + if origin is None: + try: + return _get_json_schema_type(hint) + except KeyError: + raise TypeHintParsingException( + "Couldn't parse this type hint, likely due to a custom class or object: ", hint + ) + + elif origin is Union: + # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end + subtypes = [_parse_type_hint(t) for t in args if t != type(None)] + if len(subtypes) == 1: + # A single non-null type can be expressed directly + return_dict = subtypes[0] + elif all(isinstance(subtype["type"], str) for subtype in subtypes): + # A union of basic types can be expressed as a list in the schema + return_dict = {"type": [subtype["type"] for subtype in subtypes]} + else: + # A union of more complex types requires "anyOf" + return_dict = {"anyOf": subtypes} + if type(None) in args: + return_dict["nullable"] = True + return return_dict + + elif origin is list: + if not args: + return {"type": "array"} + else: + # Lists can only have a single type argument, so recurse into it + return {"type": "array", "items": _parse_type_hint(args[0])} + + elif origin is tuple: + if not args: + return {"type": "array"} + if len(args) == 1: + raise TypeHintParsingException( + f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " + "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " + "more than one element, we recommend " + "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " + "pass the element directly." + ) + if ... in args: + raise TypeHintParsingException( + "Conversion of '...' is not supported in Tuple type hints. " + "Use List[] types for variable-length" + " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} + + elif origin is dict: + # The JSON equivalent to a dict is 'object', which mandates that all keys are strings + # However, we can specify the type of the dict values with "additionalProperties" + out = {"type": "object"} + if len(args) == 2: + out["additionalProperties"] = _parse_type_hint(args[1]) + return out + + raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + + +def _convert_type_hints_to_json_schema(func: Callable) -> Dict: + type_hints = get_type_hints(func) + signature = inspect.signature(func) + required = [] + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty: + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") + if param.default == inspect.Parameter.empty: + required.append(param_name) + + properties = {} + for param_name, param_type in type_hints.items(): + properties[param_name] = _parse_type_hint(param_type) + + schema = {"type": "object", "properties": properties} + if required: + schema["required"] = required + + return schema + + +def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: + """ + Parses a Google-style docstring to extract the function description, + argument descriptions, and return description. + + Args: + docstring (str): The docstring to parse. + + Returns: + The function description, arguments, and return description. + """ + + # Extract the sections + description_match = description_re.search(docstring) + args_match = args_re.search(docstring) + returns_match = returns_re.search(docstring) + + # Clean and store the sections + description = description_match.group(1).strip() if description_match else None + docstring_args = args_match.group(1).strip() if args_match else None + returns = returns_match.group(1).strip() if returns_match else None + + # Parsing the arguments into a dictionary + if docstring_args is not None: + docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines + matches = args_split_re.findall(docstring_args) + args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} + else: + args_dict = {} + + return description, args_dict, returns + + +def get_json_schema(func: Callable) -> Dict: + """ + This function generates a JSON schema for a given function, based on its docstring and type hints. This is + mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of + the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires + that the function has a docstring, and that each argument has a description in the docstring, in the standard + Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. + + Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is + optional because most chat templates ignore the return value of the function. + + Args: + func: The function to generate a JSON schema for. + + Returns: + A dictionary containing the JSON schema for the function. + + Examples: + ```python + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply + >>> ''' + >>> return x * y + >>> + >>> print(get_json_schema(multiply)) + { + "name": "multiply", + "description": "A function that multiplies two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "The first number to multiply"}, + "y": {"type": "number", "description": "The second number to multiply"} + }, + "required": ["x", "y"] + } + } + ``` + + The general use for these schemas is that they are used to generate tool descriptions for chat templates that + support them, like so: + + ```python + >>> from transformers import AutoTokenizer + >>> from transformers.utils import get_json_schema + >>> + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply + >>> return x * y + >>> ''' + >>> + >>> multiply_schema = get_json_schema(multiply) + >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] + >>> formatted_chat = tokenizer.apply_chat_template( + >>> messages, + >>> tools=[multiply_schema], + >>> chat_template="tool_use", + >>> return_dict=True, + >>> return_tensors="pt", + >>> add_generation_prompt=True + >>> ) + >>> # The formatted chat can now be passed to model.generate() + ``` + + Each argument description can also have an optional `(choices: ...)` block at the end, such as + `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will + only be parsed correctly if it is at the end of the line: + + ```python + >>> def drink_beverage(beverage: str): + >>> ''' + >>> A function that drinks a beverage + >>> + >>> Args: + >>> beverage: The beverage to drink (choices: ["tea", "coffee"]) + >>> ''' + >>> pass + >>> + >>> print(get_json_schema(drink_beverage)) + ``` + { + 'name': 'drink_beverage', + 'description': 'A function that drinks a beverage', + 'parameters': { + 'type': 'object', + 'properties': { + 'beverage': { + 'type': 'string', + 'enum': ['tea', 'coffee'], + 'description': 'The beverage to drink' + } + }, + 'required': ['beverage'] + } + } + """ + doc = inspect.getdoc(func) + if not doc: + raise DocstringParsingException( + f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" + ) + doc = doc.strip() + main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) + + json_schema = _convert_type_hints_to_json_schema(func) + if (return_dict := json_schema["properties"].pop("return", None)) is not None: + if return_doc is not None: # We allow a missing return docstring since most templates ignore it + return_dict["description"] = return_doc + for arg, schema in json_schema["properties"].items(): + if arg not in param_descriptions: + raise DocstringParsingException( + f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + ) + desc = param_descriptions[arg] + enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) + if enum_choices: + schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] + desc = enum_choices.string[: enum_choices.start()].strip() + schema["description"] = desc + + output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + if return_dict is not None: + output["return"] = return_dict + return {"type": "function", "function": output} diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py new file mode 100644 index 00000000000000..cff31c1f8a3483 --- /dev/null +++ b/tests/utils/test_chat_template_utils.py @@ -0,0 +1,476 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import List, Optional, Tuple, Union + +from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema + + +class JsonSchemaGeneratorTest(unittest.TestCase): + def test_simple_function(self): + def fn(x: int): + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_no_arguments(self): + def fn(): + """ + Test function + """ + return True + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": {"type": "object", "properties": {}}, + } + self.assertEqual(schema["function"], expected_schema) + + def test_union(self): + def fn(x: Union[int, float]): + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": ["integer", "number"], "description": "The input"}}, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_optional(self): + def fn(x: Optional[int]): + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input", "nullable": True}}, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_default_arg(self): + def fn(x: int = 42): + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}}, + } + self.assertEqual(schema["function"], expected_schema) + + def test_nested_list(self): + def fn(x: List[List[Union[str, int]]]): + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "array", + "items": {"type": "array", "items": {"type": ["string", "integer"]}}, + "description": "The input", + } + }, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_multiple_arguments(self): + def fn(x: int, y: str): + """ + Test function + + Args: + x: The input + y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The input"}, + "y": {"type": "string", "description": "Also the input"}, + }, + "required": ["x", "y"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_multiple_complex_arguments(self): + def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): + """ + Test function + + Args: + x: The input + y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"}, + "y": { + "type": ["integer", "string"], + "nullable": True, + "description": "Also the input", + }, + }, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_missing_docstring(self): + def fn(x: int): + return x + + with self.assertRaises(DocstringParsingException): + get_json_schema(fn) + + def test_missing_param_docstring(self): + def fn(x: int): + """ + Test function + """ + return x + + with self.assertRaises(DocstringParsingException): + get_json_schema(fn) + + def test_missing_type_hint(self): + def fn(x): + """ + Test function + + Args: + x: The input + """ + return x + + with self.assertRaises(TypeHintParsingException): + get_json_schema(fn) + + def test_return_value(self): + def fn(x: int) -> int: + """ + Test function + + Args: + x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + "return": {"type": "integer"}, + } + self.assertEqual(schema["function"], expected_schema) + + def test_return_value_docstring(self): + def fn(x: int) -> int: + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + "return": {"type": "integer", "description": "The output"}, + } + self.assertEqual(schema["function"], expected_schema) + + def test_tuple(self): + def fn(x: Tuple[int, str]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "array", + "prefixItems": [{"type": "integer"}, {"type": "string"}], + "description": "The input", + } + }, + "required": ["x"], + }, + } + self.assertEqual(schema["function"], expected_schema) + + def test_single_element_tuple_fails(self): + def fn(x: Tuple[int]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Single-element tuples should just be the type itself, or List[type] for variable-length inputs + with self.assertRaises(TypeHintParsingException): + get_json_schema(fn) + + def test_ellipsis_type_fails(self): + def fn(x: Tuple[int, ...]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Variable length inputs should be specified with List[type], not Tuple[type, ...] + with self.assertRaises(TypeHintParsingException): + get_json_schema(fn) + + def test_enum_extraction(self): + def fn(temperature_format: str): + """ + Test function + + Args: + temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"]) + + + Returns: + The temperature + """ + return -40.0 + + # Let's see if that gets correctly parsed as an enum + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "temperature_format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature format to use", + } + }, + "required": ["temperature_format"], + }, + } + + self.assertEqual(schema["function"], expected_schema) + + def test_multiline_docstring_with_types(self): + def fn(x: int, y: int): + """ + Test function + + Args: + x: The first input + + y: The second input. This is a longer description + that spans multiple lines with indentation and stuff. + + Returns: + God knows what + """ + pass + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The first input"}, + "y": { + "type": "integer", + "description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.", + }, + }, + "required": ["x", "y"], + }, + } + + self.assertEqual(schema["function"], expected_schema) + + def test_everything_all_at_once(self): + def fn( + x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello") + ) -> Tuple[int, str]: + """ + Test function with multiple args, and docstring args that we have to strip out. + + Args: + x: The first input. It's got a big multiline + description and also contains + (choices: ["a", "b", "c"]) + + y: The second input. It's a big list with a single-line description. + + z: The third input. It's some kind of tuple with a default arg. + + Returns: + The output. The return description is also a big multiline + description that spans multiple lines. + """ + pass + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function with multiple args, and docstring args that we have to strip out.", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "string", + "enum": ["a", "b", "c"], + "description": "The first input. It's got a big multiline description and also contains", + }, + "y": { + "type": "array", + "items": {"type": ["string", "integer"]}, + "nullable": True, + "description": "The second input. It's a big list with a single-line description.", + }, + "z": { + "type": "array", + "prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}], + "description": "The third input. It's some kind of tuple with a default arg.", + }, + }, + "required": ["x", "y"], + }, + "return": { + "type": "array", + "prefixItems": [{"type": "integer"}, {"type": "string"}], + "description": "The output. The return description is also a big multiline\n description that spans multiple lines.", + }, + } + self.assertEqual(schema["function"], expected_schema)