Skip to content

Commit

Permalink
No more decorator - we just do it implicitly!
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed May 23, 2024
1 parent cbb0ad4 commit 9ec58d3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 51 deletions.
9 changes: 3 additions & 6 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,19 +336,16 @@ This will yield:
}
```

We can use this function, or the equivalent [`add_json_schema`] decorator, to avoid the need to manually write JSON
schemas when passing tools to the chat template:
We can use this function to avoid the need to manually write JSON schemas when passing tools to the chat template.
In addition, if you pass functions in the `tools` argument, they will automatically be converted with this function:

```python
import datetime
from transformers.utils import add_json_schema

@add_json_schema
def current_time():
"""Get the current local time as a string."""
return str(datetime.now())

@add_json_schema
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Expand All @@ -369,7 +366,7 @@ model_input = tokenizer.apply_chat_template(

#### Notes on automatic conversion

`get_json_schema` and `add_json_schema` both expect a specific docstring format. The docstring should
`get_json_schema` expects a specific docstring format. The docstring should
begin with a description of the function, followed by an `Args:` block that describes each argument. It can also
optionally include a `Returns:` block that describes the value(s) returned by the function. Many templates ignore this,
because the model will see the return format after calling the function anyway, but some require it.
Expand Down
21 changes: 17 additions & 4 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
Expand All @@ -47,6 +48,7 @@
copy_func,
download_url,
extract_commit_hash,
get_json_schema,
is_flax_available,
is_jax_tensor,
is_mlx_available,
Expand Down Expand Up @@ -1817,10 +1819,21 @@ def apply_chat_template(
conversations = [conversation]
is_batched = False

# The add_json_schema decorator for tools adds a schema under the `json_schema` attribute. If we're passed
# decorated functions, let's extract the schema decoration now
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
if tools is not None:
tools = [tool.json_schema if hasattr(tool, "json_schema") else tool for tool in tools]
tool_schemas = []
for tool in tools:
if isinstance(tool, dict):
tool_schemas.append(tool)
elif isfunction(tool):
tool_schemas.append(get_json_schema(tool))
else:
raise ValueError(
"Tools should either be a JSON schema, or a callable function with type hints "
"and a docstring suitable for auto-conversion to a schema."
)
else:
tool_schemas = None

rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
Expand All @@ -1830,7 +1843,7 @@ def apply_chat_template(
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat,
tools=tools,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .. import __version__
from .backbone_utils import BackboneConfigMixin, BackboneMixin
from .chat_template_utils import add_json_schema, get_json_schema
from .chat_template_utils import get_json_schema
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from .doc import (
add_code_sample_docstrings,
Expand Down
42 changes: 2 additions & 40 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def get_json_schema(func):
>>> # The formatted chat can now be passed to model.generate()
```
In many cases, it is more convenient to define tool functions with the [`add_json_schema`] decorator rather than
calling this function directly.
In many cases, it is more convenient to simply pass the functions directly to apply_chat_template and let it
autogenerate schemas than calling this function directly.
"""
doc = inspect.getdoc(func)
if not doc:
Expand All @@ -104,44 +104,6 @@ def get_json_schema(func):
return output


def add_json_schema(func):
"""
This decorator adds a JSON schema to a function, based on its docstring and type hints. The JSON schema is the
same as the one generated by the [`get_json_schema`] function. It is stored in the `json_schema` attribute of the
function, which will be automatically read by `apply_chat_template()` if present.
Example:
```python
>>> from transformers import AutoTokenizer
>>> from transformers.utils import get_json_schema
>>>
>>> @add_json_schema
>>> def multiply(x: float, y: float):
>>> '''
>>> A function that multiplies two numbers
>>>
>>> :param x: The first number to multiply
>>> :param y: The second number to multiply
>>> '''
>>> return x * y
>>>
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
>>> formatted_chat = tokenizer.apply_chat_template(
>>> messages,
>>> tools=[multiply],
>>> chat_template="tool_use",
>>> return_dict=True,
>>> return_tensors="pt",
>>> add_generation_prompt=True
>>> )
>>> # The formatted chat can now be passed to model.generate()
"""
func.json_schema = get_json_schema(func)
return func


def parse_google_format_docstring(docstring):
"""
Parses a Google-style docstring to extract the function description,
Expand Down

0 comments on commit 9ec58d3

Please sign in to comment.