Skip to content

Commit

Permalink
encapsulate chat template logic
Browse files Browse the repository at this point in the history
  • Loading branch information
KonradSzafer committed Jul 22, 2024
1 parent 7987710 commit 8470940
Showing 1 changed file with 72 additions and 48 deletions.
120 changes: 72 additions & 48 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,54 +1764,7 @@ def apply_chat_template(
if tokenizer_kwargs is None:
tokenizer_kwargs = {}

using_default_template = False

# First, handle the cases when the model has a dict of multiple templates
if isinstance(self.chat_template, dict) or (
self.chat_template is None and isinstance(self.default_chat_template, dict)
):
if self.chat_template is not None:
template_dict = self.chat_template
using_default_dict = False
else:
template_dict = self.default_chat_template
using_default_dict = True
if chat_template is not None and chat_template in template_dict:
# The user can pass the name of a template to the chat template argument instead of an entire template
chat_template = template_dict[chat_template]
if using_default_dict:
using_default_template = True
elif chat_template is None:
if tools is not None and "tool_use" in template_dict:
chat_template = template_dict["tool_use"]
elif "default" in template_dict:
chat_template = template_dict["default"]
else:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template_dict.keys())}."
)
if using_default_dict:
using_default_template = True

elif chat_template is None:
# These are the cases when the model has a single template
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if self.chat_template is not None:
chat_template = self.chat_template
else:
chat_template = self.default_chat_template
using_default_template = True

if using_default_template:
logger.warning_once(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
chat_template = self.get_chat_template(chat_template, tools)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
Expand Down Expand Up @@ -1908,6 +1861,77 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat
template for better generation tracking.
Args:
chat_template (`str`, *optional*):
A Jinja template or the name of a template to use for this conversion.
It is usually not necessary to pass anything to this argument,
as the model's template will be used by default.
tools (`List[Dict]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
giving the name, description and argument types for the tool. See our
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
for more information.
Returns:
`str`: The chat template string.
"""
using_default_template = False
# First, handle the cases when the model has a dict of multiple templates
if isinstance(self.chat_template, dict) or (
self.chat_template is None and isinstance(self.default_chat_template, dict)
):
if self.chat_template is not None:
template_dict = self.chat_template
using_default_dict = False
else:
template_dict = self.default_chat_template
using_default_dict = True
if chat_template is not None and chat_template in template_dict:
# The user can pass the name of a template to the chat template argument instead of an entire template
chat_template = template_dict[chat_template]
if using_default_dict:
using_default_template = True
elif chat_template is None:
if tools is not None and "tool_use" in template_dict:
chat_template = template_dict["tool_use"]
elif "default" in template_dict:
chat_template = template_dict["default"]
else:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template_dict.keys())}."
)
if using_default_dict:
using_default_template = True

elif chat_template is None:
# These are the cases when the model has a single template
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if self.chat_template is not None:
chat_template = self.chat_template
else:
chat_template = self.default_chat_template
using_default_template = True

if using_default_template:
logger.warning_once(
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)

return chat_template

@property
def default_chat_template(self):
"""
Expand Down

0 comments on commit 8470940

Please sign in to comment.