diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index bb8ec1c450d3aa..169281d115b98c 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -336,19 +336,16 @@ This will yield: } ``` -We can use this function, or the equivalent [`add_json_schema`] decorator, to avoid the need to manually write JSON -schemas when passing tools to the chat template: +We can use this function to avoid the need to manually write JSON schemas when passing tools to the chat template. +In addition, if you pass functions in the `tools` argument, they will automatically be converted with this function: ```python import datetime -from transformers.utils import add_json_schema -@add_json_schema def current_time(): """Get the current local time as a string.""" return str(datetime.now()) -@add_json_schema def multiply(a: float, b: float): """ A function that multiplies two numbers @@ -369,7 +366,7 @@ model_input = tokenizer.apply_chat_template( #### Notes on automatic conversion -`get_json_schema` and `add_json_schema` both expect a specific docstring format. The docstring should +`get_json_schema` expects a specific docstring format. The docstring should begin with a description of the function, followed by an `Args:` block that describes each argument. It can also optionally include a `Returns:` block that describes the value(s) returned by the function. Many templates ignore this, because the model will see the return format after calling the function anyway, but some require it. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 7ded690ade1ca0..881bcb1f4b7a87 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,6 +28,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import lru_cache +from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -47,6 +48,7 @@ copy_func, download_url, extract_commit_hash, + get_json_schema, is_flax_available, is_jax_tensor, is_mlx_available, @@ -1817,10 +1819,21 @@ def apply_chat_template( conversations = [conversation] is_batched = False - # The add_json_schema decorator for tools adds a schema under the `json_schema` attribute. If we're passed - # decorated functions, let's extract the schema decoration now + # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas if tools is not None: - tools = [tool.json_schema if hasattr(tool, "json_schema") else tool for tool in tools] + tool_schemas = [] + for tool in tools: + if isinstance(tool, dict): + tool_schemas.append(tool) + elif isfunction(tool): + tool_schemas.append(get_json_schema(tool)) + else: + raise ValueError( + "Tools should either be a JSON schema, or a callable function with type hints " + "and a docstring suitable for auto-conversion to a schema." + ) + else: + tool_schemas = None rendered = [] template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present @@ -1830,7 +1843,7 @@ def apply_chat_template( chat = chat.messages rendered_chat = compiled_template.render( messages=chat, - tools=tools, + tools=tool_schemas, documents=documents, add_generation_prompt=add_generation_prompt, **template_kwargs, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 4642d8e88a38ff..75108c6975ae8a 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,7 +21,7 @@ from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin -from .chat_template_utils import add_json_schema, get_json_schema +from .chat_template_utils import get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 8d9fad6ef855a1..0cb74dc973c5fe 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -78,8 +78,8 @@ def get_json_schema(func): >>> # The formatted chat can now be passed to model.generate() ``` - In many cases, it is more convenient to define tool functions with the [`add_json_schema`] decorator rather than - calling this function directly. + In many cases, it is more convenient to simply pass the functions directly to apply_chat_template and let it + autogenerate schemas than calling this function directly. """ doc = inspect.getdoc(func) if not doc: @@ -104,44 +104,6 @@ def get_json_schema(func): return output -def add_json_schema(func): - """ - This decorator adds a JSON schema to a function, based on its docstring and type hints. The JSON schema is the - same as the one generated by the [`get_json_schema`] function. It is stored in the `json_schema` attribute of the - function, which will be automatically read by `apply_chat_template()` if present. - - Example: - - ```python - >>> from transformers import AutoTokenizer - >>> from transformers.utils import get_json_schema - >>> - >>> @add_json_schema - >>> def multiply(x: float, y: float): - >>> ''' - >>> A function that multiplies two numbers - >>> - >>> :param x: The first number to multiply - >>> :param y: The second number to multiply - >>> ''' - >>> return x * y - >>> - >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") - >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}] - >>> formatted_chat = tokenizer.apply_chat_template( - >>> messages, - >>> tools=[multiply], - >>> chat_template="tool_use", - >>> return_dict=True, - >>> return_tensors="pt", - >>> add_generation_prompt=True - >>> ) - >>> # The formatted chat can now be passed to model.generate() - """ - func.json_schema = get_json_schema(func) - return func - - def parse_google_format_docstring(docstring): """ Parses a Google-style docstring to extract the function description,