From 3ad5bd4accfa600fe69cf50b09b8ad70c32714f9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 2 May 2024 15:40:00 +0100 Subject: [PATCH 01/69] First draft, still missing automatic function conversion --- docs/source/en/chat_templating.md | 88 +++++++++++++++++++++ src/transformers/tokenization_utils_base.py | 18 ++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 0a0e3effc2a946..430396ceb8f57d 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -233,6 +233,94 @@ The sun. From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column. +## Can I pass other arguments to the chat template? + +Yes, you can! The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword +argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use +chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass +strings, lists, dicts or whatever else you want. + +That said, there are some common use-cases for these extra arguments, +such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases, +we have some opinionated recommendations about what the names and formats of these arguments should be. By sticking +to these conventions when writing your template, you make it easier for users to use stabdard tool-use or RAG input +pipelines with your model without needing any manual reformatting. + +### Arguments for tool use + +Our recommendation for "tool use" LLMs which can choose to call functions as external tools is that their template +should accept a `tools` argument. This should be a list of tools, defined via [JSON Schema](https://json-schema.org/). Each "tool" +is a single function that the model can choose to call, and the schema should include the function name, its description +and the expected spec for its arguments. + +#### Example + +```python +# A simple function that takes no arguments +current_time = { + "name": "current_time", + "description": "Get the current local time as a string.", + "parameters": {}, # TODO - double-check if this is the correct schema for this case + } + +# A more complete function that takes two numerical arguments +multiply = { + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "The first number to multiply."}, + "b": {"type": "number", "description": "The second number to multiply."}, + }, + "required": ["a", "b"], + } + } + +model_input = tokenizer.apply_chat_template( + messages, + tools = [current_time, multiply] +) +``` + +JSON schemas permit highly detailed parameter specifications, so you can pass in functions with very complex, nested +arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art +models. We recommend trying to keep your tool schemas simple and flat where possible. + +### Automated function conversion + +Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means +that writing them can be annoying. Don't panic, though, we have a solution! + +TODO Should descriptions come from the docstrings or the type hints? + +TODO Do we need to define a special format for args in the docstrings? + +### Arguments for retrieval-augmented generation (RAG) + +Our recommendation for "RAG" LLMs which can search a corpus of documents for information is that their template +should accept a `documents` argument. This should be a list of documents, where each "document" +is a single dict with `title` and `contents` keys, both of which are strings. + +#### Example + +```python +document1 = { + "title": "The Moon: Our Age-Old Foe", + "contents": "Man has always dreamed of destroying the moon. In this essay, I shall..." +} + +document2 = { + "title": "The Sun: Our Age-Old Friend", + "contents": "Although often underappreciated, the sun provides several notable benefits..." +} + +model_input = tokenizer.apply_chat_template( + messages, + documents = [document1, document2] +) +``` + ## Advanced: How do chat templates work? The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index bc9417f3587c1e..399ec06394bf35 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1683,6 +1683,8 @@ def get_vocab(self) -> Dict[str, int]: def apply_chat_template( self, conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], + tools: Optional[List[Dict]] = None, + documents: Optional[List[Dict[str, str]]] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = False, tokenize: bool = True, @@ -1703,6 +1705,16 @@ def apply_chat_template( Args: conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts with "role" and "content" keys, representing the chat history so far. + tools (List[Dict], *optional*): A list of tools (callable functions) that will be accessible + to the model. If the template does not support function calling, this argument will have no effect. + We recommend passing the list of tools as a JSON Schema[link!], although note that some models and + templates may require a different format. Please see the docs for examples of passing tools with + chat templates[link!!]. + documents (List[Dict[str, str]], *optional*): A list of dicts representing documents that will be accessible + to the model if it is performing RAG (retrieval-augmented generation). If the template does not support + RAG, this argument will have no effect. We recommend that each document should be a dict containing + "title" and "text" keys. Please see the docs for examples of passing documents with chat + templates [link!!]. chat_template (str, *optional*): A Jinja template to use for this conversion. If this is not passed, the model's default chat template will be used instead. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate @@ -1809,7 +1821,11 @@ def apply_chat_template( # Indicates it's a Conversation object chat = chat.messages rendered_chat = compiled_template.render( - messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs + messages=chat, + tools=tools, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, ) rendered.append(rendered_chat) From 59e2cb640fe5ffa380bb7c10f5283dfaad1453d5 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 15:16:56 +0100 Subject: [PATCH 02/69] First draft of the automatic schema generator --- docs/source/en/chat_templating.md | 6 +- src/transformers/utils/__init__.py | 1 + src/transformers/utils/chat_template_utils.py | 84 +++++++++++++++++++ 3 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 src/transformers/utils/chat_template_utils.py diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 430396ceb8f57d..7a2e2507b9cc55 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -287,14 +287,12 @@ JSON schemas permit highly detailed parameter specifications, so you can pass in arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art models. We recommend trying to keep your tool schemas simple and flat where possible. -### Automated function conversion +### Automated function conversion for tool use Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means that writing them can be annoying. Don't panic, though, we have a solution! -TODO Should descriptions come from the docstrings or the type hints? - -TODO Do we need to define a special format for args in the docstrings? +TODO Explain function conversion with examples ### Arguments for retrieval-augmented generation (RAG) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 51c1113cab3c2c..ac2ecaef3a7fd6 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,6 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin +from .chat_template_utils import get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py new file mode 100644 index 00000000000000..f34a7ecbc43d14 --- /dev/null +++ b/src/transformers/utils/chat_template_utils.py @@ -0,0 +1,84 @@ +import inspect +import re +from typing import Any, Union, get_origin, get_type_hints + + +BASIC_TYPES = (int, float, str, bool, Any) + + +def get_json_schema(func): + doc = inspect.getdoc(func).strip() + if not doc: + raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") + param_descriptions = _get_argument_descriptions_from_docstring(doc) + + json_schema = _convert_type_hints_to_json_schema(func) + for arg in json_schema["properties"]: + if arg not in param_descriptions: + raise ValueError( + f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + ) + json_schema["properties"][arg]["description"] = param_descriptions[arg] + + return json_schema + + +def _get_argument_descriptions_from_docstring(doc): + param_pattern = r":param (\w+): (.+)" + params = re.findall(param_pattern, doc) + return dict(params) + + +def _convert_type_hints_to_json_schema(func): + type_hints = get_type_hints(func) + properties = {} + + signature = inspect.signature(func) + required = [ + param_name for param_name, param in signature.parameters.items() if param.default == inspect.Parameter.empty + ] + + for param_name, param_type in type_hints.items(): + if param_name == "return": + continue + + if origin := get_origin(param_type) is not None: + if origin is Union: + if all(t in BASIC_TYPES for t in param_type.__args__): + properties[param_name] = { + "type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)], + "nullable": type(None) in param_type.__args__, + } + else: + properties[param_name] = { + "anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)], + "nullable": type(None) in param_type.__args__, + } + elif origin is list: + properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])} + elif origin is dict: + properties[param_name] = { + "type": "object", + "additionalProperties": _get_json_schema_type(param_type.__args__[1]), + } + else: + properties[param_name] = _get_json_schema_type(param_type) + + schema = {"type": "object", "properties": properties, "required": required} + + return schema + + +def _get_json_schema_type(param_type): + if param_type == int: + return {"type": "integer"} + elif param_type == float: + return {"type": "number"} + elif param_type == str: + return {"type": "string"} + elif param_type == bool: + return {"type": "boolean"} + elif param_type == Any: + return {} + else: + return {"type": "object"} From 8f7655db92704d00442efafe8ba96445fba3fcfd Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 15:50:33 +0100 Subject: [PATCH 03/69] Lots of small fixes --- src/transformers/utils/chat_template_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index f34a7ecbc43d14..b3098d26297d9e 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -7,10 +7,11 @@ def get_json_schema(func): - doc = inspect.getdoc(func).strip() + doc = inspect.getdoc(func) if not doc: raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") - param_descriptions = _get_argument_descriptions_from_docstring(doc) + doc = doc.strip() + main_doc, param_descriptions = _get_argument_descriptions_from_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) for arg in json_schema["properties"]: @@ -20,13 +21,14 @@ def get_json_schema(func): ) json_schema["properties"][arg]["description"] = param_descriptions[arg] - return json_schema + return {"name": func.__name__, "description": main_doc, "parameters": json_schema} def _get_argument_descriptions_from_docstring(doc): param_pattern = r":param (\w+): (.+)" params = re.findall(param_pattern, doc) - return dict(params) + main_doc = doc.split(":param")[0].strip() + return main_doc, dict(params) def _convert_type_hints_to_json_schema(func): @@ -64,7 +66,9 @@ def _convert_type_hints_to_json_schema(func): else: properties[param_name] = _get_json_schema_type(param_type) - schema = {"type": "object", "properties": properties, "required": required} + schema = {"type": "object", "properties": properties} + if required: + schema["required"] = required return schema From 0b2ead3f3a1eba058bcb197e14dd04a7bbd82551 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 16:08:55 +0100 Subject: [PATCH 04/69] the walrus has betrayed me --- src/transformers/utils/chat_template_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index b3098d26297d9e..3a28727c18f17b 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,6 +1,7 @@ import inspect import re from typing import Any, Union, get_origin, get_type_hints +import pdb BASIC_TYPES = (int, float, str, bool, Any) @@ -43,8 +44,8 @@ def _convert_type_hints_to_json_schema(func): for param_name, param_type in type_hints.items(): if param_name == "return": continue - - if origin := get_origin(param_type) is not None: + pdb.set_trace() + if (origin := get_origin(param_type)) is not None: if origin is Union: if all(t in BASIC_TYPES for t in param_type.__args__): properties[param_name] = { From cb67fd2dbd981e6eba0b24f1a940aabaf35e116d Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 16:09:16 +0100 Subject: [PATCH 05/69] please stop committing your debug breakpoints --- src/transformers/utils/chat_template_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3a28727c18f17b..5b418bcf3738ba 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -44,7 +44,6 @@ def _convert_type_hints_to_json_schema(func): for param_name, param_type in type_hints.items(): if param_name == "return": continue - pdb.set_trace() if (origin := get_origin(param_type)) is not None: if origin is Union: if all(t in BASIC_TYPES for t in param_type.__args__): From 41df7d11ec67f08588fde7b835b3f6baa7e74516 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 18:00:47 +0100 Subject: [PATCH 06/69] Lots of cleanup and edge cases, looking better now --- src/transformers/utils/chat_template_utils.py | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 5b418bcf3738ba..6e06e10da3e42e 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,7 +1,6 @@ import inspect import re from typing import Any, Union, get_origin, get_type_hints -import pdb BASIC_TYPES = (int, float, str, bool, Any) @@ -44,27 +43,8 @@ def _convert_type_hints_to_json_schema(func): for param_name, param_type in type_hints.items(): if param_name == "return": continue - if (origin := get_origin(param_type)) is not None: - if origin is Union: - if all(t in BASIC_TYPES for t in param_type.__args__): - properties[param_name] = { - "type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)], - "nullable": type(None) in param_type.__args__, - } - else: - properties[param_name] = { - "anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)], - "nullable": type(None) in param_type.__args__, - } - elif origin is list: - properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])} - elif origin is dict: - properties[param_name] = { - "type": "object", - "additionalProperties": _get_json_schema_type(param_type.__args__[1]), - } - else: - properties[param_name] = _get_json_schema_type(param_type) + properties[param_name] = _parse_type_hint(param_type) + schema = {"type": "object", "properties": properties} if required: @@ -72,6 +52,43 @@ def _convert_type_hints_to_json_schema(func): return schema +def _parse_type_hint(hint): + if (origin := get_origin(hint)) is not None: + if origin is Union: + if all(t in BASIC_TYPES for t in hint.__args__): + return_dict = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + if len(return_dict["type"]) == 1: + return_dict["type"] = return_dict["type"][0] + else: + return_dict = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)],} + if len(return_dict["anyOf"]) == 1: + return_dict = return_dict["anyOf"][0] + if type(None) in hint.__args__: + return_dict["nullable"] = True + return return_dict + elif origin is list or origin is tuple: + if not hasattr(hint, "__args__"): + return {"type": "array"} + if all(t in BASIC_TYPES for t in hint.__args__): + items = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + if len(items["type"]) == 1: + items["type"] = items["type"][0] + else: + items = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)]} + if len(items["anyOf"]) == 1: + items = items["anyOf"][0] + return_dict = {"type": "array", "items": items} + if "nullable" in hint.__args__: + return_dict["nullable"] = True + return return_dict + elif origin is dict: + return { + "type": "object", + "additionalProperties": _parse_type_hint(hint.__args__[1]), + } + else: + return _get_json_schema_type(hint) + def _get_json_schema_type(param_type): if param_type == int: From eec2486582e3cc4f050ea699ee9352a209012dac Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 7 May 2024 15:30:13 +0100 Subject: [PATCH 07/69] Comments and bugfixes for the type hint parser --- src/transformers/utils/chat_template_utils.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 6e06e10da3e42e..cfe39733741b77 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,6 +1,6 @@ import inspect import re -from typing import Any, Union, get_origin, get_type_hints +from typing import Any, Union, get_origin, get_type_hints, get_args BASIC_TYPES = (int, float, str, bool, Any) @@ -52,40 +52,49 @@ def _convert_type_hints_to_json_schema(func): return schema + def _parse_type_hint(hint): if (origin := get_origin(hint)) is not None: if origin is Union: - if all(t in BASIC_TYPES for t in hint.__args__): + # If it's a union of basic types, we can express that as a simple list in the schema + if all(t in BASIC_TYPES for t in get_args(hint)): return_dict = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} if len(return_dict["type"]) == 1: return_dict["type"] = return_dict["type"][0] else: + # A union of more complex types requires us to recurse into each subtype return_dict = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)],} if len(return_dict["anyOf"]) == 1: return_dict = return_dict["anyOf"][0] - if type(None) in hint.__args__: + if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict elif origin is list or origin is tuple: - if not hasattr(hint, "__args__"): + if not get_args(hint): return {"type": "array"} - if all(t in BASIC_TYPES for t in hint.__args__): + if all(t in BASIC_TYPES for t in get_args(hint)): + # Similarly to unions, a list of basic types can be expressed as a list in the schema items = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} if len(items["type"]) == 1: items["type"] = items["type"][0] else: + # And a list of more complex types requires us to recurse into each subtype again items = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)]} if len(items["anyOf"]) == 1: items = items["anyOf"][0] return_dict = {"type": "array", "items": items} - if "nullable" in hint.__args__: + if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict elif origin is dict: + # The JSON equivalent to a dict is 'object', which mandates that all keys are strings + # However, we can specify the type of the dict values with "additionalProperties" return { "type": "object", "additionalProperties": _parse_type_hint(hint.__args__[1]), } + else: + raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) else: return _get_json_schema_type(hint) From cf2b8dab1299cb3223a4d5d3fd31cab4d7156bc1 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 7 May 2024 16:08:22 +0100 Subject: [PATCH 08/69] More cleanup --- src/transformers/tokenization_utils_base.py | 5 ++--- src/transformers/utils/chat_template_utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 399ec06394bf35..8746ecca6ada38 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1707,9 +1707,8 @@ def apply_chat_template( with "role" and "content" keys, representing the chat history so far. tools (List[Dict], *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. - We recommend passing the list of tools as a JSON Schema[link!], although note that some models and - templates may require a different format. Please see the docs for examples of passing tools with - chat templates[link!!]. + Each tool should be passed as a JSON Schema[link!], giving the name, description and argument types + for the tool. [Docs and links here, including auto-generation of schemas!] documents (List[Dict[str, str]], *optional*): A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index cfe39733741b77..dc85175588355c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -58,12 +58,12 @@ def _parse_type_hint(hint): if origin is Union: # If it's a union of basic types, we can express that as a simple list in the schema if all(t in BASIC_TYPES for t in get_args(hint)): - return_dict = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + return_dict = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]} if len(return_dict["type"]) == 1: return_dict["type"] = return_dict["type"][0] else: # A union of more complex types requires us to recurse into each subtype - return_dict = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)],} + return_dict = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)],} if len(return_dict["anyOf"]) == 1: return_dict = return_dict["anyOf"][0] if type(None) in get_args(hint): @@ -74,12 +74,12 @@ def _parse_type_hint(hint): return {"type": "array"} if all(t in BASIC_TYPES for t in get_args(hint)): # Similarly to unions, a list of basic types can be expressed as a list in the schema - items = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + items = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]} if len(items["type"]) == 1: items["type"] = items["type"][0] else: # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)]} + items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)]} if len(items["anyOf"]) == 1: items = items["anyOf"][0] return_dict = {"type": "array", "items": items} @@ -91,7 +91,7 @@ def _parse_type_hint(hint): # However, we can specify the type of the dict values with "additionalProperties" return { "type": "object", - "additionalProperties": _parse_type_hint(hint.__args__[1]), + "additionalProperties": _parse_type_hint(get_args(hint)[1]), } else: raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) From ad0984b6bde5dee401c7d2f5b71a448014f4bb88 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 7 May 2024 18:35:28 +0100 Subject: [PATCH 09/69] Add tests, update schema generator --- src/transformers/utils/chat_template_utils.py | 26 +++- tests/utils/test_chat_template_utils.py | 128 ++++++++++++++++++ 2 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 tests/utils/test_chat_template_utils.py diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index dc85175588355c..cee8a0bcdf8b90 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,6 +1,6 @@ import inspect import re -from typing import Any, Union, get_origin, get_type_hints, get_args +from typing import Any, Union, get_args, get_origin, get_type_hints BASIC_TYPES = (int, float, str, bool, Any) @@ -45,31 +45,35 @@ def _convert_type_hints_to_json_schema(func): continue properties[param_name] = _parse_type_hint(param_type) - schema = {"type": "object", "properties": properties} if required: schema["required"] = required return schema - +# TODO: Return types!! How are those even handled? Does it even matter? I should check what the different APIs do for this +# and also add tests def _parse_type_hint(hint): if (origin := get_origin(hint)) is not None: if origin is Union: # If it's a union of basic types, we can express that as a simple list in the schema if all(t in BASIC_TYPES for t in get_args(hint)): - return_dict = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]} + return_dict = { + "type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t not in (type(None), ...)] + } if len(return_dict["type"]) == 1: return_dict["type"] = return_dict["type"][0] else: # A union of more complex types requires us to recurse into each subtype - return_dict = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)],} + return_dict = { + "anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)], + } if len(return_dict["anyOf"]) == 1: return_dict = return_dict["anyOf"][0] if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict - elif origin is list or origin is tuple: + elif origin is list: if not get_args(hint): return {"type": "array"} if all(t in BASIC_TYPES for t in get_args(hint)): @@ -79,13 +83,21 @@ def _parse_type_hint(hint): items["type"] = items["type"][0] else: # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)]} + items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)]} if len(items["anyOf"]) == 1: items = items["anyOf"][0] return_dict = {"type": "array", "items": items} if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict + elif origin is tuple: + raise ValueError( + "This helper does not parse Tuple types, as they are usually used to indicate that " + "each position is associated with a specific type, and this requires JSON schemas " + "that are not supported by most templates. We recommend " + "either using List or List[Union] instead for arguments where this is appropriate, or " + "splitting arguments with Tuple types into multiple arguments that take single inputs." + ) elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py new file mode 100644 index 00000000000000..972ed779ba474d --- /dev/null +++ b/tests/utils/test_chat_template_utils.py @@ -0,0 +1,128 @@ +import unittest +from typing import List, Optional, Union + +from transformers.utils import get_json_schema + + +class JsonSchemaGeneratorTest(unittest.TestCase): + def test_simple_function(self): + def fn(x: int): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_union(self): + def fn(x: Union[int, float]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': ['integer', 'number'], 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_optional(self): + def fn(x: Optional[int]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input', "nullable": True}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_default_arg(self): + def fn(x: int = 42): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}}} + self.assertEqual(schema, expected_schema) + + def test_nested_list(self): + def fn(x: List[List[Union[int, str]]]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': 'array', 'items': {'type': ['integer', 'string']}}, 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_multiple_arguments(self): + def fn(x: int, y: str): + """ + Test function + + :param x: The input + :param y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}, 'y': {'type': 'string', 'description': 'Also the input'}}, 'required': ['x', 'y']}} + self.assertEqual(schema, expected_schema) + + def test_multiple_complex_arguments(self): + def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): + """ + Test function + + :param x: The input + :param y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': ['integer', 'number']}, 'description': 'The input'}, 'y': {'anyOf': [{'type': 'integer'}, {'type': 'string'}], 'nullable': True, 'description': 'Also the input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_missing_docstring(self): + def fn(x: int): + return x + + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_missing_param_docstring(self): + def fn(x: int): + """ + Test function + """ + return x + + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_missing_type_hint(self): + def fn(x): + """ + Test function + + :param x: The input + """ + return x + + with self.assertRaises(ValueError): + get_json_schema(fn) From c9fb3de830f5e4c4d2013d25a5928e749aae3e6c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 8 May 2024 13:54:57 +0100 Subject: [PATCH 10/69] Update tests, proper handling of return values --- src/transformers/utils/chat_template_utils.py | 12 +- tests/utils/test_chat_template_utils.py | 104 ++++++++++++++++-- 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index cee8a0bcdf8b90..04c0624ad3385c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -36,9 +36,12 @@ def _convert_type_hints_to_json_schema(func): properties = {} signature = inspect.signature(func) - required = [ - param_name for param_name, param in signature.parameters.items() if param.default == inspect.Parameter.empty - ] + required = [] + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty: + raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}") + if param.default == inspect.Parameter.empty: + required.append(param_name) for param_name, param_type in type_hints.items(): if param_name == "return": @@ -51,8 +54,7 @@ def _convert_type_hints_to_json_schema(func): return schema -# TODO: Return types!! How are those even handled? Does it even matter? I should check what the different APIs do for this -# and also add tests + def _parse_type_hint(hint): if (origin := get_origin(hint)) is not None: if origin is Union: diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 972ed779ba474d..56e3f4c267e5dc 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -15,7 +15,15 @@ def fn(x: int): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}, 'required': ['x']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + } self.assertEqual(schema, expected_schema) def test_union(self): @@ -28,7 +36,15 @@ def fn(x: Union[int, float]): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': ['integer', 'number'], 'description': 'The input'}}, 'required': ['x']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": ["integer", "number"], "description": "The input"}}, + "required": ["x"], + }, + } self.assertEqual(schema, expected_schema) def test_optional(self): @@ -41,7 +57,15 @@ def fn(x: Optional[int]): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input', "nullable": True}}, 'required': ['x']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input", "nullable": True}}, + "required": ["x"], + }, + } self.assertEqual(schema, expected_schema) def test_default_arg(self): @@ -54,7 +78,11 @@ def fn(x: int = 42): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}}, + } self.assertEqual(schema, expected_schema) def test_nested_list(self): @@ -67,7 +95,21 @@ def fn(x: List[List[Union[int, str]]]): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': 'array', 'items': {'type': ['integer', 'string']}}, 'description': 'The input'}}, 'required': ['x']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "array", + "items": {"type": "array", "items": {"type": ["integer", "string"]}}, + "description": "The input", + } + }, + "required": ["x"], + }, + } self.assertEqual(schema, expected_schema) def test_multiple_arguments(self): @@ -81,7 +123,18 @@ def fn(x: int, y: str): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}, 'y': {'type': 'string', 'description': 'Also the input'}}, 'required': ['x', 'y']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The input"}, + "y": {"type": "string", "description": "Also the input"}, + }, + "required": ["x", "y"], + }, + } self.assertEqual(schema, expected_schema) def test_multiple_complex_arguments(self): @@ -95,7 +148,22 @@ def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): return x schema = get_json_schema(fn) - expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': ['integer', 'number']}, 'description': 'The input'}, 'y': {'anyOf': [{'type': 'integer'}, {'type': 'string'}], 'nullable': True, 'description': 'Also the input'}}, 'required': ['x']}} + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"}, + "y": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "nullable": True, + "description": "Also the input", + }, + }, + "required": ["x"], + }, + } self.assertEqual(schema, expected_schema) def test_missing_docstring(self): @@ -126,3 +194,25 @@ def fn(x): with self.assertRaises(ValueError): get_json_schema(fn) + + def test_return_value_has_no_effect(self): + # We ignore return values, so we want to make sure they don't affect the schema + def fn(x: int) -> int: + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + } + self.assertEqual(schema, expected_schema) From 10be3b19da898e2b482037621b9397c5265d5cf5 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 8 May 2024 13:59:07 +0100 Subject: [PATCH 11/69] Small docstring change --- src/transformers/utils/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 04c0624ad3385c..3ee179bc4bb604 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -97,7 +97,7 @@ def _parse_type_hint(hint): "This helper does not parse Tuple types, as they are usually used to indicate that " "each position is associated with a specific type, and this requires JSON schemas " "that are not supported by most templates. We recommend " - "either using List or List[Union] instead for arguments where this is appropriate, or " + "either using List instead for arguments where this is appropriate, or " "splitting arguments with Tuple types into multiple arguments that take single inputs." ) elif origin is dict: From d9e64540f0afcc0aced5b3a5afbf94c29a22a4b4 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 8 May 2024 15:02:40 +0100 Subject: [PATCH 12/69] More doc updates --- docs/source/en/chat_templating.md | 86 +++++++++++++++++++-- src/transformers/tokenization_utils_base.py | 4 +- tests/utils/test_chat_template_utils.py | 19 +++++ 3 files changed, 100 insertions(+), 9 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 7a2e2507b9cc55..5a8a0af470e103 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -242,13 +242,12 @@ strings, lists, dicts or whatever else you want. That said, there are some common use-cases for these extra arguments, such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases, -we have some opinionated recommendations about what the names and formats of these arguments should be. By sticking -to these conventions when writing your template, you make it easier for users to use stabdard tool-use or RAG input -pipelines with your model without needing any manual reformatting. +we have some opinionated recommendations about what the names and formats of these arguments should be. ### Arguments for tool use -Our recommendation for "tool use" LLMs which can choose to call functions as external tools is that their template +"Tool use" LLMs can choose to call functions as external tools before generating an answer. Our recommendation for +tool use models is that their template should accept a `tools` argument. This should be a list of tools, defined via [JSON Schema](https://json-schema.org/). Each "tool" is a single function that the model can choose to call, and the schema should include the function name, its description and the expected spec for its arguments. @@ -260,7 +259,10 @@ and the expected spec for its arguments. current_time = { "name": "current_time", "description": "Get the current local time as a string.", - "parameters": {}, # TODO - double-check if this is the correct schema for this case + "parameters": { + 'type': 'object', + 'properties': {} + }, } # A more complete function that takes two numerical arguments @@ -289,10 +291,80 @@ models. We recommend trying to keep your tool schemas simple and flat where poss ### Automated function conversion for tool use +# TODO Docstring and doc for get_json_schema(), then link to it + Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means -that writing them can be annoying. Don't panic, though, we have a solution! +that writing them can be annoying. Don't panic, though, we have a solution! You can simply define Python functions +as tools, and use the `get_json_schema()` function. This function will automatically generate a JSON schema for any +function that has a valid docstring with parameter annotations and valid type hints. Let's see it in action! + +```python +from transformers.utils import get_json_schema + +def multiply(a: float, b: float): + """Multiply two numbers together. + + :param a: The first number to multiply. + :param b: The second number to multiply. + """ + return a * b + +schema = get_json_schema(multiply) +print(schema) +``` + +This will yield: + +```json +{ + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number to multiply." + }, + "b": { + "type": "number", + "description": "The second number to multiply." + } + }, + "required": ["a", "b"] + } +} +``` + +TODO Add a JSON schema decorator so this can be even shorter + +We can use this function to greatly simplify tool-calling: + +```python +import datetime + +def current_time(): + """Get the current local time as a string.""" + return str(datetime.now()) + +def multiply(a: float, b: float): + """Multiply two numbers together. + + :param a: The first number to multiply. + :param b: The second number to multiply. + """ + return a * b + +tools = [current_time, multiply] +schemas = [get_json_schema(tool) for tool in tools] + +model_input = tokenizer.apply_chat_template( + messages, + tools=schemas +) +``` + -TODO Explain function conversion with examples ### Arguments for retrieval-augmented generation (RAG) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8746ecca6ada38..cb173a9061fcba 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1707,8 +1707,8 @@ def apply_chat_template( with "role" and "content" keys, representing the chat history so far. tools (List[Dict], *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. - Each tool should be passed as a JSON Schema[link!], giving the name, description and argument types - for the tool. [Docs and links here, including auto-generation of schemas!] + Each tool should be passed as a JSON Schema, giving the name, description and argument types + for the tool. See our [chat templating guide]( documents (List[Dict[str, str]], *optional*): A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 56e3f4c267e5dc..11fa697555d994 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -5,6 +5,7 @@ class JsonSchemaGeneratorTest(unittest.TestCase): + def test_simple_function(self): def fn(x: int): """ @@ -26,6 +27,24 @@ def fn(x: int): } self.assertEqual(schema, expected_schema) + def test_no_arguments(self): + def fn(): + """ + Test function + """ + return True + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + 'type': 'object', + 'properties': {} + }, + } + self.assertEqual(schema, expected_schema) + def test_union(self): def fn(x: Union[int, float]): """ From 80addcf9b567e275fccbc245c616cacdad3c3056 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 8 May 2024 15:38:18 +0100 Subject: [PATCH 13/69] More doc updates --- docs/source/en/chat_templating.md | 11 ++++++----- src/transformers/tokenization_utils_base.py | 12 +++++++----- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 5a8a0af470e103..4632a193907ba1 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -364,13 +364,14 @@ model_input = tokenizer.apply_chat_template( ) ``` +### Arguments for RAG - -### Arguments for retrieval-augmented generation (RAG) - -Our recommendation for "RAG" LLMs which can search a corpus of documents for information is that their template +"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding +to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our +recommendation for RAG models is that their template should accept a `documents` argument. This should be a list of documents, where each "document" -is a single dict with `title` and `contents` keys, both of which are strings. +is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler +than the JSON schemas used for tools, no helper functions are necessary. #### Example diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index cb173a9061fcba..39d2ce481a40f5 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1708,14 +1708,16 @@ def apply_chat_template( tools (List[Dict], *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, giving the name, description and argument types - for the tool. See our [chat templating guide]( + for the tool. See our [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + for more information. documents (List[Dict[str, str]], *optional*): A list of dicts representing documents that will be accessible to the model if it is performing RAG (retrieval-augmented generation). If the template does not support RAG, this argument will have no effect. We recommend that each document should be a dict containing - "title" and "text" keys. Please see the docs for examples of passing documents with chat - templates [link!!]. - chat_template (str, *optional*): A Jinja template to use for this conversion. If - this is not passed, the model's default chat template will be used instead. + "title" and "text" keys. Please see the RAG section of the + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + for examples of passing documents with chat templates. + chat_template (str, *optional*): A Jinja template to use for this conversion. By default, the model's + template will be used. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the From d3d677bb7a672a529c9ec6165f377acef8bd52d8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 15:55:15 +0100 Subject: [PATCH 14/69] Add json_schema decorator --- src/transformers/tokenization_utils_base.py | 5 +++++ src/transformers/utils/chat_template_utils.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 39d2ce481a40f5..ef37c63239d744 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1815,6 +1815,11 @@ def apply_chat_template( conversations = [conversation] is_batched = False + # The add_json_schema decorator for tools adds a schema under the `json_schema` attribute. If we're passed + # decorated functions, let's extract the schema decoration now + if tools is not None: + tools = [tool.json_schema if hasattr(tool, "json_schema") else tool for tool in tools] + rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3ee179bc4bb604..33d7dd6b091bdb 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -24,6 +24,11 @@ def get_json_schema(func): return {"name": func.__name__, "description": main_doc, "parameters": json_schema} +def add_json_schema(func): + func.json_schema = get_json_schema(func) + return func + + def _get_argument_descriptions_from_docstring(doc): param_pattern = r":param (\w+): (.+)" params = re.findall(param_pattern, doc) From a7d241c60e3240a752e60f78ad50823e254b6807 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 16:52:29 +0100 Subject: [PATCH 15/69] Clean up the TODOs and finish the docs --- docs/source/en/chat_templating.md | 15 ++- src/transformers/utils/chat_template_utils.py | 105 ++++++++++++++++++ tests/utils/test_chat_template_utils.py | 6 +- 3 files changed, 113 insertions(+), 13 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 4632a193907ba1..65709cfc1fb24c 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -291,11 +291,9 @@ models. We recommend trying to keep your tool schemas simple and flat where poss ### Automated function conversion for tool use -# TODO Docstring and doc for get_json_schema(), then link to it - Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means that writing them can be annoying. Don't panic, though, we have a solution! You can simply define Python functions -as tools, and use the `get_json_schema()` function. This function will automatically generate a JSON schema for any +as tools, and use the [`get_json_schema`] function. This function will automatically generate a JSON schema for any function that has a valid docstring with parameter annotations and valid type hints. Let's see it in action! ```python @@ -336,17 +334,19 @@ This will yield: } ``` -TODO Add a JSON schema decorator so this can be even shorter - -We can use this function to greatly simplify tool-calling: +We can use this function, or the equivalent [`add_json_schema`] decorator, to avoid the need to manually write JSON +schemas when passing tools to the chat template: ```python import datetime +from transformers.utils import add_json_schema +@add_json_schema def current_time(): """Get the current local time as a string.""" return str(datetime.now()) +@add_json_schema def multiply(a: float, b: float): """Multiply two numbers together. @@ -356,11 +356,10 @@ def multiply(a: float, b: float): return a * b tools = [current_time, multiply] -schemas = [get_json_schema(tool) for tool in tools] model_input = tokenizer.apply_chat_template( messages, - tools=schemas + tools=tools ) ``` diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 33d7dd6b091bdb..80d4a04db908db 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -7,6 +7,78 @@ def get_json_schema(func): + """ + This function generates a JSON schema for a given function, based on its docstring and type hints. This is + mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of + the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires + that the function has a docstring, and that each argument has a description in the docstring, in the format + `:param arg_name: arg_description`. It also requires that all the function arguments have a valid Python type hint. + + Args: + func: The function to generate a JSON schema for. + + Returns: + A dictionary containing the JSON schema for the function. + + Examples: + ```python + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> :param x: The first number to multiply + >>> :param y: The second number to multiply + >>> ''' + >>> return x * y + >>> + >>> print(get_json_schema(multiply)) + { + "name": "multiply", + "description": "A function that multiplies two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "The first number to multiply"}, + "y": {"type": "number", "description": "The second number to multiply"} + }, + "required": ["x", "y"] + } + } + ``` + + The general use for these schemas is that they are used to generate tool descriptions for chat templates that + support them, like so: + + ```python + >>> from transformers import AutoTokenizer + >>> from transformers.utils import get_json_schema + >>> + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> :param x: The first number to multiply + >>> :param y: The second number to multiply + >>> ''' + >>> return x * y + >>> + >>> multiply_schema = get_json_schema(multiply) + >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] + >>> formatted_chat = tokenizer.apply_chat_template( + >>> messages, + >>> tools=[multiply_schema], + >>> chat_template="tool_use", + >>> return_dict=True, + >>> return_tensors="pt", + >>> add_generation_prompt=True + >>> ) + >>> # The formatted chat can now be passed to model.generate() + ``` + + In many cases, it is more convenient to define tool functions with the [`add_json_schema`] decorator rather than + calling this function directly. + """ doc = inspect.getdoc(func) if not doc: raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") @@ -25,6 +97,39 @@ def get_json_schema(func): def add_json_schema(func): + """ + This decorator adds a JSON schema to a function, based on its docstring and type hints. The JSON schema is the + same as the one generated by the [`get_json_schema`] function. It is stored in the `json_schema` attribute of the + function, which will be automatically read by `apply_chat_template()` if present. + + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from transformers.utils import get_json_schema + >>> + >>> @add_json_schema + >>> def multiply(x: float, y: float): + >>> ''' + >>> A function that multiplies two numbers + >>> + >>> :param x: The first number to multiply + >>> :param y: The second number to multiply + >>> ''' + >>> return x * y + >>> + >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") + >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] + >>> formatted_chat = tokenizer.apply_chat_template( + >>> messages, + >>> tools=[multiply], + >>> chat_template="tool_use", + >>> return_dict=True, + >>> return_tensors="pt", + >>> add_generation_prompt=True + >>> ) + >>> # The formatted chat can now be passed to model.generate() + """ func.json_schema = get_json_schema(func) return func diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 11fa697555d994..9aab80ed377d6a 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -5,7 +5,6 @@ class JsonSchemaGeneratorTest(unittest.TestCase): - def test_simple_function(self): def fn(x: int): """ @@ -38,10 +37,7 @@ def fn(): expected_schema = { "name": "fn", "description": "Test function", - "parameters": { - 'type': 'object', - 'properties': {} - }, + "parameters": {"type": "object", "properties": {}}, } self.assertEqual(schema, expected_schema) From fad9ae207d2718b3b64398264ca46acc8cf6c86d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 18:17:42 +0100 Subject: [PATCH 16/69] self.maxDiff = None to see the whole diff for the nested list test --- tests/utils/test_chat_template_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 9aab80ed377d6a..578c89f0e166c4 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -109,6 +109,8 @@ def fn(x: List[List[Union[int, str]]]): """ return x + self.maxDiff = None + schema = get_json_schema(fn) expected_schema = { "name": "fn", From d202bfe0f2d51528bafd9192ba487f0cdbae77c0 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 18:18:47 +0100 Subject: [PATCH 17/69] add import for add_json_schema --- src/transformers/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ac2ecaef3a7fd6..cf47a2db544eb3 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,7 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin -from .chat_template_utils import get_json_schema +from .chat_template_utils import add_json_schema, get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, From 11cf8f57ea65e6dc04332445d2d20023fab0d1b8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 18:57:32 +0100 Subject: [PATCH 18/69] Quick test fix --- tests/utils/test_chat_template_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 578c89f0e166c4..c1e0489350fe40 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -101,7 +101,7 @@ def fn(x: int = 42): self.assertEqual(schema, expected_schema) def test_nested_list(self): - def fn(x: List[List[Union[int, str]]]): + def fn(x: List[List[Union[str, int]]]): """ Test function @@ -109,8 +109,6 @@ def fn(x: List[List[Union[int, str]]]): """ return x - self.maxDiff = None - schema = get_json_schema(fn) expected_schema = { "name": "fn", @@ -120,7 +118,7 @@ def fn(x: List[List[Union[int, str]]]): "properties": { "x": { "type": "array", - "items": {"type": "array", "items": {"type": ["integer", "string"]}}, + "items": {"type": "array", "items": {"type": ["string", "integer"]}}, "description": "The input", } }, From 8962c421cc57568beb83879ff50cf502c5eb3243 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 9 May 2024 19:00:06 +0100 Subject: [PATCH 19/69] Fix something that was bugging me in the chat template docstring --- src/transformers/tokenization_utils_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index ef37c63239d744..c7d1e20c33fc12 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1716,8 +1716,8 @@ def apply_chat_template( "title" and "text" keys. Please see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) for examples of passing documents with chat templates. - chat_template (str, *optional*): A Jinja template to use for this conversion. By default, the model's - template will be used. + chat_template (str, *optional*): A Jinja template to use for this conversion. It is usually not necessary + to pass anything to this argument, as the model's template will be used by default. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the From 6f4a897e5a1e994e16c29ff9c446279c8d74d131 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 10 May 2024 13:24:23 +0100 Subject: [PATCH 20/69] Less "anyOf" when unnecessary --- src/transformers/utils/chat_template_utils.py | 2 +- tests/utils/test_chat_template_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 80d4a04db908db..841b0ff9e91c79 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -3,7 +3,7 @@ from typing import Any, Union, get_args, get_origin, get_type_hints -BASIC_TYPES = (int, float, str, bool, Any) +BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) def get_json_schema(func): diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index c1e0489350fe40..95c7825f2acb06 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -171,7 +171,7 @@ def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): "properties": { "x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"}, "y": { - "anyOf": [{"type": "integer"}, {"type": "string"}], + "type": ["integer", "string"], "nullable": True, "description": "Also the input", }, From 6462de215032d30eedb15799b54e5d9946a80a37 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 16 May 2024 18:55:41 +0100 Subject: [PATCH 21/69] Support return types for the templates that need them --- src/transformers/utils/chat_template_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 841b0ff9e91c79..3b015e5bd4a181 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -83,11 +83,15 @@ def get_json_schema(func): if not doc: raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") doc = doc.strip() - main_doc, param_descriptions = _get_argument_descriptions_from_docstring(doc) + main_doc, param_descriptions, return_doc = _get_argument_descriptions_from_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) for arg in json_schema["properties"]: - if arg not in param_descriptions: + if arg == "return": + if return_doc is not None: # We allow a missing return docstring since most templates ignore it + json_schema["properties"][arg]["description"] = return_doc + continue + elif arg not in param_descriptions: raise ValueError( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) @@ -137,8 +141,10 @@ def add_json_schema(func): def _get_argument_descriptions_from_docstring(doc): param_pattern = r":param (\w+): (.+)" params = re.findall(param_pattern, doc) + return_pattern = r":returns?: (.+)" + return_doc = re.search(return_pattern, doc) main_doc = doc.split(":param")[0].strip() - return main_doc, dict(params) + return main_doc, dict(params), return_doc.group(1) if return_doc else None def _convert_type_hints_to_json_schema(func): @@ -154,8 +160,6 @@ def _convert_type_hints_to_json_schema(func): required.append(param_name) for param_name, param_type in type_hints.items(): - if param_name == "return": - continue properties[param_name] = _parse_type_hint(param_type) schema = {"type": "object", "properties": properties} From a49b68ea2f4e34e1bab545db66ec124ed07267d6 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 16 May 2024 19:02:50 +0100 Subject: [PATCH 22/69] Proper return type tests --- tests/utils/test_chat_template_utils.py | 30 ++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 95c7825f2acb06..2dcf8aec49456b 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -210,8 +210,7 @@ def fn(x): with self.assertRaises(ValueError): get_json_schema(fn) - def test_return_value_has_no_effect(self): - # We ignore return values, so we want to make sure they don't affect the schema + def test_return_value(self): def fn(x: int) -> int: """ Test function @@ -226,7 +225,32 @@ def fn(x: int) -> int: "description": "Test function", "parameters": { "type": "object", - "properties": {"x": {"type": "integer", "description": "The input"}}, + "properties": {"x": {"type": "integer", "description": "The input"}, "return": {"type": "integer"}}, + "required": ["x"], + }, + } + self.assertEqual(schema, expected_schema) + + def test_return_value_docstring(self): + def fn(x: int) -> int: + """ + Test function + + :param x: The input + :returns: The output + """ + return x + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The input"}, + "return": {"type": "integer", "description": "The output"}, + }, "required": ["x"], }, } From c8a021e159c77957bec9b010e9b7d184d29c67ff Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 17 May 2024 17:16:21 +0100 Subject: [PATCH 23/69] Switch to Google format docstrings --- src/transformers/utils/chat_template_utils.py | 56 ++++++++++++++----- tests/utils/test_chat_template_utils.py | 39 ++++++++----- 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3b015e5bd4a181..917fd632c35f7d 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -26,8 +26,9 @@ def get_json_schema(func): >>> ''' >>> A function that multiplies two numbers >>> - >>> :param x: The first number to multiply - >>> :param y: The second number to multiply + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply >>> ''' >>> return x * y >>> @@ -57,10 +58,11 @@ def get_json_schema(func): >>> ''' >>> A function that multiplies two numbers >>> - >>> :param x: The first number to multiply - >>> :param y: The second number to multiply - >>> ''' + >>> Args: + >>> x: The first number to multiply + >>> y: The second number to multiply >>> return x * y + >>> ''' >>> >>> multiply_schema = get_json_schema(multiply) >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") @@ -83,7 +85,7 @@ def get_json_schema(func): if not doc: raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") doc = doc.strip() - main_doc, param_descriptions, return_doc = _get_argument_descriptions_from_docstring(doc) + main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) for arg in json_schema["properties"]: @@ -138,13 +140,41 @@ def add_json_schema(func): return func -def _get_argument_descriptions_from_docstring(doc): - param_pattern = r":param (\w+): (.+)" - params = re.findall(param_pattern, doc) - return_pattern = r":returns?: (.+)" - return_doc = re.search(return_pattern, doc) - main_doc = doc.split(":param")[0].strip() - return main_doc, dict(params), return_doc.group(1) if return_doc else None +def parse_google_format_docstring(docstring): + """ + Parses a Google-style docstring to extract the function description, + argument descriptions, and return description. + + Args: + docstring (str): The docstring to parse. + + Returns: + dict: A dictionary containing the function description, arguments, and return description. + """ + # Regular expressions to match the sections + description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) + args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) + returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) + + # Extract the sections + description_match = description_re.search(docstring) + args_match = args_re.search(docstring) + returns_match = returns_re.search(docstring) + + # Clean and store the sections + description = description_match.group(1).strip() if description_match else None + args = args_match.group(1).strip() if args_match else None + returns = returns_match.group(1).strip() if returns_match else None + + # Parsing the arguments into a dictionary + args_dict = {} + if args is not None: + arg_lines = args.split("\n") + for line in arg_lines: + arg_name, arg_desc = line.split(":", 1) + args_dict[arg_name.strip()] = arg_desc.strip() + + return description, args_dict, returns def _convert_type_hints_to_json_schema(func): diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 2dcf8aec49456b..b48385ae13b174 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -10,7 +10,8 @@ def fn(x: int): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -46,7 +47,8 @@ def fn(x: Union[int, float]): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -67,7 +69,8 @@ def fn(x: Optional[int]): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -88,7 +91,8 @@ def fn(x: int = 42): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -105,7 +109,8 @@ def fn(x: List[List[Union[str, int]]]): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -132,8 +137,9 @@ def fn(x: int, y: str): """ Test function - :param x: The input - :param y: Also the input + Args: + x: The input + y: Also the input """ return x @@ -157,8 +163,9 @@ def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): """ Test function - :param x: The input - :param y: Also the input + Args: + x: The input + y: Also the input """ return x @@ -203,7 +210,8 @@ def fn(x): """ Test function - :param x: The input + Args: + x: The input """ return x @@ -215,7 +223,8 @@ def fn(x: int) -> int: """ Test function - :param x: The input + Args: + x: The input """ return x @@ -236,8 +245,12 @@ def fn(x: int) -> int: """ Test function - :param x: The input - :returns: The output + Args: + x: The input + + + Returns: + The output """ return x From 69b6d31b9620cbd26b74a43fe273708a36ab41f7 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 17 May 2024 17:32:16 +0100 Subject: [PATCH 24/69] Update chat templating docs to match new format --- docs/source/en/chat_templating.md | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 65709cfc1fb24c..bb8ec1c450d3aa 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -300,10 +300,12 @@ function that has a valid docstring with parameter annotations and valid type hi from transformers.utils import get_json_schema def multiply(a: float, b: float): - """Multiply two numbers together. + """ + A function that multiplies two numbers - :param a: The first number to multiply. - :param b: The second number to multiply. + Args: + a: The first number to multiply + b: The second number to multiply """ return a * b @@ -348,10 +350,12 @@ def current_time(): @add_json_schema def multiply(a: float, b: float): - """Multiply two numbers together. + """ + A function that multiplies two numbers - :param a: The first number to multiply. - :param b: The second number to multiply. + Args: + a: The first number to multiply + b: The second number to multiply """ return a * b @@ -363,6 +367,16 @@ model_input = tokenizer.apply_chat_template( ) ``` +#### Notes on automatic conversion + +`get_json_schema` and `add_json_schema` both expect a specific docstring format. The docstring should +begin with a description of the function, followed by an `Args:` block that describes each argument. It can also +optionally include a `Returns:` block that describes the value(s) returned by the function. Many templates ignore this, +because the model will see the return format after calling the function anyway, but some require it. + +Argument descriptions in the docstring should not include the argument types - these are read from the type hints +in the function signature instead. + ### Arguments for RAG "Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding From 7f20d44573f9593afdae90f9df2f7de6f0f70aa9 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 21 May 2024 15:45:54 +0100 Subject: [PATCH 25/69] Stop putting the return type in with the other parameters --- src/transformers/utils/chat_template_utils.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 917fd632c35f7d..af217eb844e5f5 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -88,18 +88,20 @@ def get_json_schema(func): main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) json_schema = _convert_type_hints_to_json_schema(func) + if (return_dict := json_schema["properties"].pop("return", None)) is not None: + if return_doc is not None: # We allow a missing return docstring since most templates ignore it + return_dict["description"] = return_doc for arg in json_schema["properties"]: - if arg == "return": - if return_doc is not None: # We allow a missing return docstring since most templates ignore it - json_schema["properties"][arg]["description"] = return_doc - continue - elif arg not in param_descriptions: + if arg not in param_descriptions: raise ValueError( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) json_schema["properties"][arg]["description"] = param_descriptions[arg] - return {"name": func.__name__, "description": main_doc, "parameters": json_schema} + output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} + if return_dict is not None: + output["return"] = return_dict + return output def add_json_schema(func): @@ -247,10 +249,10 @@ def _parse_type_hint(hint): elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" - return { - "type": "object", - "additionalProperties": _parse_type_hint(get_args(hint)[1]), - } + out = {"type": "object"} + if len(get_args(hint)) == 2: + out["additionalProperties"] = _parse_type_hint(get_args(hint)[1]) + return out else: raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) else: From c3cf8723e80ed3d31638172c1121d1325be6d6ab Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 22 May 2024 16:39:39 +0100 Subject: [PATCH 26/69] Add Tuple support --- src/transformers/utils/chat_template_utils.py | 25 +++++-- tests/utils/test_chat_template_utils.py | 75 ++++++++++++++++++- 2 files changed, 89 insertions(+), 11 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index af217eb844e5f5..8d9fad6ef855a1 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -239,13 +239,24 @@ def _parse_type_hint(hint): return_dict["nullable"] = True return return_dict elif origin is tuple: - raise ValueError( - "This helper does not parse Tuple types, as they are usually used to indicate that " - "each position is associated with a specific type, and this requires JSON schemas " - "that are not supported by most templates. We recommend " - "either using List instead for arguments where this is appropriate, or " - "splitting arguments with Tuple types into multiple arguments that take single inputs." - ) + if not get_args(hint): + return {"type": "array"} + if len(get_args(hint)) == 1: + raise ValueError( + "Tuple type hints should only be used when the argument has a fixed length and each " + f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " + "This should be replaced with an unwrapped type hint instead like " + f"{get_args(hint)[0]}. Alternatively, if the " + "function can actually take a tuple with multiple elements, please either indicate " + f"each element type (e.g. Tuple[{get_args(hint)[0]}, {get_args(hint)[0]}]), " + f"or if the input can be variable length, use List[{get_args(hint)[0]}] instead." + ) + if ... in get_args(hint): + raise ValueError( + "'...' is not supported in Tuple type hints. Use List[] types for variable-length" + " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in get_args(hint)]} elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index b48385ae13b174..65ad7b2a292254 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -1,5 +1,5 @@ import unittest -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from transformers.utils import get_json_schema @@ -234,9 +234,10 @@ def fn(x: int) -> int: "description": "Test function", "parameters": { "type": "object", - "properties": {"x": {"type": "integer", "description": "The input"}, "return": {"type": "integer"}}, + "properties": {"x": {"type": "integer", "description": "The input"}}, "required": ["x"], }, + "return": {"type": "integer"}, } self.assertEqual(schema, expected_schema) @@ -254,6 +255,33 @@ def fn(x: int) -> int: """ return x + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The input"}}, + "required": ["x"], + }, + "return": {"type": "integer", "description": "The output"}, + } + self.assertEqual(schema, expected_schema) + + def test_tuple(self): + def fn(x: Tuple[int, str]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + schema = get_json_schema(fn) expected_schema = { "name": "fn", @@ -261,10 +289,49 @@ def fn(x: int) -> int: "parameters": { "type": "object", "properties": { - "x": {"type": "integer", "description": "The input"}, - "return": {"type": "integer", "description": "The output"}, + "x": { + "type": "array", + "prefixItems": [{"type": "integer"}, {"type": "string"}], + "description": "The input", + } }, "required": ["x"], }, } self.assertEqual(schema, expected_schema) + + def test_single_element_tuple_fails(self): + def fn(x: Tuple[int]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Single-element tuples should just be the type itself, or List[type] for variable-length inputs + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_ellipsis_type_fails(self): + def fn(x: Tuple[int, ...]): + """ + Test function + + Args: + x: The input + + + Returns: + The output + """ + return x + + # Variable length inputs should be specified with List[type], not Tuple[type, ...] + with self.assertRaises(ValueError): + get_json_schema(fn) From 098780da374fe5fc95e238ee87f6e61b79c34de9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 14:57:13 +0100 Subject: [PATCH 27/69] No more decorator - we just do it implicitly! --- docs/source/en/chat_templating.md | 9 ++-- src/transformers/tokenization_utils_base.py | 21 ++++++++-- src/transformers/utils/__init__.py | 2 +- src/transformers/utils/chat_template_utils.py | 42 +------------------ 4 files changed, 23 insertions(+), 51 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index bb8ec1c450d3aa..169281d115b98c 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -336,19 +336,16 @@ This will yield: } ``` -We can use this function, or the equivalent [`add_json_schema`] decorator, to avoid the need to manually write JSON -schemas when passing tools to the chat template: +We can use this function to avoid the need to manually write JSON schemas when passing tools to the chat template. +In addition, if you pass functions in the `tools` argument, they will automatically be converted with this function: ```python import datetime -from transformers.utils import add_json_schema -@add_json_schema def current_time(): """Get the current local time as a string.""" return str(datetime.now()) -@add_json_schema def multiply(a: float, b: float): """ A function that multiplies two numbers @@ -369,7 +366,7 @@ model_input = tokenizer.apply_chat_template( #### Notes on automatic conversion -`get_json_schema` and `add_json_schema` both expect a specific docstring format. The docstring should +`get_json_schema` expects a specific docstring format. The docstring should begin with a description of the function, followed by an `Args:` block that describes each argument. It can also optionally include a `Returns:` block that describes the value(s) returned by the function. Many templates ignore this, because the model will see the return format after calling the function anyway, but some require it. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c7d1e20c33fc12..7c6a778fc68d29 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import lru_cache +from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -47,6 +48,7 @@ copy_func, download_url, extract_commit_hash, + get_json_schema, is_flax_available, is_jax_tensor, is_mlx_available, @@ -1815,10 +1817,21 @@ def apply_chat_template( conversations = [conversation] is_batched = False - # The add_json_schema decorator for tools adds a schema under the `json_schema` attribute. If we're passed - # decorated functions, let's extract the schema decoration now + # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas if tools is not None: - tools = [tool.json_schema if hasattr(tool, "json_schema") else tool for tool in tools] + tool_schemas = [] + for tool in tools: + if isinstance(tool, dict): + tool_schemas.append(tool) + elif isfunction(tool): + tool_schemas.append(get_json_schema(tool)) + else: + raise ValueError( + "Tools should either be a JSON schema, or a callable function with type hints " + "and a docstring suitable for auto-conversion to a schema." + ) + else: + tool_schemas = None rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present @@ -1828,7 +1841,7 @@ def apply_chat_template( chat = chat.messages rendered_chat = compiled_template.render( messages=chat, - tools=tools, + tools=tool_schemas, documents=documents, add_generation_prompt=add_generation_prompt, **template_kwargs, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index cf47a2db544eb3..ac2ecaef3a7fd6 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,7 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin -from .chat_template_utils import add_json_schema, get_json_schema +from .chat_template_utils import get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 8d9fad6ef855a1..0cb74dc973c5fe 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -78,8 +78,8 @@ def get_json_schema(func): >>> # The formatted chat can now be passed to model.generate() ``` - In many cases, it is more convenient to define tool functions with the [`add_json_schema`] decorator rather than - calling this function directly. + In many cases, it is more convenient to simply pass the functions directly to apply_chat_template and let it + autogenerate schemas than calling this function directly. """ doc = inspect.getdoc(func) if not doc: @@ -104,44 +104,6 @@ def get_json_schema(func): return output -def add_json_schema(func): - """ - This decorator adds a JSON schema to a function, based on its docstring and type hints. The JSON schema is the - same as the one generated by the [`get_json_schema`] function. It is stored in the `json_schema` attribute of the - function, which will be automatically read by `apply_chat_template()` if present. - - Example: - - ```python - >>> from transformers import AutoTokenizer - >>> from transformers.utils import get_json_schema - >>> - >>> @add_json_schema - >>> def multiply(x: float, y: float): - >>> ''' - >>> A function that multiplies two numbers - >>> - >>> :param x: The first number to multiply - >>> :param y: The second number to multiply - >>> ''' - >>> return x * y - >>> - >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") - >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] - >>> formatted_chat = tokenizer.apply_chat_template( - >>> messages, - >>> tools=[multiply], - >>> chat_template="tool_use", - >>> return_dict=True, - >>> return_tensors="pt", - >>> add_generation_prompt=True - >>> ) - >>> # The formatted chat can now be passed to model.generate() - """ - func.json_schema = get_json_schema(func) - return func - - def parse_google_format_docstring(docstring): """ Parses a Google-style docstring to extract the function description, From 0ad7fdeee2dbb6d36f5628c90d4cec59f1d5cc12 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 18:06:56 +0100 Subject: [PATCH 28/69] Add enum support to get_json_schema --- src/transformers/utils/chat_template_utils.py | 8 ++++- tests/utils/test_chat_template_utils.py | 34 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 0cb74dc973c5fe..5f045332baacbb 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,4 +1,5 @@ import inspect +import json import re from typing import Any, Union, get_args, get_origin, get_type_hints @@ -96,7 +97,12 @@ def get_json_schema(func): raise ValueError( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) - json_schema["properties"][arg]["description"] = param_descriptions[arg] + desc = param_descriptions[arg] + enum_choices = re.search(r"\(choices:\s*([^)]+)\)\s*$", desc, flags=re.IGNORECASE) + if enum_choices: + json_schema["properties"][arg]["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] + desc = enum_choices.string[: enum_choices.start()].strip() + json_schema["properties"][arg]["description"] = desc output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} if return_dict is not None: diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 65ad7b2a292254..c48d373d75c346 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -335,3 +335,37 @@ def fn(x: Tuple[int, ...]): # Variable length inputs should be specified with List[type], not Tuple[type, ...] with self.assertRaises(ValueError): get_json_schema(fn) + + def test_enum_extraction(self): + def fn(temperature_format: str): + """ + Test function + + Args: + temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"]) + + + Returns: + The temperature + """ + return -40.0 + + # Let's see if that gets correctly parsed as an enum + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "temperature_format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature format to use", + } + }, + "required": ["temperature_format"], + }, + } + + self.assertEqual(schema, expected_schema) From 90a3c5bd6cbe290dbe5d298fe275d29e1680ff8a Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 18:09:56 +0100 Subject: [PATCH 29/69] Update docstring --- src/transformers/utils/chat_template_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 5f045332baacbb..11642bf86b25e9 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -12,8 +12,14 @@ def get_json_schema(func): This function generates a JSON schema for a given function, based on its docstring and type hints. This is mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires - that the function has a docstring, and that each argument has a description in the docstring, in the format - `:param arg_name: arg_description`. It also requires that all the function arguments have a valid Python type hint. + that the function has a docstring, and that each argument has a description in the docstring, in the standard + Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. + + Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is + optional because most chat templates ignore the return value of the function. Each argument description + can also have an optional `(choices: ...)` block at the end, such as `(choices: ["tea", "coffee"])`, which will be + parsed into an `enum` field in the schema. Note that this will only be parsed correctly if it is at the end of the + line. Args: func: The function to generate a JSON schema for. From ab2e741d97a29bbb819fc0876123dd94e730cf66 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:42:21 +0100 Subject: [PATCH 30/69] Add copyright header --- src/transformers/utils/chat_template_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 11642bf86b25e9..f9fd62d95e2a6f 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import json import re From b2563c218252aba2691b872047e148cd06532b78 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:46:34 +0100 Subject: [PATCH 31/69] Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 24 +++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 7c6a778fc68d29..095c30381bca29 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1707,19 +1707,21 @@ def apply_chat_template( Args: conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts with "role" and "content" keys, representing the chat history so far. - tools (List[Dict], *optional*): A list of tools (callable functions) that will be accessible - to the model. If the template does not support function calling, this argument will have no effect. - Each tool should be passed as a JSON Schema, giving the name, description and argument types - for the tool. See our [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) + tools (`List[Dict]`, *optional*): + A list of tools (callable functions) that will be accessible to the model. If the template does not + support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, + giving the name, description and argument types for the tool. See our + [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) for more information. - documents (List[Dict[str, str]], *optional*): A list of dicts representing documents that will be accessible - to the model if it is performing RAG (retrieval-augmented generation). If the template does not support - RAG, this argument will have no effect. We recommend that each document should be a dict containing - "title" and "text" keys. Please see the RAG section of the - [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) + documents (`List[Dict[str, str]]`, *optional*): + A list of dicts representing documents that will be accessible to the model if it is performing RAG + (retrieval-augmented generation). If the template does not support RAG, this argument will have no + effect. We recommend that each document should be a dict containing "title" and "text" keys. Please + see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) for examples of passing documents with chat templates. - chat_template (str, *optional*): A Jinja template to use for this conversion. It is usually not necessary - to pass anything to this argument, as the model's template will be used by default. + chat_template (`str`, *optional*): + A Jinja template to use for this conversion. It is usually not necessary to pass anything to this + argument, as the model's template will be used by default. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the From 24c05893272d0a8d285e17f26cd84171217b8e10 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:46:49 +0100 Subject: [PATCH 32/69] Update docs/source/en/chat_templating.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/chat_templating.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 169281d115b98c..2bbfc90aa3491f 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -398,7 +398,7 @@ document2 = { model_input = tokenizer.apply_chat_template( messages, - documents = [document1, document2] + documents=[document1, document2] ) ``` From 49f1e97730e5184b8b981f2e5769b5adb47ba8b8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:46:57 +0100 Subject: [PATCH 33/69] Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/chat_template_utils.py | 140 ++++++++++-------- 1 file changed, 75 insertions(+), 65 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index f9fd62d95e2a6f..07c38373b2ecc9 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -190,73 +190,83 @@ def _convert_type_hints_to_json_schema(func): def _parse_type_hint(hint): - if (origin := get_origin(hint)) is not None: - if origin is Union: - # If it's a union of basic types, we can express that as a simple list in the schema - if all(t in BASIC_TYPES for t in get_args(hint)): - return_dict = { - "type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t not in (type(None), ...)] - } - if len(return_dict["type"]) == 1: - return_dict["type"] = return_dict["type"][0] - else: - # A union of more complex types requires us to recurse into each subtype - return_dict = { - "anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)], - } - if len(return_dict["anyOf"]) == 1: - return_dict = return_dict["anyOf"][0] - if type(None) in get_args(hint): - return_dict["nullable"] = True - return return_dict - elif origin is list: - if not get_args(hint): - return {"type": "array"} - if all(t in BASIC_TYPES for t in get_args(hint)): - # Similarly to unions, a list of basic types can be expressed as a list in the schema - items = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]} - if len(items["type"]) == 1: - items["type"] = items["type"][0] - else: - # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)]} - if len(items["anyOf"]) == 1: - items = items["anyOf"][0] - return_dict = {"type": "array", "items": items} - if type(None) in get_args(hint): - return_dict["nullable"] = True - return return_dict - elif origin is tuple: - if not get_args(hint): - return {"type": "array"} - if len(get_args(hint)) == 1: - raise ValueError( - "Tuple type hints should only be used when the argument has a fixed length and each " - f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " - "This should be replaced with an unwrapped type hint instead like " - f"{get_args(hint)[0]}. Alternatively, if the " - "function can actually take a tuple with multiple elements, please either indicate " - f"each element type (e.g. Tuple[{get_args(hint)[0]}, {get_args(hint)[0]}]), " - f"or if the input can be variable length, use List[{get_args(hint)[0]}] instead." - ) - if ... in get_args(hint): - raise ValueError( - "'...' is not supported in Tuple type hints. Use List[] types for variable-length" - " inputs instead." - ) - return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in get_args(hint)]} - elif origin is dict: - # The JSON equivalent to a dict is 'object', which mandates that all keys are strings - # However, we can specify the type of the dict values with "additionalProperties" - out = {"type": "object"} - if len(get_args(hint)) == 2: - out["additionalProperties"] = _parse_type_hint(get_args(hint)[1]) - return out - else: - raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) - else: + origin = get_origin(hint) + args = get_args(hint) + + if origin is None: return _get_json_schema_type(hint) + if origin is Union: + # If it's a union of basic types, we can express that as a simple list in the schema + if all(t in BASIC_TYPES for t in args): + return_dict = { + "type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)] + } + if len(return_dict["type"]) == 1: + return_dict["type"] = return_dict["type"][0] + else: + # A union of more complex types requires us to recurse into each subtype + return_dict = { + "anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)], + } + if len(return_dict["anyOf"]) == 1: + return_dict = return_dict["anyOf"][0] + if type(None) in args: + return_dict["nullable"] = True + return return_dict + + if origin is list: + if not args: + return {"type": "array"} + + # Similarly to unions, a list of basic types can be expressed as a list in the schema + if all(t in BASIC_TYPES for t in args): + items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]} + if len(items["type"]) == 1: + items["type"] = items["type"][0] + else: + # And a list of more complex types requires us to recurse into each subtype again + items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]} + if len(items["anyOf"]) == 1: + items = items["anyOf"][0] + + return_dict = {"type": "array", "items": items} + + if type(None) in args: + return_dict["nullable"] = True + + return return_dict + + if origin is tuple: + if not args: + return {"type": "array"} + if len(args) == 1: + raise ValueError( + "Tuple type hints should only be used when the argument has a fixed length and each " + f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " + "This should be replaced with an unwrapped type hint instead like " + f"{args[0]}. Alternatively, if the " + "function can actually take a tuple with multiple elements, please either indicate " + f"each element type (e.g. Tuple[{args[0]}, {args[0]}]), " + f"or if the input can be variable length, use List[{args[0]}] instead." + ) + if ... in args: + raise ValueError( + "'...' is not supported in Tuple type hints. Use List[] types for variable-length" + " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} + + if origin is dict: + # The JSON equivalent to a dict is 'object', which mandates that all keys are strings + # However, we can specify the type of the dict values with "additionalProperties" + out = {"type": "object"} + if len(args) == 2: + out["additionalProperties"] = _parse_type_hint(args[1]) + return out + + raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + def _get_json_schema_type(param_type): if param_type == int: From 1036c5a48e4638954cb9baaddf37bf79aad709f1 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:47:22 +0100 Subject: [PATCH 34/69] Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/chat_template_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 07c38373b2ecc9..db34f8a09ef1ba 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -169,8 +169,6 @@ def parse_google_format_docstring(docstring): def _convert_type_hints_to_json_schema(func): type_hints = get_type_hints(func) - properties = {} - signature = inspect.signature(func) required = [] for param_name, param in signature.parameters.items(): @@ -179,6 +177,7 @@ def _convert_type_hints_to_json_schema(func): if param.default == inspect.Parameter.empty: required.append(param_name) + properties = {} for param_name, param_type in type_hints.items(): properties[param_name] = _parse_type_hint(param_type) From cdbc9bc00f9d2b97ee000c998ca36aa76010982f Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:48:46 +0100 Subject: [PATCH 35/69] Add copyright header --- tests/utils/test_chat_template_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index c48d373d75c346..13330ddb212ce6 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -1,3 +1,17 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from typing import List, Optional, Tuple, Union From dbb157f2ae669eeef6066bcbf72efe2c2280f0f2 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 19:49:16 +0100 Subject: [PATCH 36/69] make fixup --- src/transformers/tokenization_utils_base.py | 6 +++--- src/transformers/utils/chat_template_utils.py | 13 +++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 095c30381bca29..966f6045183302 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1707,10 +1707,10 @@ def apply_chat_template( Args: conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts with "role" and "content" keys, representing the chat history so far. - tools (`List[Dict]`, *optional*): + tools (`List[Dict]`, *optional*): A list of tools (callable functions) that will be accessible to the model. If the template does not support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, - giving the name, description and argument types for the tool. See our + giving the name, description and argument types for the tool. See our [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) for more information. documents (`List[Dict[str, str]]`, *optional*): @@ -1719,7 +1719,7 @@ def apply_chat_template( effect. We recommend that each document should be a dict containing "title" and "text" keys. Please see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) for examples of passing documents with chat templates. - chat_template (`str`, *optional*): + chat_template (`str`, *optional*): A Jinja template to use for this conversion. It is usually not necessary to pass anything to this argument, as the model's template will be used by default. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index db34f8a09ef1ba..6de19103a42191 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -191,16 +191,14 @@ def _convert_type_hints_to_json_schema(func): def _parse_type_hint(hint): origin = get_origin(hint) args = get_args(hint) - + if origin is None: return _get_json_schema_type(hint) if origin is Union: # If it's a union of basic types, we can express that as a simple list in the schema if all(t in BASIC_TYPES for t in args): - return_dict = { - "type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)] - } + return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]} if len(return_dict["type"]) == 1: return_dict["type"] = return_dict["type"][0] else: @@ -230,10 +228,10 @@ def _parse_type_hint(hint): items = items["anyOf"][0] return_dict = {"type": "array", "items": items} - + if type(None) in args: return_dict["nullable"] = True - + return return_dict if origin is tuple: @@ -251,8 +249,7 @@ def _parse_type_hint(hint): ) if ... in args: raise ValueError( - "'...' is not supported in Tuple type hints. Use List[] types for variable-length" - " inputs instead." + "'...' is not supported in Tuple type hints. Use List[] types for variable-length" " inputs instead." ) return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} From 6171a9f0b8aa6d9b1fcd66f422a7cb7a1b107f92 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 20:03:49 +0100 Subject: [PATCH 37/69] Fix indentation --- src/transformers/utils/chat_template_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 6de19103a42191..ecfee8c68e4b22 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -89,13 +89,13 @@ def get_json_schema(func): >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] >>> formatted_chat = tokenizer.apply_chat_template( - >>> messages, - >>> tools=[multiply_schema], - >>> chat_template="tool_use", - >>> return_dict=True, - >>> return_tensors="pt", - >>> add_generation_prompt=True - >>> ) + >>> messages, + >>> tools=[multiply_schema], + >>> chat_template="tool_use", + >>> return_dict=True, + >>> return_tensors="pt", + >>> add_generation_prompt=True + >>> ) >>> # The formatted chat can now be passed to model.generate() ``` From 098d6e6a951811041cb50d8ff7c1b19227d28170 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 20:06:06 +0100 Subject: [PATCH 38/69] Reformat chat_template_utils --- src/transformers/utils/chat_template_utils.py | 298 +++++++++--------- 1 file changed, 149 insertions(+), 149 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index ecfee8c68e4b22..97a61d9cd18f7b 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -21,6 +21,155 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) +def _get_json_schema_type(param_type): + if param_type == int: + return {"type": "integer"} + elif param_type == float: + return {"type": "number"} + elif param_type == str: + return {"type": "string"} + elif param_type == bool: + return {"type": "boolean"} + elif param_type == Any: + return {} + else: + return {"type": "object"} + + +def _parse_type_hint(hint): + origin = get_origin(hint) + args = get_args(hint) + + if origin is None: + return _get_json_schema_type(hint) + + if origin is Union: + # If it's a union of basic types, we can express that as a simple list in the schema + if all(t in BASIC_TYPES for t in args): + return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]} + if len(return_dict["type"]) == 1: + return_dict["type"] = return_dict["type"][0] + else: + # A union of more complex types requires us to recurse into each subtype + return_dict = { + "anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)], + } + if len(return_dict["anyOf"]) == 1: + return_dict = return_dict["anyOf"][0] + if type(None) in args: + return_dict["nullable"] = True + return return_dict + + if origin is list: + if not args: + return {"type": "array"} + + # Similarly to unions, a list of basic types can be expressed as a list in the schema + if all(t in BASIC_TYPES for t in args): + items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]} + if len(items["type"]) == 1: + items["type"] = items["type"][0] + else: + # And a list of more complex types requires us to recurse into each subtype again + items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]} + if len(items["anyOf"]) == 1: + items = items["anyOf"][0] + + return_dict = {"type": "array", "items": items} + + if type(None) in args: + return_dict["nullable"] = True + + return return_dict + + if origin is tuple: + if not args: + return {"type": "array"} + if len(args) == 1: + raise ValueError( + "Tuple type hints should only be used when the argument has a fixed length and each " + f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " + "This should be replaced with an unwrapped type hint instead like " + f"{args[0]}. Alternatively, if the " + "function can actually take a tuple with multiple elements, please either indicate " + f"each element type (e.g. Tuple[{args[0]}, {args[0]}]), " + f"or if the input can be variable length, use List[{args[0]}] instead." + ) + if ... in args: + raise ValueError( + "'...' is not supported in Tuple type hints. Use List[] types for variable-length" " inputs instead." + ) + return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} + + if origin is dict: + # The JSON equivalent to a dict is 'object', which mandates that all keys are strings + # However, we can specify the type of the dict values with "additionalProperties" + out = {"type": "object"} + if len(args) == 2: + out["additionalProperties"] = _parse_type_hint(args[1]) + return out + + raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + + +def _convert_type_hints_to_json_schema(func): + type_hints = get_type_hints(func) + signature = inspect.signature(func) + required = [] + for param_name, param in signature.parameters.items(): + if param.annotation == inspect.Parameter.empty: + raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}") + if param.default == inspect.Parameter.empty: + required.append(param_name) + + properties = {} + for param_name, param_type in type_hints.items(): + properties[param_name] = _parse_type_hint(param_type) + + schema = {"type": "object", "properties": properties} + if required: + schema["required"] = required + + return schema + + +def parse_google_format_docstring(docstring): + """ + Parses a Google-style docstring to extract the function description, + argument descriptions, and return description. + + Args: + docstring (str): The docstring to parse. + + Returns: + dict: A dictionary containing the function description, arguments, and return description. + """ + # Regular expressions to match the sections + description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) + args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) + returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) + + # Extract the sections + description_match = description_re.search(docstring) + args_match = args_re.search(docstring) + returns_match = returns_re.search(docstring) + + # Clean and store the sections + description = description_match.group(1).strip() if description_match else None + args = args_match.group(1).strip() if args_match else None + returns = returns_match.group(1).strip() if returns_match else None + + # Parsing the arguments into a dictionary + args_dict = {} + if args is not None: + arg_lines = args.split("\n") + for line in arg_lines: + arg_name, arg_desc = line.split(":", 1) + args_dict[arg_name.strip()] = arg_desc.strip() + + return description, args_dict, returns + + def get_json_schema(func): """ This function generates a JSON schema for a given function, based on its docstring and type hints. This is @@ -128,152 +277,3 @@ def get_json_schema(func): if return_dict is not None: output["return"] = return_dict return output - - -def parse_google_format_docstring(docstring): - """ - Parses a Google-style docstring to extract the function description, - argument descriptions, and return description. - - Args: - docstring (str): The docstring to parse. - - Returns: - dict: A dictionary containing the function description, arguments, and return description. - """ - # Regular expressions to match the sections - description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) - args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) - returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) - - # Extract the sections - description_match = description_re.search(docstring) - args_match = args_re.search(docstring) - returns_match = returns_re.search(docstring) - - # Clean and store the sections - description = description_match.group(1).strip() if description_match else None - args = args_match.group(1).strip() if args_match else None - returns = returns_match.group(1).strip() if returns_match else None - - # Parsing the arguments into a dictionary - args_dict = {} - if args is not None: - arg_lines = args.split("\n") - for line in arg_lines: - arg_name, arg_desc = line.split(":", 1) - args_dict[arg_name.strip()] = arg_desc.strip() - - return description, args_dict, returns - - -def _convert_type_hints_to_json_schema(func): - type_hints = get_type_hints(func) - signature = inspect.signature(func) - required = [] - for param_name, param in signature.parameters.items(): - if param.annotation == inspect.Parameter.empty: - raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}") - if param.default == inspect.Parameter.empty: - required.append(param_name) - - properties = {} - for param_name, param_type in type_hints.items(): - properties[param_name] = _parse_type_hint(param_type) - - schema = {"type": "object", "properties": properties} - if required: - schema["required"] = required - - return schema - - -def _parse_type_hint(hint): - origin = get_origin(hint) - args = get_args(hint) - - if origin is None: - return _get_json_schema_type(hint) - - if origin is Union: - # If it's a union of basic types, we can express that as a simple list in the schema - if all(t in BASIC_TYPES for t in args): - return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]} - if len(return_dict["type"]) == 1: - return_dict["type"] = return_dict["type"][0] - else: - # A union of more complex types requires us to recurse into each subtype - return_dict = { - "anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)], - } - if len(return_dict["anyOf"]) == 1: - return_dict = return_dict["anyOf"][0] - if type(None) in args: - return_dict["nullable"] = True - return return_dict - - if origin is list: - if not args: - return {"type": "array"} - - # Similarly to unions, a list of basic types can be expressed as a list in the schema - if all(t in BASIC_TYPES for t in args): - items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]} - if len(items["type"]) == 1: - items["type"] = items["type"][0] - else: - # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]} - if len(items["anyOf"]) == 1: - items = items["anyOf"][0] - - return_dict = {"type": "array", "items": items} - - if type(None) in args: - return_dict["nullable"] = True - - return return_dict - - if origin is tuple: - if not args: - return {"type": "array"} - if len(args) == 1: - raise ValueError( - "Tuple type hints should only be used when the argument has a fixed length and each " - f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " - "This should be replaced with an unwrapped type hint instead like " - f"{args[0]}. Alternatively, if the " - "function can actually take a tuple with multiple elements, please either indicate " - f"each element type (e.g. Tuple[{args[0]}, {args[0]}]), " - f"or if the input can be variable length, use List[{args[0]}] instead." - ) - if ... in args: - raise ValueError( - "'...' is not supported in Tuple type hints. Use List[] types for variable-length" " inputs instead." - ) - return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} - - if origin is dict: - # The JSON equivalent to a dict is 'object', which mandates that all keys are strings - # However, we can specify the type of the dict values with "additionalProperties" - out = {"type": "object"} - if len(args) == 2: - out["additionalProperties"] = _parse_type_hint(args[1]) - return out - - raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) - - -def _get_json_schema_type(param_type): - if param_type == int: - return {"type": "integer"} - elif param_type == float: - return {"type": "number"} - elif param_type == str: - return {"type": "string"} - elif param_type == bool: - return {"type": "boolean"} - elif param_type == Any: - return {} - else: - return {"type": "object"} From 0a92408be9592a4aa019e098747f94d848ad41c9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 20:09:12 +0100 Subject: [PATCH 39/69] Correct return value --- src/transformers/utils/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 97a61d9cd18f7b..163c10bc9084a2 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -142,7 +142,7 @@ def parse_google_format_docstring(docstring): docstring (str): The docstring to parse. Returns: - dict: A dictionary containing the function description, arguments, and return description. + The function description, arguments, and return description. """ # Regular expressions to match the sections description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) From a4dc182dfb18ef033747f13498c41f929f3c7bc6 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 23 May 2024 20:10:08 +0100 Subject: [PATCH 40/69] Make regexes module-level --- src/transformers/utils/chat_template_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 163c10bc9084a2..83e02ee13ee84e 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -19,6 +19,9 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) +description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) +args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) +returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) def _get_json_schema_type(param_type): @@ -144,10 +147,6 @@ def parse_google_format_docstring(docstring): Returns: The function description, arguments, and return description. """ - # Regular expressions to match the sections - description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) - args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) - returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) # Extract the sections description_match = description_re.search(docstring) From 390596e19e64275172dacd8d6d78d534f4cdb7df Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 14:55:01 +0100 Subject: [PATCH 41/69] Support more complex, multi-line arg docstrings --- src/transformers/utils/chat_template_utils.py | 15 +++++---- tests/utils/test_chat_template_utils.py | 32 +++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 83e02ee13ee84e..00f7d990c410ca 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -21,6 +21,7 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) +args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\(\w+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL) returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) @@ -155,16 +156,16 @@ def parse_google_format_docstring(docstring): # Clean and store the sections description = description_match.group(1).strip() if description_match else None - args = args_match.group(1).strip() if args_match else None + docstring_args = args_match.group(1).strip() if args_match else None returns = returns_match.group(1).strip() if returns_match else None # Parsing the arguments into a dictionary - args_dict = {} - if args is not None: - arg_lines = args.split("\n") - for line in arg_lines: - arg_name, arg_desc = line.split(":", 1) - args_dict[arg_name.strip()] = arg_desc.strip() + if docstring_args is not None: + docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines + matches = args_split_re.findall(docstring_args) + args_dict = {match[0]: match[1].replace("\n", " ").strip() for match in matches} + else: + args_dict = {} return description, args_dict, returns diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 13330ddb212ce6..de09f37619a6cf 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -383,3 +383,35 @@ def fn(temperature_format: str): } self.assertEqual(schema, expected_schema) + + def test_multiline_docstring_with_types(self): + def fn(x: int, y: int): + """ + Test function + + Args: + x (int): The first input + + y (int): The second input. This is a longer description + that spans multiple lines with indentation and stuff. + + Returns: + God knows what + """ + pass + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The first input"}, + "y": {"type": "integer", "description": "The second input. This is a longer description"}, + }, + "required": ["x", "y"], + }, + } + + self.assertEqual(schema, expected_schema) From 5568cb314d67202a2d7af2051224d4f0a6bdc370 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 14:55:55 +0100 Subject: [PATCH 42/69] Update error message for ... --- src/transformers/utils/chat_template_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 00f7d990c410ca..e2005557a03112 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -101,7 +101,8 @@ def _parse_type_hint(hint): ) if ... in args: raise ValueError( - "'...' is not supported in Tuple type hints. Use List[] types for variable-length" " inputs instead." + "Conversion of '...' is not supported in Tuple type hints. " + "Use List[] types for variable-length" " inputs instead." ) return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} From ae443e206040fc91092a155ba7e75b57e6f9dd67 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 15:04:28 +0100 Subject: [PATCH 43/69] Update ruff --- src/transformers/utils/chat_template_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e2005557a03112..c3d09bcd3f64de 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -102,7 +102,8 @@ def _parse_type_hint(hint): if ... in args: raise ValueError( "Conversion of '...' is not supported in Tuple type hints. " - "Use List[] types for variable-length" " inputs instead." + "Use List[] types for variable-length" + " inputs instead." ) return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} From 85cba4e6a9fa74beb90b110a92344fa94066555c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 15:27:39 +0100 Subject: [PATCH 44/69] Add document type validation --- src/transformers/tokenization_utils_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 966f6045183302..14a31560c6f1a8 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1835,6 +1835,11 @@ def apply_chat_template( else: tool_schemas = None + if documents is not None: + for document in documents: + if not isinstance(document, dict): + raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!") + rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present for chat in conversations: From 23692c3c45508f2a7ca9ce37abe9b0e5033da00b Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 15:48:22 +0100 Subject: [PATCH 45/69] Refactor docs --- docs/source/en/chat_templating.md | 146 +++++++++++++++--------------- 1 file changed, 74 insertions(+), 72 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 2bbfc90aa3491f..850a6c0a3c8229 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -233,9 +233,9 @@ The sun. From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column. -## Can I pass other arguments to the chat template? +## Advanced: Extra inputs to chat templates -Yes, you can! The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword +The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass strings, lists, dicts or whatever else you want. @@ -244,57 +244,57 @@ That said, there are some common use-cases for these extra arguments, such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases, we have some opinionated recommendations about what the names and formats of these arguments should be. -### Arguments for tool use +### Tool use / function calling -"Tool use" LLMs can choose to call functions as external tools before generating an answer. Our recommendation for -tool use models is that their template -should accept a `tools` argument. This should be a list of tools, defined via [JSON Schema](https://json-schema.org/). Each "tool" -is a single function that the model can choose to call, and the schema should include the function name, its description -and the expected spec for its arguments. - -#### Example +"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools +to a tool-use model, you can simply pass a list of functions to the `tools` argument: ```python -# A simple function that takes no arguments -current_time = { - "name": "current_time", - "description": "Get the current local time as a string.", - "parameters": { - 'type': 'object', - 'properties': {} - }, - } +import datetime -# A more complete function that takes two numerical arguments -multiply = { - "name": "multiply", - "description": "Multiply two numbers together.", - "parameters": { - "type": "object", - "properties": { - "a": {"type": "number", "description": "The first number to multiply."}, - "b": {"type": "number", "description": "The second number to multiply."}, - }, - "required": ["a", "b"], - } - } +def current_time(): + """Get the current local time as a string.""" + return str(datetime.now()) + +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + +tools = [current_time, multiply] model_input = tokenizer.apply_chat_template( messages, - tools = [current_time, multiply] + tools=tools ) ``` -JSON schemas permit highly detailed parameter specifications, so you can pass in functions with very complex, nested -arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art -models. We recommend trying to keep your tool schemas simple and flat where possible. +In order for this to work correctly, you should use the following conventions, so that the functions can be parsed +correctly as tools: + +- Each function should have a descriptive name +- Every argument should have a type hint +- The function should have a docstring in the standard Google style (in other words, an initial function description + followed by an `Args:` block that describes the arguments. It is not necessary to include types in the `Args:` block. +- The function can have a return type and a `Returns:` block in the docstring. However, these are optional + because most tool-use models ignore them. -### Automated function conversion for tool use +### Understanding tool schemas -Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means -that writing them can be annoying. Don't panic, though, we have a solution! You can simply define Python functions -as tools, and use the [`get_json_schema`] function. This function will automatically generate a JSON schema for any -function that has a valid docstring with parameter annotations and valid type hints. Let's see it in action! +Each function you pass to the `tools` argument of `apply_chat_template` is converted into a +[JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas +are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they +never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they +need to pass to them - they care about what the tools do and how to use them, not how they work! + +Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions +follow the specification above, but if you encounter problems, or you simply want more control over the conversion, +you can handle the conversion manually. Here is an example of a manual schema conversion. ```python from transformers.utils import get_json_schema @@ -336,45 +336,47 @@ This will yield: } ``` -We can use this function to avoid the need to manually write JSON schemas when passing tools to the chat template. -In addition, if you pass functions in the `tools` argument, they will automatically be converted with this function: +If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at +all. JSON schemas can be passed directly to the `tools` argument of +`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful, +though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We +recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments) +to a minimum. -```python -import datetime - -def current_time(): - """Get the current local time as a string.""" - return str(datetime.now()) +Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`: -def multiply(a: float, b: float): - """ - A function that multiplies two numbers - - Args: - a: The first number to multiply - b: The second number to multiply - """ - return a * b +```python +# A simple function that takes no arguments +current_time = { + "name": "current_time", + "description": "Get the current local time as a string.", + "parameters": { + 'type': 'object', + 'properties': {} + }, + } -tools = [current_time, multiply] +# A more complete function that takes two numerical arguments +multiply = { + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "The first number to multiply."}, + "b": {"type": "number", "description": "The second number to multiply."}, + }, + "required": ["a", "b"], + } + } model_input = tokenizer.apply_chat_template( messages, - tools=tools + tools = [current_time, multiply] ) ``` -#### Notes on automatic conversion - -`get_json_schema` expects a specific docstring format. The docstring should -begin with a description of the function, followed by an `Args:` block that describes each argument. It can also -optionally include a `Returns:` block that describes the value(s) returned by the function. Many templates ignore this, -because the model will see the return format after calling the function anyway, but some require it. - -Argument descriptions in the docstring should not include the argument types - these are read from the type hints -in the function signature instead. - -### Arguments for RAG +### Retrieval-augmented generation "Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our @@ -383,7 +385,7 @@ should accept a `documents` argument. This should be a list of documents, where is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler than the JSON schemas used for tools, no helper functions are necessary. -#### Example +Here's an example of a RAG template in action: ```python document1 = { From 6ebad3ad0d7b4604d0994b144b9ad59e732a463c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 15:50:21 +0100 Subject: [PATCH 46/69] Refactor docs --- docs/source/en/chat_templating.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 850a6c0a3c8229..a1108ee0387e59 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -290,7 +290,9 @@ Each function you pass to the `tools` argument of `apply_chat_template` is conve [JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they -need to pass to them - they care about what the tools do and how to use them, not how they work! +need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you +to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and +return the response in the chat. Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions follow the specification above, but if you encounter problems, or you simply want more control over the conversion, From 19a9624b0f6e14facf1e4da0008e439f486857df Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 16:04:55 +0100 Subject: [PATCH 47/69] Refactor docs --- docs/source/en/chat_templating.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index a1108ee0387e59..03889529325e56 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -274,13 +274,14 @@ model_input = tokenizer.apply_chat_template( ) ``` -In order for this to work correctly, you should use the following conventions, so that the functions can be parsed -correctly as tools: - -- Each function should have a descriptive name -- Every argument should have a type hint -- The function should have a docstring in the standard Google style (in other words, an initial function description - followed by an `Args:` block that describes the arguments. It is not necessary to include types in the `Args:` block. +In order for this to work correctly, you should write your functions in the format above, so that they can be parsed +correctly as tools. Specifically, you should follow these rules: + +- The function should have a descriptive name +- Every argument must have a type hint +- The function must have a docstring in the standard Google style (in other words, an initial function description + followed by an `Args:` block that describes the arguments, unless the function does not have any arguments. + It is not necessary to include types in the `Args:` block. - The function can have a return type and a `Returns:` block in the docstring. However, these are optional because most tool-use models ignore them. From 28b5bc24fd05a1f3b075cdf99e566847f7c9f3ac Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 16:27:09 +0100 Subject: [PATCH 48/69] Clean up Tuple error --- src/transformers/utils/chat_template_utils.py | 12 ++++---- tests/utils/test_chat_template_utils.py | 28 ++++++++++++++++++- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index c3d09bcd3f64de..c7294c2f428c5d 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -90,14 +90,12 @@ def _parse_type_hint(hint): if not args: return {"type": "array"} if len(args) == 1: + breakpoint() raise ValueError( - "Tuple type hints should only be used when the argument has a fixed length and each " - f"element has a specific type. The hint {hint} indicates a Tuple of length 1. " - "This should be replaced with an unwrapped type hint instead like " - f"{args[0]}. Alternatively, if the " - "function can actually take a tuple with multiple elements, please either indicate " - f"each element type (e.g. Tuple[{args[0]}, {args[0]}]), " - f"or if the input can be variable length, use List[{args[0]}] instead." + f"The type hint {hint.replace('typing.', '')} is a Tuple with a single element, which we do not " + "support as it is rarely necessary. If this input can contain more than one element, we recommend " + "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " + "pass the element directly." ) if ... in args: raise ValueError( diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index de09f37619a6cf..722ac73d99f88e 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -137,7 +137,7 @@ def fn(x: List[List[Union[str, int]]]): "properties": { "x": { "type": "array", - "items": {"type": "array", "items": {"type": ["string", "integer"]}}, + "items": {"type": "array", "items": {"type": ["integer", "string"]}}, "description": "The input", } }, @@ -415,3 +415,29 @@ def fn(x: int, y: int): } self.assertEqual(schema, expected_schema) + + def test_everything_all_at_once(self): + def fn( + x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str]] = (42, "hello") + ) -> Tuple[int, str]: + """ + Test function with multiple args, and docstring args that we have to strip out. + + Args: + x (str): The first input. It's got a big multiline + description and also contains + (choices: ["a", "b", "c"]) + + y (List[int, str], *optional*): The second input. It's a big list with a single-line description. + + z (Tuple[int, str]): The third input. It's some kind of tuple with a default arg. + + Returns: + The output. The return description is also a big multiline + description that spans multiple lines. + """ + pass + + schema = get_json_schema(fn) + breakpoint() + self.assertEqual(schema, expected_schema) From 880bc997aba4d93baf0d57082353555eba56b966 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 16:40:37 +0100 Subject: [PATCH 49/69] Add an extra test for very complex defs and docstrings and clean everything up for it --- src/transformers/utils/chat_template_utils.py | 8 ++--- tests/utils/test_chat_template_utils.py | 34 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index c7294c2f428c5d..ce0b6eb7fa7548 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -21,7 +21,7 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) -args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\(\w+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL) +args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\([\w\s\[\],.*]+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL) returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) @@ -90,10 +90,10 @@ def _parse_type_hint(hint): if not args: return {"type": "array"} if len(args) == 1: - breakpoint() raise ValueError( - f"The type hint {hint.replace('typing.', '')} is a Tuple with a single element, which we do not " - "support as it is rarely necessary. If this input can contain more than one element, we recommend " + f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " + "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " + "more than one element, we recommend " "using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just " "pass the element directly." ) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 722ac73d99f88e..c1b503afa5b2bf 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -418,7 +418,7 @@ def fn(x: int, y: int): def test_everything_all_at_once(self): def fn( - x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str]] = (42, "hello") + x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str], str] = (42, "hello") ) -> Tuple[int, str]: """ Test function with multiple args, and docstring args that we have to strip out. @@ -428,9 +428,9 @@ def fn( description and also contains (choices: ["a", "b", "c"]) - y (List[int, str], *optional*): The second input. It's a big list with a single-line description. + y (List[Union[int, str], *optional*): The second input. It's a big list with a single-line description. - z (Tuple[int, str]): The third input. It's some kind of tuple with a default arg. + z (Tuple[Union[int, str], str]): The third input. It's some kind of tuple with a default arg. Returns: The output. The return description is also a big multiline @@ -439,5 +439,31 @@ def fn( pass schema = get_json_schema(fn) - breakpoint() + expected_schema = { + "name": "fn", + "description": "Test function with multiple args, and docstring args that we have to strip out.", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "string", "description": "The first input. It's got a big multiline"}, + "y": { + "type": "array", + "items": {"type": ["integer", "string"]}, + "nullable": True, + "description": "The second input. It's a big list with a single-line description.", + }, + "z": { + "type": "array", + "prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}], + "description": "The third input. It's some kind of tuple with a default arg.", + }, + }, + "required": ["x", "y"], + }, + "return": { + "type": "array", + "prefixItems": [{"type": "integer"}, {"type": "string"}], + "description": "The output. The return description is also a big multiline\n description that spans multiple lines.", + }, + } self.assertEqual(schema, expected_schema) From 58f69db7e2212e37b4878cbdeed9c48878a53e1a Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 16:47:32 +0100 Subject: [PATCH 50/69] Document enum block --- src/transformers/utils/chat_template_utils.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index ce0b6eb7fa7548..e26e5aec60930d 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -179,10 +179,7 @@ def get_json_schema(func): Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint. Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is - optional because most chat templates ignore the return value of the function. Each argument description - can also have an optional `(choices: ...)` block at the end, such as `(choices: ["tea", "coffee"])`, which will be - parsed into an `enum` field in the schema. Note that this will only be parsed correctly if it is at the end of the - line. + optional because most chat templates ignore the return value of the function. Args: func: The function to generate a JSON schema for. @@ -248,8 +245,37 @@ def get_json_schema(func): >>> # The formatted chat can now be passed to model.generate() ``` - In many cases, it is more convenient to simply pass the functions directly to apply_chat_template and let it - autogenerate schemas than calling this function directly. + Each argument description can also have an optional `(choices: ...)` block at the end, such as + `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will + only be parsed correctly if it is at the end of the line: + + ```python + >>> def drink_beverage(beverage: str): + >>> ''' + >>> A function that drinks a beverage + >>> + >>> Args: + >>> beverage: The beverage to drink (choices: ["tea", "coffee"]) + >>> ''' + >>> pass + >>> + >>> print(get_json_schema(drink_beverage)) + ``` + { + 'name': 'drink_beverage', + 'description': 'A function that drinks a beverage', + 'parameters': { + 'type': 'object', + 'properties': { + 'beverage': { + 'type': 'string', + 'enum': ['tea', 'coffee'], + 'description': 'The beverage to drink' + } + }, + 'required': ['beverage'] + } + } """ doc = inspect.getdoc(func) if not doc: From df3123e3266832148f206e362cc1913d93b1a90d Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 24 May 2024 16:59:56 +0100 Subject: [PATCH 51/69] Quick test fixes --- tests/utils/test_chat_template_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index c1b503afa5b2bf..e795468996b91a 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -137,7 +137,7 @@ def fn(x: List[List[Union[str, int]]]): "properties": { "x": { "type": "array", - "items": {"type": "array", "items": {"type": ["integer", "string"]}}, + "items": {"type": "array", "items": {"type": ["string", "integer"]}}, "description": "The input", } }, @@ -418,7 +418,7 @@ def fn(x: int, y: int): def test_everything_all_at_once(self): def fn( - x: str, y: Optional[List[Union[int, str]]], z: Tuple[Union[int, str], str] = (42, "hello") + x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello") ) -> Tuple[int, str]: """ Test function with multiple args, and docstring args that we have to strip out. @@ -448,13 +448,13 @@ def fn( "x": {"type": "string", "description": "The first input. It's got a big multiline"}, "y": { "type": "array", - "items": {"type": ["integer", "string"]}, + "items": {"type": ["string", "integer"]}, "nullable": True, "description": "The second input. It's a big list with a single-line description.", }, "z": { "type": "array", - "prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}], + "prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}], "description": "The third input. It's some kind of tuple with a default arg.", }, }, From cfb1190748929a6ab814d83fc46341cf136d913d Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 28 May 2024 16:17:34 +0100 Subject: [PATCH 52/69] Stop supporting type hints in docstring to fix bugs and simplify the regex --- src/transformers/utils/chat_template_utils.py | 12 ++++++++-- tests/utils/test_chat_template_utils.py | 23 ++++++++++++------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e26e5aec60930d..81ee9a5695f005 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -21,7 +21,15 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) -args_split_re = re.compile(r"(?:^|\n)\s*(\w+)\s*(?:\([\w\s\[\],.*]+\))?:\s*(.*?)\s*(?=\n\s*\w|\Z)", re.DOTALL) +args_split_re = re.compile( + r""" +(?:^|\n) # Match the start of the args block, or a newline +\s*(\w+):\s* # Capture the argument name and strip spacing +(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing +(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block +""", + re.DOTALL | re.VERBOSE, +) returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) @@ -163,7 +171,7 @@ def parse_google_format_docstring(docstring): if docstring_args is not None: docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines matches = args_split_re.findall(docstring_args) - args_dict = {match[0]: match[1].replace("\n", " ").strip() for match in matches} + args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches} else: args_dict = {} diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index e795468996b91a..b1aa23b4f6a0be 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -390,10 +390,10 @@ def fn(x: int, y: int): Test function Args: - x (int): The first input + x: The first input - y (int): The second input. This is a longer description - that spans multiple lines with indentation and stuff. + y: The second input. This is a longer description + that spans multiple lines with indentation and stuff. Returns: God knows what @@ -408,7 +408,10 @@ def fn(x: int, y: int): "type": "object", "properties": { "x": {"type": "integer", "description": "The first input"}, - "y": {"type": "integer", "description": "The second input. This is a longer description"}, + "y": { + "type": "integer", + "description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.", + }, }, "required": ["x", "y"], }, @@ -424,13 +427,13 @@ def fn( Test function with multiple args, and docstring args that we have to strip out. Args: - x (str): The first input. It's got a big multiline + x: The first input. It's got a big multiline description and also contains (choices: ["a", "b", "c"]) - y (List[Union[int, str], *optional*): The second input. It's a big list with a single-line description. + y: The second input. It's a big list with a single-line description. - z (Tuple[Union[int, str], str]): The third input. It's some kind of tuple with a default arg. + z: The third input. It's some kind of tuple with a default arg. Returns: The output. The return description is also a big multiline @@ -445,7 +448,11 @@ def fn( "parameters": { "type": "object", "properties": { - "x": {"type": "string", "description": "The first input. It's got a big multiline"}, + "x": { + "type": "string", + "enum": ["a", "b", "c"], + "description": "The first input. It's got a big multiline description and also contains", + }, "y": { "type": "array", "items": {"type": ["string", "integer"]}, From fc66acc74b4561e5a34ff4dc8e1758a6e4783813 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 28 May 2024 16:18:49 +0100 Subject: [PATCH 53/69] Update docs for the regex change --- docs/source/en/chat_templating.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 03889529325e56..8c6d30eddde971 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -281,7 +281,7 @@ correctly as tools. Specifically, you should follow these rules: - Every argument must have a type hint - The function must have a docstring in the standard Google style (in other words, an initial function description followed by an `Args:` block that describes the arguments, unless the function does not have any arguments. - It is not necessary to include types in the `Args:` block. + Do not include types in the `Args:` block - put them in the type hints in the function header instead. - The function can have a return type and a `Returns:` block in the docstring. However, these are optional because most tool-use models ignore them. From bb6ba18060d04e3646d95512f7159679086af935 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 28 May 2024 16:21:22 +0100 Subject: [PATCH 54/69] Clean up enum regex --- src/transformers/utils/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 81ee9a5695f005..3c669c4bc31c99 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -301,7 +301,7 @@ def get_json_schema(func): f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) desc = param_descriptions[arg] - enum_choices = re.search(r"\(choices:\s*([^)]+)\)\s*$", desc, flags=re.IGNORECASE) + enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) if enum_choices: json_schema["properties"][arg]["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] desc = enum_choices.string[: enum_choices.start()].strip() From eda660ff4a0b79a82233e793a14fdbf66c6c298e Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 29 May 2024 18:12:40 +0100 Subject: [PATCH 55/69] Wrap functions in {"type": "function", "function": ...} --- src/transformers/utils/chat_template_utils.py | 2 +- tests/utils/test_chat_template_utils.py | 28 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 3c669c4bc31c99..6fcaed19ab4572 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -310,4 +310,4 @@ def get_json_schema(func): output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} if return_dict is not None: output["return"] = return_dict - return output + return {"type": "function", "function": output} diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index b1aa23b4f6a0be..7ebf93d0205bf1 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -39,7 +39,7 @@ def fn(x: int): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_no_arguments(self): def fn(): @@ -54,7 +54,7 @@ def fn(): "description": "Test function", "parameters": {"type": "object", "properties": {}}, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_union(self): def fn(x: Union[int, float]): @@ -76,7 +76,7 @@ def fn(x: Union[int, float]): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_optional(self): def fn(x: Optional[int]): @@ -98,7 +98,7 @@ def fn(x: Optional[int]): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_default_arg(self): def fn(x: int = 42): @@ -116,7 +116,7 @@ def fn(x: int = 42): "description": "Test function", "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}}, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_nested_list(self): def fn(x: List[List[Union[str, int]]]): @@ -144,7 +144,7 @@ def fn(x: List[List[Union[str, int]]]): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_multiple_arguments(self): def fn(x: int, y: str): @@ -170,7 +170,7 @@ def fn(x: int, y: str): "required": ["x", "y"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_multiple_complex_arguments(self): def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): @@ -200,7 +200,7 @@ def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_missing_docstring(self): def fn(x: int): @@ -253,7 +253,7 @@ def fn(x: int) -> int: }, "return": {"type": "integer"}, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_return_value_docstring(self): def fn(x: int) -> int: @@ -280,7 +280,7 @@ def fn(x: int) -> int: }, "return": {"type": "integer", "description": "The output"}, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_tuple(self): def fn(x: Tuple[int, str]): @@ -312,7 +312,7 @@ def fn(x: Tuple[int, str]): "required": ["x"], }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_single_element_tuple_fails(self): def fn(x: Tuple[int]): @@ -382,7 +382,7 @@ def fn(temperature_format: str): }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_multiline_docstring_with_types(self): def fn(x: int, y: int): @@ -417,7 +417,7 @@ def fn(x: int, y: int): }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) def test_everything_all_at_once(self): def fn( @@ -473,4 +473,4 @@ def fn( "description": "The output. The return description is also a big multiline\n description that spans multiple lines.", }, } - self.assertEqual(schema, expected_schema) + self.assertEqual(schema["function"], expected_schema) From 50c00e4bde7eecaa7b9e7faae87b79f84ebb03b4 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 28 May 2024 19:20:34 +0100 Subject: [PATCH 56/69] Update src/transformers/utils/chat_template_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- src/transformers/utils/chat_template_utils.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 6fcaed19ab4572..ee72befa662413 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -34,18 +34,14 @@ def _get_json_schema_type(param_type): - if param_type == int: - return {"type": "integer"} - elif param_type == float: - return {"type": "number"} - elif param_type == str: - return {"type": "string"} - elif param_type == bool: - return {"type": "boolean"} - elif param_type == Any: - return {} - else: - return {"type": "object"} + type_mapping = { + int: {"type": "integer"}, + float: {"type": "number"}, + str: {"type": "string"}, + bool: {"type": "boolean"}, + Any: {}, + } + return type_mapping.get(param_type, {"type": "object"}) def _parse_type_hint(hint): From 7b396ec94bfe9d7b78c1d740252d9fca09d57712 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 31 May 2024 14:10:17 +0100 Subject: [PATCH 57/69] Temporary tool calling commit --- docs/source/en/chat_templating.md | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 8c6d30eddde971..1650226ef1263a 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -233,7 +233,7 @@ The sun. From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column. -## Advanced: Extra inputs to chat templates +## Extra inputs to chat templates The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use @@ -242,9 +242,11 @@ strings, lists, dicts or whatever else you want. That said, there are some common use-cases for these extra arguments, such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases, -we have some opinionated recommendations about what the names and formats of these arguments should be. +we have some opinionated recommendations about what the names and formats of these arguments should be, which are +described in the sections below. We encourage model authors to make their chat templates compatible with this format, +to make it easy to transfer tool-calling code between models. -### Tool use / function calling +## Tool use / function calling "Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools to a tool-use model, you can simply pass a list of functions to the `tools` argument: @@ -285,7 +287,22 @@ correctly as tools. Specifically, you should follow these rules: - The function can have a return type and a `Returns:` block in the docstring. However, these are optional because most tool-use models ignore them. -### Understanding tool schemas +### Passing tool results to the model + +The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use +one? If that happens, you should: + +1. Parse the model's output to get the tool name and arguments. +2. Add the model's tool call to the conversation, in the format `{role: "assistant", "tool_calls": [{"name": function_name, "arguments": arguments}]}` +3. Call the corresponding function with those arguments. +4. Add the result to the conversation, in the format `{"role": "tool", "content": tool_results}` + +Here is an example conversation, containing tool calls: + +# TODO example goes here + + +### Advanced: Understanding tool schemas Each function you pass to the `tools` argument of `apply_chat_template` is converted into a [JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas @@ -379,7 +396,7 @@ model_input = tokenizer.apply_chat_template( ) ``` -### Retrieval-augmented generation +## Retrieval-augmented generation "Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our From f64acfd34aab740ea1fcc69be86b75cd5ddd322c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Jun 2024 17:12:34 +0100 Subject: [PATCH 58/69] Add type hints to chat template utils, partially update docs (incomplete!) --- docs/source/en/chat_templating.md | 11 ++++++----- src/transformers/utils/chat_template_utils.py | 12 ++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 1650226ef1263a..bb51d441fa2661 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -233,7 +233,7 @@ The sun. From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column. -## Extra inputs to chat templates +## Advanced: Extra inputs to chat templates The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use @@ -246,7 +246,7 @@ we have some opinionated recommendations about what the names and formats of the described in the sections below. We encourage model authors to make their chat templates compatible with this format, to make it easy to transfer tool-calling code between models. -## Tool use / function calling +## Advanced: Tool use / function calling "Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools to a tool-use model, you can simply pass a list of functions to the `tools` argument: @@ -283,11 +283,12 @@ correctly as tools. Specifically, you should follow these rules: - Every argument must have a type hint - The function must have a docstring in the standard Google style (in other words, an initial function description followed by an `Args:` block that describes the arguments, unless the function does not have any arguments. - Do not include types in the `Args:` block - put them in the type hints in the function header instead. +- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not + `a (int): The first number to multiply`. Type hints should go in the function header instead. - The function can have a return type and a `Returns:` block in the docstring. However, these are optional because most tool-use models ignore them. -### Passing tool results to the model +### Advanced: Passing tool results to the model The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use one? If that happens, you should: @@ -299,7 +300,7 @@ one? If that happens, you should: Here is an example conversation, containing tool calls: -# TODO example goes here +# TODO example goes here after we update all the templates ### Advanced: Understanding tool schemas diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index ee72befa662413..a5bbe4592dad33 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -15,7 +15,7 @@ import inspect import json import re -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) @@ -33,7 +33,7 @@ returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) -def _get_json_schema_type(param_type): +def _get_json_schema_type(param_type: str) -> Dict[str, str]: type_mapping = { int: {"type": "integer"}, float: {"type": "number"}, @@ -44,7 +44,7 @@ def _get_json_schema_type(param_type): return type_mapping.get(param_type, {"type": "object"}) -def _parse_type_hint(hint): +def _parse_type_hint(hint: str) -> Dict: origin = get_origin(hint) args = get_args(hint) @@ -120,7 +120,7 @@ def _parse_type_hint(hint): raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) -def _convert_type_hints_to_json_schema(func): +def _convert_type_hints_to_json_schema(func: Callable) -> Dict: type_hints = get_type_hints(func) signature = inspect.signature(func) required = [] @@ -141,7 +141,7 @@ def _convert_type_hints_to_json_schema(func): return schema -def parse_google_format_docstring(docstring): +def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: """ Parses a Google-style docstring to extract the function description, argument descriptions, and return description. @@ -174,7 +174,7 @@ def parse_google_format_docstring(docstring): return description, args_dict, returns -def get_json_schema(func): +def get_json_schema(func: Callable) -> Dict: """ This function generates a JSON schema for a given function, based on its docstring and type hints. This is mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of From 4e839cb04471a7a183c0a6d730fecea06942887d Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Jun 2024 17:14:33 +0100 Subject: [PATCH 59/69] Code cleanup based on @molbap's suggestion --- src/transformers/utils/chat_template_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index a5bbe4592dad33..decd08ce655087 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -291,7 +291,7 @@ def get_json_schema(func: Callable) -> Dict: if (return_dict := json_schema["properties"].pop("return", None)) is not None: if return_doc is not None: # We allow a missing return docstring since most templates ignore it return_dict["description"] = return_doc - for arg in json_schema["properties"]: + for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: raise ValueError( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" @@ -299,9 +299,9 @@ def get_json_schema(func: Callable) -> Dict: desc = param_descriptions[arg] enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE) if enum_choices: - json_schema["properties"][arg]["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] + schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))] desc = enum_choices.string[: enum_choices.start()].strip() - json_schema["properties"][arg]["description"] = desc + schema["description"] = desc output = {"name": func.__name__, "description": main_doc, "parameters": json_schema} if return_dict is not None: From 8d7ca8f2e6ee01e1acf04aff0af8f4b5a5405a09 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Jun 2024 17:32:23 +0100 Subject: [PATCH 60/69] Add comments to explain regexes --- src/transformers/utils/chat_template_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index decd08ce655087..ed462550fc1e88 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -19,8 +19,11 @@ BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) +# Extracts the initial segment of the docstring, containing the function description description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) +# Extracts the Args: block from the docstring args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL) +# Splits the Args: block into individual arguments args_split_re = re.compile( r""" (?:^|\n) # Match the start of the args block, or a newline @@ -30,6 +33,7 @@ """, re.DOTALL | re.VERBOSE, ) +# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc! returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) From 3fbd682365b8a5e80d0d940650089280c64e24d6 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Jun 2024 18:25:51 +0100 Subject: [PATCH 61/69] Fix up type parsing for unions and lists --- src/transformers/utils/chat_template_utils.py | 55 +++++++------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index ed462550fc1e88..9ea019bd4acf23 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -53,48 +53,35 @@ def _parse_type_hint(hint: str) -> Dict: args = get_args(hint) if origin is None: - return _get_json_schema_type(hint) - - if origin is Union: - # If it's a union of basic types, we can express that as a simple list in the schema - if all(t in BASIC_TYPES for t in args): - return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]} - if len(return_dict["type"]) == 1: - return_dict["type"] = return_dict["type"][0] + try: + return _get_json_schema_type(hint) + except KeyError: + raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + + elif origin is Union: + # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end + subtypes = [_parse_type_hint(t) for t in args if t != type(None)] + if len(subtypes) == 1: + # A single non-null type can be expressed directly + return_dict = subtypes[0] + elif all(isinstance(subtype["type"], str) for subtype in subtypes): + # A union of basic types can be expressed as a list in the schema + return_dict = {"type": [subtype["type"] for subtype in subtypes]} else: - # A union of more complex types requires us to recurse into each subtype - return_dict = { - "anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)], - } - if len(return_dict["anyOf"]) == 1: - return_dict = return_dict["anyOf"][0] + # A union of more complex types requires "anyOf" + return_dict = {"anyOf": subtypes} if type(None) in args: return_dict["nullable"] = True return return_dict - if origin is list: + elif origin is list: if not args: return {"type": "array"} - - # Similarly to unions, a list of basic types can be expressed as a list in the schema - if all(t in BASIC_TYPES for t in args): - items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]} - if len(items["type"]) == 1: - items["type"] = items["type"][0] else: - # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]} - if len(items["anyOf"]) == 1: - items = items["anyOf"][0] - - return_dict = {"type": "array", "items": items} - - if type(None) in args: - return_dict["nullable"] = True - - return return_dict + # Lists can only have a single type argument, so recurse into it + return {"type": "array", "items": _parse_type_hint(args[0])} - if origin is tuple: + elif origin is tuple: if not args: return {"type": "array"} if len(args) == 1: @@ -113,7 +100,7 @@ def _parse_type_hint(hint: str) -> Dict: ) return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]} - if origin is dict: + elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" out = {"type": "object"} From 76c3320079490a01d68dd20217ae6b9b1ec594f1 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 13:18:57 +0100 Subject: [PATCH 62/69] Add custom exception types and adjust tests to look for them --- src/transformers/utils/__init__.py | 2 +- src/transformers/utils/chat_template_utils.py | 30 ++++++++++++++----- tests/utils/test_chat_template_utils.py | 12 ++++---- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index ac2ecaef3a7fd6..ce87bc8623132e 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,7 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin -from .chat_template_utils import get_json_schema +from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 9ea019bd4acf23..ee6173f2a1532b 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -37,6 +37,18 @@ returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL) +class TypeHintParsingException(Exception): + """Exception raised for errors in parsing type hints to generate JSON schemas""" + + pass + + +class DocstringParsingException(Exception): + """Exception raised for errors in parsing docstrings to generate JSON schemas""" + + pass + + def _get_json_schema_type(param_type: str) -> Dict[str, str]: type_mapping = { int: {"type": "integer"}, @@ -56,7 +68,9 @@ def _parse_type_hint(hint: str) -> Dict: try: return _get_json_schema_type(hint) except KeyError: - raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + raise TypeHintParsingException( + "Couldn't parse this type hint, likely due to a custom class or object: ", hint + ) elif origin is Union: # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end @@ -85,7 +99,7 @@ def _parse_type_hint(hint: str) -> Dict: if not args: return {"type": "array"} if len(args) == 1: - raise ValueError( + raise TypeHintParsingException( f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which " "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain " "more than one element, we recommend " @@ -93,7 +107,7 @@ def _parse_type_hint(hint: str) -> Dict: "pass the element directly." ) if ... in args: - raise ValueError( + raise TypeHintParsingException( "Conversion of '...' is not supported in Tuple type hints. " "Use List[] types for variable-length" " inputs instead." @@ -108,7 +122,7 @@ def _parse_type_hint(hint: str) -> Dict: out["additionalProperties"] = _parse_type_hint(args[1]) return out - raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint) + raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) def _convert_type_hints_to_json_schema(func: Callable) -> Dict: @@ -117,7 +131,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict: required = [] for param_name, param in signature.parameters.items(): if param.annotation == inspect.Parameter.empty: - raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}") + raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}") if param.default == inspect.Parameter.empty: required.append(param_name) @@ -274,7 +288,9 @@ def get_json_schema(func: Callable) -> Dict: """ doc = inspect.getdoc(func) if not doc: - raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") + raise DocstringParsingException( + f"Cannot generate JSON schema for {func.__name__} because it has no docstring!" + ) doc = doc.strip() main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc) @@ -284,7 +300,7 @@ def get_json_schema(func: Callable) -> Dict: return_dict["description"] = return_doc for arg, schema in json_schema["properties"].items(): if arg not in param_descriptions: - raise ValueError( + raise DocstringParsingException( f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" ) desc = param_descriptions[arg] diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 7ebf93d0205bf1..cff31c1f8a3483 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -15,7 +15,7 @@ import unittest from typing import List, Optional, Tuple, Union -from transformers.utils import get_json_schema +from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema class JsonSchemaGeneratorTest(unittest.TestCase): @@ -206,7 +206,7 @@ def test_missing_docstring(self): def fn(x: int): return x - with self.assertRaises(ValueError): + with self.assertRaises(DocstringParsingException): get_json_schema(fn) def test_missing_param_docstring(self): @@ -216,7 +216,7 @@ def fn(x: int): """ return x - with self.assertRaises(ValueError): + with self.assertRaises(DocstringParsingException): get_json_schema(fn) def test_missing_type_hint(self): @@ -229,7 +229,7 @@ def fn(x): """ return x - with self.assertRaises(ValueError): + with self.assertRaises(TypeHintParsingException): get_json_schema(fn) def test_return_value(self): @@ -329,7 +329,7 @@ def fn(x: Tuple[int]): return x # Single-element tuples should just be the type itself, or List[type] for variable-length inputs - with self.assertRaises(ValueError): + with self.assertRaises(TypeHintParsingException): get_json_schema(fn) def test_ellipsis_type_fails(self): @@ -347,7 +347,7 @@ def fn(x: Tuple[int, ...]): return x # Variable length inputs should be specified with List[type], not Tuple[type, ...] - with self.assertRaises(ValueError): + with self.assertRaises(TypeHintParsingException): get_json_schema(fn) def test_enum_extraction(self): From 170e3675b07064a68ef04ccdbd5453aebc6f0613 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 6 Jun 2024 21:03:52 +0100 Subject: [PATCH 63/69] Update docs with a demo! --- docs/source/en/chat_templating.md | 94 ++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 2 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index bb51d441fa2661..c79c167d1016dd 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -298,10 +298,100 @@ one? If that happens, you should: 3. Call the corresponding function with those arguments. 4. Add the result to the conversation, in the format `{"role": "tool", "content": tool_results}` -Here is an example conversation, containing tool calls: +### Advanced: A complete tool use example -# TODO example goes here after we update all the templates +Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model, +as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the +memory, you can consider using a larger model instead, like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) +or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use. +First, let's load our model and tokenizer: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B" + +tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13") +model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto") +``` + +Next, let's define a list of tools. For simplicity, we'll just have a single tool in this example: + +```python +def get_current_temperature(location: str, unit: str) -> float: + """ + Get the current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) + Returns: + The current temperature in the specified units, as a float. + """ + return 22. # Your real function should probably actually get the temperature! + +tools = [get_current_temperature] +``` + +Now, let's set up a conversation for our bot: + +```python +messages = [ + {"role": "system", "content": "You are a bot that responds to temperature queries. You should choose the unit used in the queried location."}, + {"role": "user", "content": "Hey, what's the temperature in Paris right now?"} +] +``` + +Now, let's apply the chat template and generate a response: + +```python +inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt") +inputs = {k: v.to(model.device) for k, v in inputs.items()} +out = model.generate(**inputs, max_new_tokens=128) +print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) +``` + +And we get: + +```text + +{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"} +<|im_end|> +``` + +The model has called the function with valid arguments, in the format requested by the function docstring. It has +inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units, +the temperature in France should certainly be displayed in Celsius. + +Let's append the model's tool call to the conversation, followed by the result of calling the tool. Remember, in +reality this is the point where you'd actually call the function, rather than just using a dummy +result! + +```python +messages.append({"role": "assistant", "tool_calls": [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}]}) +messages.append({"role": "tool", "name": "get_current_temperature", "content": 22.}) +``` + +Finally, let's let the assistant read the function outputs and continue chatting with the user: + +```python +inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt") +inputs = {k: v.to(model.device) for k, v in inputs.items()} +out = model.generate(**inputs, max_new_tokens=128) +print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) +``` + +And we get: + +```text +The current temperature in Paris, France is 22.0°C (71.6°F). +``` + +Although this was a simple demo with only a single call, the same technique works with +multiple tools and longer conversations. This can be a powerful ways to extend the capabilities of conversational +agents with real-time information, computational tools like calculators, or access to large databases. ### Advanced: Understanding tool schemas From 1d4952a1a602499683dfd4436a11fe1db3b76fc0 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 7 Jun 2024 13:54:09 +0100 Subject: [PATCH 64/69] Docs cleanup --- docs/source/en/chat_templating.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index c79c167d1016dd..76483cb7efd4d7 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -288,7 +288,7 @@ correctly as tools. Specifically, you should follow these rules: - The function can have a return type and a `Returns:` block in the docstring. However, these are optional because most tool-use models ignore them. -### Advanced: Passing tool results to the model +### Passing tool results to the model The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use one? If that happens, you should: @@ -298,11 +298,11 @@ one? If that happens, you should: 3. Call the corresponding function with those arguments. 4. Add the result to the conversation, in the format `{"role": "tool", "content": tool_results}` -### Advanced: A complete tool use example +### A complete tool use example Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model, as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the -memory, you can consider using a larger model instead, like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) +memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use. First, let's load our model and tokenizer: @@ -393,10 +393,10 @@ Although this was a simple demo with only a single call, the same technique work multiple tools and longer conversations. This can be a powerful ways to extend the capabilities of conversational agents with real-time information, computational tools like calculators, or access to large databases. -### Advanced: Understanding tool schemas +### Understanding tool schemas Each function you pass to the `tools` argument of `apply_chat_template` is converted into a -[JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas +[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you @@ -487,7 +487,7 @@ model_input = tokenizer.apply_chat_template( ) ``` -## Retrieval-augmented generation +## Advanced: Retrieval-augmented generation "Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our From 9f71f6c63aad9dd02fa08c549bd93f404798cb24 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 7 Jun 2024 13:54:51 +0100 Subject: [PATCH 65/69] Pass content as string --- docs/source/en/chat_templating.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 76483cb7efd4d7..fe6cb3dc187a94 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -371,7 +371,7 @@ result! ```python messages.append({"role": "assistant", "tool_calls": [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}]}) -messages.append({"role": "tool", "name": "get_current_temperature", "content": 22.}) +messages.append({"role": "tool", "name": "get_current_temperature", "content": "22.0"}) ``` Finally, let's let the assistant read the function outputs and continue chatting with the user: From 6c29d46559cf2d84544ba86fad0301157c3f96eb Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 10 Jun 2024 18:24:29 +0100 Subject: [PATCH 66/69] Update tool call formatting --- docs/source/en/chat_templating.md | 39 +++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index fe6cb3dc187a94..a2efecf866140d 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -293,17 +293,18 @@ correctly as tools. Specifically, you should follow these rules: The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use one? If that happens, you should: -1. Parse the model's output to get the tool name and arguments. -2. Add the model's tool call to the conversation, in the format `{role: "assistant", "tool_calls": [{"name": function_name, "arguments": arguments}]}` -3. Call the corresponding function with those arguments. -4. Add the result to the conversation, in the format `{"role": "tool", "content": tool_results}` +1. Parse the model's output to get the tool name(s) and arguments. +2. Add the model's tool call(s) to the conversation. +3. Call the corresponding function(s) with those arguments. +4. Add the result(s) to the conversation ### A complete tool use example Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model, as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) -or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use. +or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use +and offer even stronger performance. First, let's load our model and tokenizer: @@ -365,13 +366,23 @@ The model has called the function with valid arguments, in the format requested inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units, the temperature in France should certainly be displayed in Celsius. -Let's append the model's tool call to the conversation, followed by the result of calling the tool. Remember, in -reality this is the point where you'd actually call the function, rather than just using a dummy -result! +Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs +are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response +corresponds to which call. You can generate them any way you like, but they should be unique within each chat. ```python -messages.append({"role": "assistant", "tool_calls": [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}]}) -messages.append({"role": "tool", "name": "get_current_temperature", "content": "22.0"}) +tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call +tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}} +messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]}) +``` + + +Now that we've added the tool call to the conversation, we can call the function and append the result to the +conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append +that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above. + +```python +messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"}) ``` Finally, let's let the assistant read the function outputs and continue chatting with the user: @@ -393,6 +404,14 @@ Although this was a simple demo with only a single call, the same technique work multiple tools and longer conversations. This can be a powerful ways to extend the capabilities of conversational agents with real-time information, computational tools like calculators, or access to large databases. + +Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and +match tool calls to results using the ordering, and there are several models that use neither and only issue one tool +call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we +recommend structuring your tools calls like we've shown here, and returning tool results in the order that +they were issued by the model. The chat templates on each model should handle the rest. + + ### Understanding tool schemas Each function you pass to the `tools` argument of `apply_chat_template` is converted into a From 0ab8c7f3ee47fdf57836c43abad23b42f5eed0f3 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 11 Jun 2024 13:24:58 +0100 Subject: [PATCH 67/69] Update docs with new function format --- docs/source/en/chat_templating.md | 64 +++++++++++++++++++------------ 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index a2efecf866140d..00600fb28484d5 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -401,7 +401,7 @@ The current temperature in Paris, France is 22.0°C (71.6°F). ``` Although this was a simple demo with only a single call, the same technique works with -multiple tools and longer conversations. This can be a powerful ways to extend the capabilities of conversational +multiple tools and longer conversations. This can be a powerful way to extend the capabilities of conversational agents with real-time information, computational tools like calculators, or access to large databases. @@ -447,22 +447,25 @@ This will yield: ```json { + "type": "function", + "function": { "name": "multiply", - "description": "Multiply two numbers together.", + "description": "A function that multiplies two numbers", "parameters": { - "type": "object", - "properties": { - "a": { - "type": "number", - "description": "The first number to multiply." - }, - "b": { - "type": "number", - "description": "The second number to multiply." - } + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number to multiply" }, - "required": ["a", "b"] + "b": { + "type": "number", + "description": "The second number to multiply" + } + }, + "required": ["a", "b"] } + } } ``` @@ -478,27 +481,38 @@ Here is an example of defining schemas by hand, and passing them directly to `ap ```python # A simple function that takes no arguments current_time = { + "type": "function", + "function": { "name": "current_time", "description": "Get the current local time as a string.", "parameters": { - 'type': 'object', - 'properties': {} - }, + 'type': 'object', + 'properties': {} } + } +} # A more complete function that takes two numerical arguments multiply = { - "name": "multiply", - "description": "Multiply two numbers together.", - "parameters": { - "type": "object", - "properties": { - "a": {"type": "number", "description": "The first number to multiply."}, - "b": {"type": "number", "description": "The second number to multiply."}, - }, - "required": ["a", "b"], + 'type': 'function', + 'function': { + 'name': 'multiply', + 'description': 'A function that multiplies two numbers', + 'parameters': { + 'type': 'object', + 'properties': { + 'a': { + 'type': 'number', + 'description': 'The first number to multiply' + }, + 'b': { + 'type': 'number', 'description': 'The second number to multiply' } + }, + 'required': ['a', 'b'] } + } +} model_input = tokenizer.apply_chat_template( messages, From 2b3e222772d78cffa2726b9dcebcdf4301e824ff Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 11 Jun 2024 15:13:55 +0100 Subject: [PATCH 68/69] Update docs --- docs/source/en/chat_templating.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 00600fb28484d5..d3a68958fdd2fa 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -340,7 +340,7 @@ Now, let's set up a conversation for our bot: ```python messages = [ - {"role": "system", "content": "You are a bot that responds to temperature queries. You should choose the unit used in the queried location."}, + {"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."}, {"role": "user", "content": "Hey, what's the temperature in Paris right now?"} ] ``` @@ -397,7 +397,7 @@ print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) And we get: ```text -The current temperature in Paris, France is 22.0°C (71.6°F). +The current temperature in Paris, France is 22.0 degrees Celsius.<|im_end|> ``` Although this was a simple demo with only a single call, the same technique works with From 2b19b0a3e6b385ac7886d92775da628373b2601d Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 11 Jun 2024 15:20:05 +0100 Subject: [PATCH 69/69] Update docs with a second tool to show the model choosing correctly --- docs/source/en/chat_templating.md | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index d3a68958fdd2fa..8d49fa5c80ee28 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -318,7 +318,7 @@ tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13") model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto") ``` -Next, let's define a list of tools. For simplicity, we'll just have a single tool in this example: +Next, let's define a list of tools: ```python def get_current_temperature(location: str, unit: str) -> float: @@ -329,11 +329,22 @@ def get_current_temperature(location: str, unit: str) -> float: location: The location to get the temperature for, in the format "City, Country" unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) Returns: - The current temperature in the specified units, as a float. + The current temperature at the specified location in the specified units, as a float. """ - return 22. # Your real function should probably actually get the temperature! + return 22. # A real function should probably actually get the temperature! -tools = [get_current_temperature] +def get_current_wind_speed(location: str) -> float: + """ + Get the current wind speed in km/h at a given location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + Returns: + The current wind speed at the given location in km/h, as a float. + """ + return 6. # A real function should probably actually get the wind speed! + +tools = [get_current_temperature, get_current_wind_speed] ``` Now, let's set up a conversation for our bot: @@ -397,11 +408,11 @@ print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):])) And we get: ```text -The current temperature in Paris, France is 22.0 degrees Celsius.<|im_end|> +The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|> ``` -Although this was a simple demo with only a single call, the same technique works with -multiple tools and longer conversations. This can be a powerful way to extend the capabilities of conversational +Although this was a simple demo with dummy tools and a single call, the same technique works with +multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational agents with real-time information, computational tools like calculators, or access to large databases.